Skip to content

Commit 3e1567b

Browse files
committed
[CI] Fix SOTA runs
ghstack-source-id: f19d10d Pull-Request: #3252
1 parent 227b33c commit 3e1567b

File tree

5 files changed

+40
-62
lines changed

5 files changed

+40
-62
lines changed

.github/unittest/linux/scripts/environment.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,9 @@ dependencies:
2626
- tensorboard
2727
- imageio==2.26.0
2828
- wandb
29-
- mujoco<3.3.6
3029
- mlflow
3130
- av
3231
- coverage
33-
- ray
3432
- transformers
3533
- ninja
3634
- timm

.github/unittest/linux/scripts/run_all.sh

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,23 @@ if [[ "$PYTHON_VERSION" != "3.13" && "$PYTHON_VERSION" != "3.14" ]]; then
119119
pip3 install dm_control
120120
fi
121121

122+
# Install ray for Python < 3.14 (ray doesn't support Python 3.14 yet)
123+
if [[ "$PYTHON_VERSION" != "3.14" ]]; then
124+
echo "installing ray"
125+
pip3 install ray
126+
fi
127+
128+
# Install mujoco for Python < 3.14 (mujoco doesn't have Python 3.14 wheels yet)
129+
if [[ "$PYTHON_VERSION" != "3.14" ]]; then
130+
echo "installing mujoco"
131+
pip3 install "mujoco<3.3.6"
132+
fi
133+
122134
echo "installing gymnasium"
123-
if [[ "$PYTHON_VERSION" == "3.12" ]]; then
135+
if [[ "$PYTHON_VERSION" == "3.14" ]]; then
136+
# Python 3.14: no mujoco wheels available
137+
pip3 install "gymnasium[atari]>=1.1"
138+
elif [[ "$PYTHON_VERSION" == "3.12" ]]; then
124139
pip3 install ale-py
125140
pip3 install sympy
126141
pip3 install "gymnasium[mujoco]>=1.1" mo-gymnasium[mujoco]

.github/unittest/linux_sota/scripts/run_all.sh

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ dpkg-reconfigure -f noninteractive tzdata || true
1919
apt-get upgrade -y
2020
apt-get install -y vim git wget cmake
2121

22-
apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libosmesa6-dev
23-
apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2
22+
apt-get install -y libglfw3 libosmesa6 libglew-dev libosmesa6-dev
23+
apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2
2424
apt-get install -y g++ gcc patchelf
2525

2626
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
@@ -36,7 +36,6 @@ git config --global --add safe.directory '*'
3636
root_dir="$(git rev-parse --show-toplevel)"
3737
conda_dir="${root_dir}/conda"
3838
env_dir="${root_dir}/env"
39-
lib_dir="${env_dir}/lib"
4039

4140
cd "${root_dir}"
4241

@@ -72,14 +71,6 @@ fi
7271
printf "* Verified Python implementation: %s\n" "$python_impl"
7372

7473
# 3. Install mujoco
75-
printf "* Installing mujoco and related\n"
76-
mkdir -p $root_dir/.mujoco
77-
cd $root_dir/.mujoco/
78-
#wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz
79-
#tar -xf mujoco-2.1.1-linux-x86_64.tar.gz
80-
wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz
81-
tar -xf mujoco210-linux-x86_64.tar.gz
82-
cd "${root_dir}"
8374

8475
# 4. Install Conda dependencies
8576
printf "* Installing dependencies (except PyTorch)\n"
@@ -89,9 +80,6 @@ if ! grep -q "python=${PYTHON_VERSION}" "${this_dir}/environment.yml"; then
8980
fi
9081
cat "${this_dir}/environment.yml"
9182

92-
export MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210
93-
#export MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1
94-
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin
9583
export SDL_VIDEODRIVER=dummy
9684
export MUJOCO_GL=egl
9785
export PYOPENGL_PLATFORM=egl
@@ -100,26 +88,21 @@ export COMPOSITE_LP_AGGREGATE=0
10088

10189
conda env config vars set \
10290
MAX_IDLE_COUNT=1000 \
103-
MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \
10491
DISPLAY=:99 \
105-
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin \
10692
SDL_VIDEODRIVER=dummy \
10793
MUJOCO_GL=egl \
10894
PYOPENGL_PLATFORM=egl \
10995
BATCHED_PIPE_TIMEOUT=60 \
11096
TOKENIZERS_PARALLELISM=true
11197

112-
pip install pip --upgrade
98+
# Use python -m pip to ensure we use conda's Python, not system GraalPy
99+
python -m pip install pip --upgrade
113100

114101
conda env update --file "${this_dir}/environment.yml" --prune
115102

116103
conda deactivate
117104
conda activate "${env_dir}"
118105

119-
# install d4rl
120-
pip install free-mujoco-py
121-
pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl
122-
123106
# TODO: move this down -- will break torchrl installation
124107
conda install -y -c conda-forge libstdcxx-ng=12
125108
## find libstdc - search in the env's lib directory first, then fall back to conda packages
@@ -144,12 +127,6 @@ fi
144127
conda deactivate
145128
conda activate "${env_dir}"
146129

147-
# compile mujoco-py (bc it's done at runtime for whatever reason someone thought it was a good idea)
148-
python -c """import gym;import d4rl"""
149-
150-
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
151-
# rename them
152-
153130
# ============================================================================================ #
154131
# ================================ PyTorch & TorchRL ========================================= #
155132

@@ -160,26 +137,29 @@ elif [[ ${#CU_VERSION} -eq 5 ]]; then
160137
CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}"
161138
fi
162139
echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)"
163-
version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
164-
165140
# submodules
166141
git submodule sync && git submodule update --init --recursive
167142

168-
pip3 install ale-py -U
169-
pip3 install "gym[atari,accept-rom-license]" "gymnasium>=1.1.0" -U
143+
# Install jax with CUDA support before ale-py to satisfy its build dependency
144+
# (ale-py source builds require jax/jaxlib which needs CUDA-compatible wheels)
145+
# See: https://docs.jax.dev/en/latest/installation.html
146+
python -m pip install --upgrade "jax[cuda13-local]"
147+
# Use --no-build-isolation so ale-py uses the already-installed jax/jaxlib
148+
python -m pip install ale-py -U --no-build-isolation
149+
python -m pip install "gymnasium[atari,accept-rom-license,mujoco]>=1.1.0" -U
170150

171151
printf "Installing PyTorch with %s\n" "${CU_VERSION}"
172152
if [[ "$TORCH_VERSION" == "nightly" ]]; then
173153
if [ "${CU_VERSION:-}" == cpu ] ; then
174-
pip3 install --pre torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U
154+
python -m pip install --pre torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U
175155
else
176-
pip3 install --pre torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
156+
python -m pip install --pre torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
177157
fi
178158
elif [[ "$TORCH_VERSION" == "stable" ]]; then
179159
if [ "${CU_VERSION:-}" == cpu ] ; then
180-
pip3 install torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/cpu
160+
python -m pip install torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/cpu
181161
else
182-
pip3 install torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/$CU_VERSION
162+
python -m pip install torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/$CU_VERSION
183163
fi
184164
else
185165
printf "Failed to install pytorch"
@@ -194,9 +174,9 @@ python -c "import functorch"
194174

195175
# install tensordict
196176
if [[ "$RELEASE" == 0 ]]; then
197-
pip3 install git+https://github.com/pytorch/tensordict.git
177+
python -m pip install git+https://github.com/pytorch/tensordict.git
198178
else
199-
pip3 install tensordict
179+
python -m pip install tensordict
200180
fi
201181

202182
printf "* Installing torchrl\n"

.github/unittest/linux_sota/scripts/test_sota.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,6 @@
1515
), "Composite LP must be set to False. Run this test with COMPOSITE_LP_AGGREGATE=0"
1616

1717
commands = {
18-
"dt": """python sota-implementations/decision_transformer/dt.py \
19-
optim.pretrain_gradient_steps=55 \
20-
optim.updates_per_episode=3 \
21-
optim.warmup_steps=10 \
22-
logger.backend= \
23-
env.backend=gymnasium \
24-
env.name=HalfCheetah-v4
25-
""",
26-
"online_dt": """python sota-implementations/decision_transformer/online_dt.py \
27-
optim.pretrain_gradient_steps=55 \
28-
optim.updates_per_episode=3 \
29-
optim.warmup_steps=10 \
30-
env.backend=gymnasium \
31-
logger.backend=
32-
""",
3318
"td3_bc": """python sota-implementations/td3_bc/td3_bc.py \
3419
optim.gradient_steps=55 \
3520
logger.backend=
@@ -39,7 +24,7 @@
3924
collector.frames_per_batch=20 \
4025
collector.num_workers=1 \
4126
logger.backend= \
42-
env.backend=gym \
27+
env.backend=gymnasium \
4328
logger.test_interval=10
4429
""",
4530
"ppo_mujoco": """python sota-implementations/ppo/ppo_mujoco.py \
@@ -57,7 +42,7 @@
5742
loss.mini_batch_size=20 \
5843
loss.ppo_epochs=2 \
5944
logger.backend= \
60-
env.backend=gym \
45+
env.backend=gymnasium \
6146
logger.test_interval=10
6247
""",
6348
"ddpg": """python sota-implementations/ddpg/ddpg.py \
@@ -84,7 +69,7 @@
8469
collector.frames_per_batch=20 \
8570
loss.mini_batch_size=20 \
8671
logger.backend= \
87-
env.backend=gym \
72+
env.backend=gymnasium \
8873
logger.test_interval=40
8974
""",
9075
"dqn_atari": """python sota-implementations/dqn/dqn_atari.py \
@@ -94,7 +79,7 @@
9479
buffer.batch_size=10 \
9580
loss.num_updates=1 \
9681
logger.backend= \
97-
env.backend=gym \
82+
env.backend=gymnasium \
9883
buffer.buffer_size=120
9984
""",
10085
"discrete_cql_online": """python sota-implementations/cql/discrete_cql_online.py \

.github/workflows/test-linux-sota.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ jobs:
2626
tests:
2727
strategy:
2828
matrix:
29-
python_version: ["3.9"]
30-
cuda_arch_version: ["12.8"]
29+
python_version: ["3.10"]
30+
cuda_arch_version: ["13.0"]
3131
fail-fast: false
3232
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3333
with:
3434
runner: linux.g5.4xlarge.nvidia.gpu
3535
repository: pytorch/rl
36-
docker-image: "nvidia/cuda:12.2.0-devel-ubuntu22.04"
36+
docker-image: "nvidia/cuda:13.0.2-cudnn-devel-ubuntu24.04"
3737
gpu-arch-type: cuda
3838
gpu-arch-version: ${{ matrix.cuda_arch_version }}
3939
timeout: 90

0 commit comments

Comments
 (0)