Skip to content

Commit 3ebad77

Browse files
committed
Move SharedAdam to utils
2 parents 19209e2 + 836f03e commit 3ebad77

File tree

134 files changed

+17099
-4107
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

134 files changed

+17099
-4107
lines changed

.github/scripts/pre-build-script-win.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
pip install --upgrade setuptools
44

5-
export TORCHRL_BUILD_VERSION=0.9.0
5+
export TORCHRL_BUILD_VERSION=0.10.0

.github/scripts/td_script.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
export TORCHRL_BUILD_VERSION=0.9.0
3+
export TORCHRL_BUILD_VERSION=0.10.0
44
pip install --upgrade setuptools
55

66
# Check if ARCH is set to aarch64

.github/scripts/version_script.bat

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@echo off
2-
set TORCHRL_BUILD_VERSION=0.9.0
2+
set TORCHRL_BUILD_VERSION=0.10.0
33
echo TORCHRL_BUILD_VERSION is set to %TORCHRL_BUILD_VERSION%
44

55
@echo on

.github/unittest/linux/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ dependencies:
3535
- transformers
3636
- ninja
3737
- timm
38+
- safetensors

.github/unittest/linux_libs/scripts_brax/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ dependencies:
2121
- hydra-core
2222
- jax[cuda12]
2323
- brax
24+
- psutil

.github/unittest/linux_libs/scripts_brax/run_test.sh

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ conda activate ./env
88

99
export PYTORCH_TEST_WITH_SLOW='1'
1010
export LAZY_LEGACY_OP=False
11+
12+
# Configure JAX for proper GPU initialization
13+
export XLA_PYTHON_CLIENT_PREALLOCATE=false
14+
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
15+
export TF_FORCE_GPU_ALLOW_GROWTH=true
16+
export CUDA_VISIBLE_DEVICES=0
17+
1118
python -m torch.utils.collect_env
1219
# Avoid error: "fatal: unsafe repository"
1320
git config --global --add safe.directory '*'
@@ -28,7 +35,33 @@ export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON
2835
# this workflow only tests the libs
2936
python -c "import brax"
3037
python -c "import brax.envs"
31-
python -c "import jax"
38+
39+
# Initialize JAX with proper GPU configuration
40+
python -c "
41+
import jax
42+
import jax.numpy as jnp
43+
import os
44+
45+
# Configure JAX for GPU
46+
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
47+
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
48+
49+
# Test JAX GPU availability
50+
try:
51+
devices = jax.devices()
52+
print(f'JAX devices: {devices}')
53+
if len(devices) > 1:
54+
print('JAX GPU is available')
55+
else:
56+
print('JAX CPU only')
57+
except Exception as e:
58+
print(f'JAX initialization error: {e}')
59+
# Fallback to CPU
60+
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
61+
jax.config.update('jax_platform_name', 'cpu')
62+
print('Falling back to JAX CPU')
63+
"
64+
3265
python3 -c 'import torch;t = torch.ones([2,2], device="cuda:0");print(t);print("tensor device:" + str(t.device))'
3366

3467
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestBrax --error-for-skips

.github/unittest/linux_libs/scripts_minari/environment.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,11 @@ dependencies:
2020
- scipy
2121
- hydra-core
2222
- minari[gcs,hdf5,hf]
23-
- gymnasium<1.0.0
23+
- gymnasium>=1.2.0
24+
- ale-py
25+
- gymnasium-robotics
26+
- minari[create]
27+
- jax
28+
- mujoco
29+
- mujoco-py<2.2,>=2.1
30+
- minigrid

.github/unittest/linux_libs/scripts_minari/install.sh

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ unset PYTORCH_VERSION
77

88
set -e
99

10-
eval "$(./conda/bin/conda shell.bash hook)"
11-
conda activate ./env
10+
# Note: This script is sourced by run_all.sh, so the environment is already active
1211

1312
if [ "${CU_VERSION:-}" == cpu ] ; then
1413
version="cpu"
@@ -22,22 +21,21 @@ else
2221
version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
2322
fi
2423

25-
2624
# submodules
2725
git submodule sync && git submodule update --init --recursive
2826

2927
printf "Installing PyTorch with cu128"
3028
if [[ "$TORCH_VERSION" == "nightly" ]]; then
3129
if [ "${CU_VERSION:-}" == cpu ] ; then
32-
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
30+
uv pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
3331
else
34-
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U
32+
uv pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U
3533
fi
3634
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3735
if [ "${CU_VERSION:-}" == cpu ] ; then
38-
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
36+
uv pip install torch --index-url https://download.pytorch.org/whl/cpu
3937
else
40-
pip3 install torch --index-url https://download.pytorch.org/whl/cu128
38+
uv pip install torch --index-url https://download.pytorch.org/whl/cu128
4139
fi
4240
else
4341
printf "Failed to install pytorch"
@@ -46,9 +44,9 @@ fi
4644

4745
# install tensordict
4846
if [[ "$RELEASE" == 0 ]]; then
49-
pip3 install git+https://github.com/pytorch/tensordict.git
47+
uv pip install git+https://github.com/pytorch/tensordict.git
5048
else
51-
pip3 install tensordict
49+
uv pip install tensordict
5250
fi
5351

5452
# smoke test

.github/unittest/linux_libs/scripts_minari/post_process.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,4 @@
22

33
set -e
44

5-
eval "$(./conda/bin/conda shell.bash hook)"
6-
conda activate ./env
5+
# Note: This script is sourced by run_all.sh, so the environment is already active
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
hypothesis
2+
future
3+
cloudpickle
4+
pytest
5+
pytest-cov
6+
pytest-mock
7+
pytest-instafail
8+
pytest-rerunfailures
9+
pytest-error-for-skips
10+
pytest-asyncio
11+
expecttest
12+
pybind11[global]
13+
pyyaml
14+
scipy
15+
hydra-core
16+
minari[gcs,hdf5,hf,create]
17+
gymnasium>=1.2.0
18+
ale-py
19+
gymnasium-robotics
20+
mujoco

0 commit comments

Comments
 (0)