Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 3 additions & 16 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

extensions = [
"myst_nb",
"sphinx_gallery.gen_gallery",
"sphinx_collections",
# api docs
"sphinx.ext.autodoc",
Expand All @@ -34,10 +33,9 @@
"_build",
"Thumbs.db",
".DS_Store",
(
"_collections/examples/model_load/from_safetensor_load/*"
"_collections/examples/rl/README.md"
),
"_collections/examples/README.rst",
"_collections/examples/model_load/from_safetensor_load/*",
"_collections/examples/rl/README.md",
"_collections/examples/sft/**",
"_collections/examples/deepscaler/**",
]
Expand All @@ -61,17 +59,6 @@
"navigation_with_keys": False,
}

# -- Options for sphinx-gallery ----------------------------------------------

sphinx_gallery_conf = {
"examples_dirs": "_collections/examples", # path to your example scripts
"gallery_dirs": (
"_collections/gallery"
), # path to where to save gallery generated output
"filename_pattern": "*.py",
"ignore_pattern": r"rl/|sft/|deepscaler/",
}

# -- Options for myst -------------------------------------------------------
myst_enable_extensions = [
"amsmath",
Expand Down
107 changes: 107 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Examples and Guides

``` {toctree}
:maxdepth: 1
:hidden:

_collections/examples/grpo_gemma
_collections/examples/logit_distillation
_collections/examples/qlora_gemma
_collections/examples/dpo_gemma
```

This section provides a high-level overview of the Colab notebooks, scripts, and
example directories.

Expand Down Expand Up @@ -59,3 +69,100 @@ All examples are located in this
</tr>
</tbody>
</table>

## GCE VM Setup for Fine-Tuning

### 1. Create TPU VM

Create a v5litepod-8 TPU VM in GCE:

* SW version: `v2-alpha-tpuv5-lite`
* Name: `v5-8`

Reference:
[TPU Runtime Versions](https://cloud.google.com/tpu/docs/runtimes?hl=en#training-v5p-v5e)

### 2. Configure VM

SSH into the VM using the supplied gcloud command, then run:

```bash
# Create .env file with required credentials
vim .env

# Download and install Anaconda
curl -O https://repo.anaconda.com/archive/Anaconda3-2025.06-0-Linux-x86_64.sh
bash ~/Anaconda3-2025.06-0-Linux-x86_64.sh # always input "yes"/enter
source ~/.bashrc

# Create conda environment (Python 3.12 - MUST BE 12, NOT 11!)
conda create -n colab python=3.12 -y
conda activate colab

# Install dependencies
pip install 'ipykernel<7' jupyterlab
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install --upgrade clu
```

Reference:
[Run JAX on TPU](https://cloud.google.com/tpu/docs/run-calculation-jax)

Exit the SSH session after setup is complete.

### 3. Connect from Local Machine

From your local machine, run the following to connect to Jupyter Lab:

```bash
gcloud compute tpus tpu-vm ssh v5-8 --zone=us-west1-c \
-- -L 8080:localhost:8080 -L 6006:localhost:6006 \
"source \$HOME/anaconda3/etc/profile.d/conda.sh && \
conda activate colab && \
jupyter lab \
--ServerApp.allow_origin='https://colab.research.google.com' \
--port=8080 \
--no-browser \
--ServerApp.port_retries=0 \
--ServerApp.allow_credentials=True"
```

Reference:
[Local Runtimes in Colab](https://research.google.com/colaboratory/local-runtimes.html)

### 4. Environment Variables

Example `.env` file:

```bash
HF_TOKEN=
KAGGLE_USERNAME=
KAGGLE_KEY=
WANDB_API_KEY=
```

## Loading Saved Safetensors Models

To load a saved safetensors model back into JAX (with a given `local_path`):

```python
import os
import jax
import jax.numpy as jnp
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib


local_path = '[PLACEHOLDER]'
MESH = [(1, 1), ("fsdp", "tp")]

mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))
with mesh:
model = params_safetensors_lib.create_model_from_safe_tensors(
os.path.abspath(local_path), (model_config), mesh, dtype=jnp.bfloat16
)
```

## Notes

* **IMPORTANT**: Use `%pip` not `!pip` in notebooks!
* Python 3.12 is the recommended version
85 changes: 0 additions & 85 deletions docs/gallery.rst

This file was deleted.

1 change: 0 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ examples.md
talks.md
contributing.md
code-of-conduct.md
gallery.rst
```

```{eval-rst}
Expand Down
Loading