Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c174833
update pyproject jax wsl gpu
chataignault Apr 22, 2025
61ff588
jax quick fix
chataignault Apr 22, 2025
abeefd2
jax models pytest integration
chataignault Apr 23, 2025
31d3864
Merge branch 'main' of github.com:chataignault/models into diffusion-…
chataignault Apr 23, 2025
4ebb37d
Merge branch 'main' of github.com:chataignault/models into diffusion-…
chataignault May 7, 2025
17fcecd
start correcting jax diffusion
chataignault May 7, 2025
d082b76
Merge branch 'main' of github.com:chataignault/models into diffusion-…
chataignault Oct 16, 2025
6b20d58
ddpm start ray inference forwarding
chataignault Oct 16, 2025
0f12f39
Merge branch 'main' of github.com:chataignault/models into diffusion-…
chataignault Dec 9, 2025
e50e4a1
update submodules
chataignault Dec 9, 2025
8096cf2
ddpm: jax tpu pyproject
chataignault Dec 10, 2025
1d79513
ddpm: update gitignore and fix pyproject
chataignault Dec 10, 2025
975ac4e
jax: main with tpu and orbax checkpoint
chataignault Dec 10, 2025
67a5253
jax: update training
chataignault Dec 10, 2025
87b00e8
jax: comment checkpointer
chataignault Dec 10, 2025
9cbd1e3
torch: correct dataloader
chataignault Dec 10, 2025
b7dd0f0
jax: tensorboard logger
chataignault Dec 10, 2025
81b3cf2
jax unet: base dim param
chataignault Dec 11, 2025
df5319c
jax: typo unet and format
chataignault Dec 11, 2025
e223758
jax: larger Unet
chataignault Dec 14, 2025
2d928a9
jax: simpleunet in main
chataignault Dec 14, 2025
a81176f
jax: improve simple unet
chataignault Dec 14, 2025
761fae9
jax: solve sample rng split and improve unet
chataignault Dec 14, 2025
9138251
jax: larger model and no image recentering
chataignault Dec 15, 2025
da4dbcf
jax: jit sampling function
chataignault Dec 15, 2025
725a03b
jax: jit other helper functions
chataignault Dec 15, 2025
a8fda15
jax: deeper unet and correct schedule
chataignault Dec 16, 2025
dfe3ab9
jax: upgrade dependencies and torch wheel
chataignault Dec 18, 2025
104d692
jax: fix orbax checkpointing
chataignault Dec 18, 2025
b0bb025
ignore checkpoints folder
chataignault Dec 18, 2025
1378ec3
orbax: no checkpoint for inference
chataignault Dec 18, 2025
01712b9
jax: minor refactor
chataignault Dec 18, 2025
c4d8ac3
jax: fix cifar10 training
chataignault Dec 19, 2025
e9692a2
jax: correct plot with channels
chataignault Dec 19, 2025
9e0f7f8
tweak params
chataignault Dec 27, 2025
2312b60
unet: add middle-layer attention
chataignault Dec 27, 2025
75467ce
fid metric
chataignault Dec 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,44 +1,63 @@
# ============================================================================
# IGNORE FILES AND DIRECTORIES FROM ALL PROJECTS
# ============================================================================

# config
# ----------------------------------------------------------------------------
# CONFIG
# ----------------------------------------------------------------------------
.vscode
.code-workspace
.todos
*.lock
*.pth
*.cache

# data
# ----------------------------------------------------------------------------
# DATA
# ----------------------------------------------------------------------------
MNIST
FashionMNIST
data
checkpoints
*.csv
*.txt
*.gif
*.png
*.zip
*.pt
*.pkl

# logs
# ----------------------------------------------------------------------------
# LOGS
# ----------------------------------------------------------------------------
mlruns
flax_ckpt
tb_logs
logs
log
runs
lightning_logs

# builds
# ----------------------------------------------------------------------------
# BUILDS
# ----------------------------------------------------------------------------
target
out
debug
dist-newstyle
*.whl
*.exe

# other
# ----------------------------------------------------------------------------
# PYTHON
# ----------------------------------------------------------------------------
__pycache__
notebooks
*.ipynb
*.pt
*.pkl
*.pth
*.gif
*.png
*.lock
*.zip
*.cache

# ----------------------------------------------------------------------------
# ROCQ
# ----------------------------------------------------------------------------
*.glob
*.vo
*.vok
Expand Down
2 changes: 1 addition & 1 deletion diffusion/ddpm/.python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.12
3.13
217 changes: 217 additions & 0 deletions diffusion/ddpm/README_RAY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# Ray-based Parallel Inference for DDPM

This guide explains how to use Ray to generate DDPM samples in parallel, both locally and on a KubeRay cluster.

## Overview

The `run_ray_inference.py` script loads a trained DDPM model and generates multiple samples in parallel using Ray. This is much faster than sequential generation, especially when generating hundreds or thousands of samples.

## Prerequisites

```bash
pip install ray[default] torch torchvision matplotlib numpy
```

For GPU support:
```bash
pip install ray[default] torch torchvision matplotlib numpy
```

## Usage

### 1. Local Mode (Single Machine)

Run inference on your local machine without connecting to a cluster:

```bash
python run_ray_inference.py \
--model_path models/Unet_20241016-14.pt \
--model_name Unet \
--downs 8 16 32 \
--time_emb_dim 4 \
--img_size 28 \
--n_samples 100 \
--timesteps 1000 \
--device cuda \
--ray_address None \
--num_gpus_per_task 0.25
```

**Key parameters:**
- `--model_path`: Path to your trained model checkpoint
- `--n_samples`: Number of samples to generate
- `--ray_address None`: Use local mode (don't connect to cluster)
- `--num_gpus_per_task 0.25`: Allocate 0.25 GPU per task (4 parallel tasks per GPU)

### 2. KubeRay Cluster Mode

#### Step 1: Deploy RayCluster

If using kind (Kubernetes in Docker) for local testing:

```bash
# Create kind cluster if needed
kind create cluster --name ray-cluster

# Install KubeRay operator (if not already installed)
kubectl create -k "github.com/ray-project/kuberay/ray-operator/config/default?ref=v1.0.0"

# Deploy the DDPM inference cluster
kubectl apply -f kuberay-config.yaml

# Check cluster status
kubectl get rayclusters
kubectl get pods
```

#### Step 2: Port Forward to Access Cluster

```bash
# Forward Ray client port
kubectl port-forward service/ddpm-inference-service 10001:10001 &

# Optional: Forward dashboard for monitoring
kubectl port-forward service/ddpm-inference-service 8265:8265 &
```

#### Step 3: Run Inference

```bash
python run_ray_inference.py \
--model_path models/Unet_20241016-14.pt \
--model_name Unet \
--downs 8 16 32 \
--time_emb_dim 4 \
--img_size 28 \
--n_samples 100 \
--timesteps 1000 \
--device cuda \
--ray_address "auto" \
--num_gpus_per_task 0.25
```

**Notes:**
- `--ray_address "auto"`: Automatically discover Ray cluster (works with port forwarding)
- Alternatively, use `--ray_address "ray://localhost:10001"` for explicit connection

#### Step 4: Monitor Progress

Open the Ray dashboard in your browser:
```
http://localhost:8265
```

You can see:
- Active tasks and their progress
- Resource utilization (CPU, GPU, memory)
- Task timeline and execution details

#### Step 5: Cleanup

```bash
# Delete the RayCluster
kubectl delete -f kuberay-config.yaml

# Or delete the entire kind cluster
kind delete cluster --name ray-cluster
```

## How It Works

### Parallelization Strategy

1. **Model Sharing**: The trained model weights are loaded once and put in Ray's object store
2. **Task Distribution**: Each sample generation is a separate Ray task
3. **GPU Allocation**: Each task requests a fraction of a GPU (e.g., 0.25 = 4 tasks per GPU)
4. **Result Collection**: Results are collected as tasks complete using `ray.wait()`

### Key Ray Concepts Demonstrated

- **`@ray.remote`**: Decorator to make functions executable on remote workers
- **`ray.put()`**: Store objects in distributed object store for efficient sharing
- **`ray.get()`**: Retrieve results from remote tasks
- **Resource management**: Specify GPU/CPU requirements per task
- **Cluster connection**: Connect to local or remote Ray clusters

### Performance Benefits

For generating 100 samples with T=1000 timesteps:
- **Sequential** (original code): ~100 * 30 seconds = ~50 minutes
- **Parallel with 1 GPU** (4 tasks): ~25 * 30 seconds = ~13 minutes (4x speedup)
- **Parallel with 4 GPUs** (16 tasks): ~7 * 30 seconds = ~3.5 minutes (14x speedup)

## Customizing the KubeRay Configuration

Edit `kuberay-config.yaml` to adjust:

### Worker Count
```yaml
workerGroupSpecs:
- replicas: 3 # Increase for more parallel workers
minReplicas: 1
maxReplicas: 5
```

### GPU Resources
```yaml
resources:
limits:
nvidia.com/gpu: "1" # GPUs per worker pod
```

### CPU/Memory
```yaml
resources:
limits:
cpu: "8"
memory: "16Gi"
```

## Troubleshooting

### Connection Issues

If you can't connect to the Ray cluster:
```bash
# Check Ray head pod
kubectl get pods -l ray.io/node-type=head

# Check Ray head logs
kubectl logs -l ray.io/node-type=head

# Verify port forwarding
lsof -i :10001
```

### GPU Issues

If GPUs aren't detected:
```bash
# Check GPU availability in pods
kubectl exec -it <worker-pod-name> -- nvidia-smi

# Check Ray cluster resources
python -c "import ray; ray.init(address='auto'); print(ray.cluster_resources())"
```

### Out of Memory

If you run out of GPU memory:
- Reduce `--num_gpus_per_task` to allocate more GPU memory per task
- Generate samples in smaller batches
- Increase worker replicas to distribute load

## Learning Resources

- [Ray Documentation](https://docs.ray.io/)
- [KubeRay Documentation](https://docs.ray.io/en/latest/cluster/kubernetes/index.html)
- [Ray Core Walkthrough](https://docs.ray.io/en/latest/ray-core/walkthrough.html)
- [Ray Dashboard Guide](https://docs.ray.io/en/latest/ray-observability/getting-started.html)

## Next Steps

Once you're comfortable with this inference example, you can explore:
1. **Ray Tune** for hyperparameter optimization of your training script
2. **Ray Train** for distributed training across multiple GPUs/nodes
3. **Ray Data** for efficient data preprocessing pipelines
4. **Ray Serve** for serving your model as an API endpoint
87 changes: 87 additions & 0 deletions diffusion/ddpm/kuberay-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
apiVersion: ray.io/v1alpha1
kind: RayCluster
metadata:
name: ddpm-inference-cluster
namespace: default
spec:
rayVersion: '2.9.0'
enableInTreeAutoscaling: false
headGroupSpec:
rayStartParams:
dashboard-host: '0.0.0.0'
num-cpus: '0' # Don't schedule workloads on head
template:
spec:
containers:
- name: ray-head
image: rayproject/ray:2.9.0-py310-gpu
resources:
limits:
cpu: "2"
memory: "4Gi"
requests:
cpu: "1"
memory: "2Gi"
ports:
- containerPort: 6379
name: gcs-server
- containerPort: 8265
name: dashboard
- containerPort: 10001
name: client
volumeMounts:
- mountPath: /workspace
name: workspace
volumes:
- name: workspace
emptyDir: {}
workerGroupSpecs:
- replicas: 1
minReplicas: 1
maxReplicas: 3
groupName: gpu-workers
rayStartParams:
num-cpus: '4'
num-gpus: '1'
template:
spec:
containers:
- name: ray-worker
image: rayproject/ray:2.9.0-py310-gpu
lifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "ray stop"]
resources:
limits:
cpu: "4"
memory: "8Gi"
nvidia.com/gpu: "1" # Request 1 GPU per worker
requests:
cpu: "2"
memory: "4Gi"
nvidia.com/gpu: "1"
volumeMounts:
- mountPath: /workspace
name: workspace
volumes:
- name: workspace
emptyDir: {}
---
apiVersion: v1
kind: Service
metadata:
name: ddpm-inference-service
namespace: default
spec:
type: ClusterIP
selector:
ray.io/cluster: ddpm-inference-cluster
ray.io/node-type: head
ports:
- name: dashboard
port: 8265
targetPort: 8265
- name: client
port: 10001
targetPort: 10001
Loading