This guide explains how to set up the environment and train the HSTU/DLRM models on Cloud TPU v6.
If you are developing on a TPU VM directly, use a virtual environment to avoid conflicts with the system-level Python packages.
Ensure you have Python 3.12+ installed.
python3 --versionRun the following from the root of the repository:
# Create the venv
python3 -m venv venv
# Activate it
source venv/bin/activateInstall the latest version of the jax-tpu-embedding library:
pip install ./jax_tpu_embedding-0.1.0.dev20260226-cp312-cp312-manylinux_2_31_x86_64.whlpip install -r requirements.txtWe need to force a specific version of Protobuf to ensure compatibility with our TPU stack. Run this exactly as shown:
pip install "protobuf>=6.31.1" --no-depsThe --no-deps flag is required to prevent pip from downgrading it due to strict dependency pinning in other libraries.
python dlrm_experiment_test.pyIf you prefer not to manage a virtual environment or want to deploy this as a container, you can use a docker image. We provide two options: (1) Building your own docker image with the Dockerfile provided in this repo; (2) Use our latest docker image from Dockerhub to run the code.
Run this command from the root of the repository. It reads the Dockerfile, installs all dependencies, and creates a ready-to-run image. You will need to have the jax-tpu-embedding wheel for building your own docker image. Steps to get the wheel can be found here: https://github.com/jax-ml/jax-tpu-embedding.
docker build -t recml-training .The image name is: docker.io/recsyscmcs/recml-tpu:v1.0.0. This image contains all the latest dependencies and sets up the env for RecML to run the algorithms successfully on V6 and V7 TPUs.
This will run the docker image and execute the command specified, which is currently set to run DLRM. The below command uses our latest image, but feel free to change the image to your own.
docker run --rm --privileged \
--net=host \
--ipc=host \
--name recml-experiment \
docker.io/recsyscmcs/recml-tpu:v1.0.0