diff --git a/.github/workflows/tpu-nightly-regression.yml b/.github/workflows/tpu-nightly-regression.yml index 15537db4f..10c967c12 100644 --- a/.github/workflows/tpu-nightly-regression.yml +++ b/.github/workflows/tpu-nightly-regression.yml @@ -119,22 +119,22 @@ jobs: python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=vllm --rollout-server-mode=True --cluster-setup=disaggregated-3-way || FAILED=1 # SGLang Tests - unset JAX_PLATFORMS - pip list | egrep 'jax|flax|libtpu' - cd .. - git clone https://github.com/sgl-project/sglang-jax.git && cd sglang-jax/python && pip install -e . && cd ../.. - pip install jax==0.8.1 flax==0.12.0 libtpu==0.0.24 - pip list | egrep 'jax|flax|libtpu' - cd tunix + # unset JAX_PLATFORMS + # pip list | egrep 'jax|flax|libtpu' + # cd .. + # git clone https://github.com/sgl-project/sglang-jax.git && cd sglang-jax/python && pip install -e . && cd ../.. + # pip install jax==0.8.1 flax==0.12.0 libtpu==0.0.24 + # pip list | egrep 'jax|flax|libtpu' + # cd tunix - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax in colocated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=sglang_jax || FAILED=1 + # echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax in colocated mode ..." + # python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=sglang_jax || FAILED=1 - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax in 2 way disaggregated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=sglang_jax --cluster-setup=disaggregated-2-way || FAILED=1 + # echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax in 2 way disaggregated mode ..." + # python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=sglang_jax --cluster-setup=disaggregated-2-way || FAILED=1 - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax in 3 way disaggregated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=sglang_jax --cluster-setup=disaggregated-3-way || FAILED=1 + # echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax in 3 way disaggregated mode ..." + # python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=sglang_jax --cluster-setup=disaggregated-3-way || FAILED=1 # echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax with LoRA ..." # python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine sglang_jax --enable-lora --lora-target-modules all || FAILED=1