Skip to content

Commit a13ebd3

Browse files
authored
Merge branch 'pytorch:main' into fix_wandb_logger
2 parents 0f51e6b + 851a041 commit a13ebd3

File tree

128 files changed

+23097
-4095
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

128 files changed

+23097
-4095
lines changed

.github/unittest/linux_libs/scripts_brax/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ dependencies:
2121
- hydra-core
2222
- jax[cuda12]
2323
- brax
24+
- psutil

.github/unittest/linux_libs/scripts_brax/run_test.sh

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ conda activate ./env
88

99
export PYTORCH_TEST_WITH_SLOW='1'
1010
export LAZY_LEGACY_OP=False
11+
12+
# Configure JAX for proper GPU initialization
13+
export XLA_PYTHON_CLIENT_PREALLOCATE=false
14+
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
15+
export TF_FORCE_GPU_ALLOW_GROWTH=true
16+
export CUDA_VISIBLE_DEVICES=0
17+
1118
python -m torch.utils.collect_env
1219
# Avoid error: "fatal: unsafe repository"
1320
git config --global --add safe.directory '*'
@@ -28,7 +35,33 @@ export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON
2835
# this workflow only tests the libs
2936
python -c "import brax"
3037
python -c "import brax.envs"
31-
python -c "import jax"
38+
39+
# Initialize JAX with proper GPU configuration
40+
python -c "
41+
import jax
42+
import jax.numpy as jnp
43+
import os
44+
45+
# Configure JAX for GPU
46+
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
47+
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
48+
49+
# Test JAX GPU availability
50+
try:
51+
devices = jax.devices()
52+
print(f'JAX devices: {devices}')
53+
if len(devices) > 1:
54+
print('JAX GPU is available')
55+
else:
56+
print('JAX CPU only')
57+
except Exception as e:
58+
print(f'JAX initialization error: {e}')
59+
# Fallback to CPU
60+
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
61+
jax.config.update('jax_platform_name', 'cpu')
62+
print('Falling back to JAX CPU')
63+
"
64+
3265
python3 -c 'import torch;t = torch.ones([2,2], device="cuda:0");print(t);print("tensor device:" + str(t.device))'
3366

3467
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestBrax --error-for-skips

.github/unittest/linux_libs/scripts_openx/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ dependencies:
2121
- hydra-core
2222
- tqdm
2323
- h5py
24-
- datasets
24+
- datasets<4.0.0
2525
- pillow

.github/unittest/windows_optdepts/scripts/unittest.sh

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ env_dir="${root_dir}/env"
1414

1515
cd "${root_dir}"
1616

17+
echo "=== Starting Windows CI setup ==="
18+
echo "Current directory: $(pwd)"
19+
echo "Python version: $PYTHON_VERSION"
20+
echo "CU_VERSION: $CU_VERSION"
21+
echo "TORCH_VERSION: $TORCH_VERSION"
22+
1723
eval "$($(which conda) shell.bash hook)" && set -x
1824

1925
# Create test environment at ./env
@@ -28,11 +34,12 @@ echo $(which python)
2834
echo $(python --version)
2935
echo $(conda info -e)
3036

31-
37+
echo "=== Installing test dependencies ==="
3238
python -m pip install hypothesis future cloudpickle pytest pytest-cov pytest-mock pytest-instafail pytest-rerunfailures expecttest pyyaml scipy coverage
3339

3440
# =================================== Install =================================================
3541

42+
echo "=== Installing PyTorch and dependencies ==="
3643

3744
# TODO, refactor the below logic to make it easy to understand how to get correct cuda_version.
3845
if [ "${CU_VERSION:-}" == cpu ] ; then
@@ -56,8 +63,8 @@ else
5663
cudatoolkit="${cuda_toolkit_pckg}=${version}"
5764
fi
5865

59-
6066
# submodules
67+
echo "=== Updating git submodules ==="
6168
git submodule sync && git submodule update --init --recursive
6269
python -m pip install "numpy<2.0"
6370

@@ -92,6 +99,7 @@ fi
9299
#python -m pip install pip --upgrade
93100

94101
# install tensordict
102+
echo "=== Installing tensordict ==="
95103
if [[ "$RELEASE" == 0 ]]; then
96104
conda install anaconda::cmake -y
97105

@@ -103,11 +111,13 @@ else
103111
fi
104112

105113
# smoke test
114+
echo "=== Testing tensordict import ==="
106115
python -c """
107116
from tensordict import TensorDict
108117
print('successfully imported tensordict')
109118
"""
110119

120+
echo "=== Setting up CUDA environment ==="
111121
source "$this_dir/set_cuda_envs.sh"
112122

113123
printf "* Installing torchrl\n"
@@ -117,13 +127,15 @@ whatsinside=$(ls -rtlh ./torchrl)
117127
echo $whatsinside
118128

119129
# smoke test
130+
echo "=== Testing torchrl import ==="
120131
python -c """
121132
from torchrl.data import ReplayBuffer
122133
print('successfully imported torchrl')
123134
"""
124135

125136
# =================================== Run =================================================
126137

138+
echo "=== Setting up test environment ==="
127139
source "$this_dir/set_cuda_envs.sh"
128140

129141
# we don't use torchsnapshot
@@ -132,5 +144,24 @@ export MAX_IDLE_COUNT=60
132144
export BATCHED_PIPE_TIMEOUT=60
133145
export LAZY_LEGACY_OP=False
134146

147+
echo "=== Collecting environment info ==="
135148
python -m torch.utils.collect_env
136-
pytest --junitxml=test-results/junit.xml -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py
149+
150+
echo "=== Starting pytest execution ==="
151+
echo "Current working directory: $(pwd)"
152+
echo "Python executable: $(which python)"
153+
echo "Pytest executable: $(which pytest)"
154+
155+
# Create test-results directory if it doesn't exist
156+
mkdir -p test-results
157+
158+
# Run pytest with explicit error handling
159+
set +e # Don't exit on error for pytest
160+
pytest --junitxml=test-results/junit.xml -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py --ignore test/llm
161+
PYTEST_EXIT_CODE=$?
162+
set -e # Re-enable exit on error
163+
164+
echo "=== Pytest completed with exit code: $PYTEST_EXIT_CODE ==="
165+
166+
# Exit with pytest's exit code
167+
exit $PYTEST_EXIT_CODE

.github/workflows/benchmarks_pr.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ jobs:
1414

1515
benchmark_cpu:
1616
name: CPU Pytest benchmark
17-
runs-on: linux.g5.4xlarge.nvidia.cpu
17+
runs-on: linux.4xlarge
18+
# Disabling job since it hasn't worked for months
19+
if: false
1820
defaults:
1921
run:
2022
shell: bash -l {0}

.github/workflows/test-linux-libs.yml

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,39 +21,39 @@ permissions:
2121

2222
jobs:
2323

24-
unittests-atari-dqn:
25-
strategy:
26-
matrix:
27-
python_version: ["3.10"]
28-
cuda_arch_version: ["12.8"]
29-
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
30-
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
31-
with:
32-
repository: pytorch/rl
33-
runner: "linux.g5.4xlarge.nvidia.gpu"
34-
docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04"
35-
timeout: 120
36-
script: |
37-
if [[ "${{ github.ref }}" =~ release/* ]]; then
38-
export RELEASE=1
39-
export TORCH_VERSION=stable
40-
else
41-
export RELEASE=0
42-
export TORCH_VERSION=nightly
43-
fi
44-
45-
set -euo pipefail
46-
export PYTHON_VERSION="3.10"
47-
export CU_VERSION="cu128"
48-
export TAR_OPTIONS="--no-same-owner"
49-
export UPLOAD_CHANNEL="nightly"
50-
export TF_CPP_MIN_LOG_LEVEL=0
51-
export TD_GET_DEFAULTS_TO_NONE=1
52-
53-
bash .github/unittest/linux_libs/scripts_ataridqn/setup_env.sh
54-
bash .github/unittest/linux_libs/scripts_ataridqn/install.sh
55-
bash .github/unittest/linux_libs/scripts_ataridqn/run_test.sh
56-
bash .github/unittest/linux_libs/scripts_ataridqn/post_process.sh
24+
# unittests-atari-dqn:
25+
# strategy:
26+
# matrix:
27+
# python_version: ["3.10"]
28+
# cuda_arch_version: ["12.8"]
29+
# if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
30+
# uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
31+
# with:
32+
# repository: pytorch/rl
33+
# runner: "linux.g5.4xlarge.nvidia.gpu"
34+
# docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04"
35+
# timeout: 120
36+
# script: |
37+
# if [[ "${{ github.ref }}" =~ release/* ]]; then
38+
# export RELEASE=1
39+
# export TORCH_VERSION=stable
40+
# else
41+
# export RELEASE=0
42+
# export TORCH_VERSION=nightly
43+
# fi
44+
45+
# set -euo pipefail
46+
# export PYTHON_VERSION="3.10"
47+
# export CU_VERSION="cu128"
48+
# export TAR_OPTIONS="--no-same-owner"
49+
# export UPLOAD_CHANNEL="nightly"
50+
# export TF_CPP_MIN_LOG_LEVEL=0
51+
# export TD_GET_DEFAULTS_TO_NONE=1
52+
53+
# bash .github/unittest/linux_libs/scripts_ataridqn/setup_env.sh
54+
# bash .github/unittest/linux_libs/scripts_ataridqn/install.sh
55+
# bash .github/unittest/linux_libs/scripts_ataridqn/run_test.sh
56+
# bash .github/unittest/linux_libs/scripts_ataridqn/post_process.sh
5757

5858
unittests-brax:
5959
strategy:

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ htmlcov/
4444
.coverage
4545
.coverage.*
4646
.cache
47+
.neptune
4748
nosetests.xml
4849
coverage.xml
4950
*.cover

README.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,57 @@
2323

2424
**TorchRL** is an open-source Reinforcement Learning (RL) library for PyTorch.
2525

26+
## 🚀 What's New
27+
28+
### LLM API - Complete Framework for Language Model Fine-tuning
29+
30+
TorchRL now includes a comprehensive **LLM API** for post-training and fine-tuning of language models! This new framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
31+
32+
- 🤖 **Unified LLM Wrappers**: Seamless integration with Hugging Face models and vLLM inference engines - more to come!
33+
- 💬 **Conversation Management**: Advanced [`History`](torchrl/data/llm/history.py) class for multi-turn dialogue with automatic chat template detection
34+
- 🛠️ **Tool Integration**: [Built-in support](torchrl/envs/llm/transforms/) for Python code execution, function calling, and custom tool transforms
35+
- 🎯 **Specialized Objectives**: [GRPO](torchrl/objectives/llm/grpo.py) (Group Relative Policy Optimization) and [SFT](torchrl/objectives/llm/sft.py) loss functions optimized for language models
36+
-**High-Performance Collectors**: [Async data collection](torchrl/collectors/llm/) with distributed training support
37+
- 🔄 **Flexible Environments**: Transform-based architecture for reward computation, data loading, and conversation augmentation
38+
39+
The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the [complete documentation](https://pytorch.org/rl/main/reference/llms.html) and [GRPO implementation example](https://github.com/pytorch/rl/tree/main/sota-implementations/grpo) to get started!
40+
41+
<details>
42+
<summary>Quick LLM API Example</summary>
43+
44+
```python
45+
from torchrl.envs.llm import ChatEnv
46+
from torchrl.modules.llm import TransformersWrapper
47+
from torchrl.objectives.llm import GRPOLoss
48+
from torchrl.collectors.llm import LLMCollector
49+
50+
# Create environment with Python tool execution
51+
env = ChatEnv(
52+
tokenizer=tokenizer,
53+
system_prompt="You are an assistant that can execute Python code.",
54+
batch_size=[1]
55+
).append_transform(PythonInterpreter())
56+
57+
# Wrap your language model
58+
llm = TransformersWrapper(
59+
model=model,
60+
tokenizer=tokenizer,
61+
input_mode="history"
62+
)
63+
64+
# Set up GRPO training
65+
loss_fn = GRPOLoss(llm, critic, gamma=0.99)
66+
collector = LLMCollector(env, llm, frames_per_batch=100)
67+
68+
# Training loop
69+
for data in collector:
70+
loss = loss_fn(data)
71+
loss.backward()
72+
optimizer.step()
73+
```
74+
75+
</details>
76+
2677
## Key features
2778

2879
- 🐍 **Python-first**: Designed with Python as the primary language for ease of use and flexibility
@@ -516,6 +567,39 @@ And it is `functorch` and `torch.compile` compatible!
516567
- various [recipes](https://github.com/pytorch/rl/blob/main/torchrl/trainers/helpers/models.py) to build models that
517568
correspond to the environment being deployed.
518569

570+
- **LLM API**: Complete framework for language model fine-tuning with unified wrappers for Hugging Face and vLLM backends,
571+
conversation management with automatic chat template detection, tool integration (Python execution, function calling),
572+
specialized objectives (GRPO, SFT), and high-performance async collectors. Perfect for RLHF, supervised fine-tuning,
573+
and tool-augmented training scenarios.
574+
<details>
575+
<summary>Code</summary>
576+
577+
```python
578+
from torchrl.envs.llm import ChatEnv
579+
from torchrl.modules.llm import TransformersWrapper
580+
from torchrl.envs.llm.transforms import PythonInterpreter
581+
582+
# Create environment with tool execution
583+
env = ChatEnv(
584+
tokenizer=tokenizer,
585+
system_prompt="You can execute Python code.",
586+
batch_size=[1]
587+
).append_transform(PythonInterpreter())
588+
589+
# Wrap language model for training
590+
llm = TransformersWrapper(
591+
model=model,
592+
tokenizer=tokenizer,
593+
input_mode="history"
594+
)
595+
596+
# Multi-turn conversation with tool use
597+
obs = env.reset(TensorDict({"query": "Calculate 2+2"}, batch_size=[1]))
598+
llm_output = llm(obs) # Generates response
599+
obs = env.step(llm_output) # Environment processes response
600+
```
601+
</details>
602+
519603
If you feel a feature is missing from the library, please submit an issue!
520604
If you would like to contribute to new features, check our [call for contributions](https://github.com/pytorch/rl/issues/509) and our [contribution](https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md) page.
521605

@@ -792,6 +876,18 @@ A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blo
792876
<td> NA
793877
</td>
794878
</tr>
879+
<tr>
880+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/grpo">LLM API (GRPO)</a>
881+
</td>
882+
<td> NA
883+
</td>
884+
<td> +
885+
</td>
886+
<td> +
887+
</td>
888+
<td> NA
889+
</td>
890+
</tr>
795891
</table>
796892

797893
** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending on
@@ -800,6 +896,7 @@ A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blo
800896
and many more to come!
801897

802898
[Code examples](examples/) displaying toy code snippets and training scripts are also available
899+
- [LLM API & GRPO](sota-implementations/grpo) - Complete language model fine-tuning pipeline
803900
- [RLHF](examples/rlhf)
804901
- [Memory-mapped replay buffers](examples/torchrl_features)
805902

docs/source/_static/img/llm-data.svg

Lines changed: 5 additions & 0 deletions
Loading

docs/source/_static/img/llm-env.png

577 KB
Loading

0 commit comments

Comments
 (0)