Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions architectures/centralized/client/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ pub async fn build_app(
.await?;

let state_options: RunInitConfig<ClientId, ClientId> = RunInitConfig {
parallelism_auto: p.parallelism_auto,
data_parallelism: p.data_parallelism,
tensor_parallelism: p.tensor_parallelism,
micro_batch_size: p.micro_batch_size,
Expand Down
1 change: 1 addition & 0 deletions architectures/decentralized/solana-client/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ pub async fn build_app(

let state_options: RunInitConfig<psyche_solana_coordinator::ClientId, NetworkIdentity> =
RunInitConfig {
parallelism_auto: p.parallelism_auto,
data_parallelism: p.data_parallelism,
tensor_parallelism: p.tensor_parallelism,
micro_batch_size: p.micro_batch_size,
Expand Down
24 changes: 24 additions & 0 deletions psyche-book/src/enduser/create-run.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,30 @@ run-manager create-run \

At this point, your run has been successfully created.

### Adding parallelism configuration (required for --parallelism-auto)

If you want clients to use `PARALLELISM_AUTO=true` for automatic configuration, you must add a `parallelism_data.json` file to your model's GCS bucket alongside the model files.

```json
{
"H100": {
"1": { "dp": 1, "tp": 1, "micro_batch_size": 4 },
"8": { "dp": 4, "tp": 2, "micro_batch_size": 4 }
},
"H200": {
"8": { "dp": 8, "tp": 1, "micro_batch_size": 8 }
}
}
```

Format: `gpu_type` → `num_gpus` → config

- **gpu_type**: GPU model name (e.g., "H100", "H200")
- **num_gpus**: Number of GPUs available (e.g., "1", "8")
- **dp**: Data parallelism
- **tp**: Tensor parallelism
- **micro_batch_size**: Micro batch size per GPU

### Initializing configuration

Initially, the run will not have any configuration defined and will remain paused, so no clients can join yet.
Expand Down
10 changes: 10 additions & 0 deletions psyche-book/src/enduser/join-run.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,29 @@ though you might need to.

**`NVIDIA_DRIVER_CAPABILITIES`** - An environment variable that the NVIDIA Container Toolkit uses to determine which compute capabilities should be provided to your container. It is recommended to set it to 'all', e.g. `NVIDIA_DRIVER_CAPABILITIES=all`.

**`PARALLELISM_AUTO`** - Set to `true` to automatically detect optimal parallelism settings based on your GPU hardware.

- When enabled, the client fetches a `parallelism_data.json` lookup table from the model's GCS bucket and picks the best `DATA_PARALLELISM`, `TENSOR_PARALLELISM`, and `MICRO_BATCH_SIZE` for your GPU type and count
- Your GPU type and count must be present in the lookup table
- This is the recommended option for most users
- If set, manual parallelism settings below will be ignored

**`DATA_PARALLELISM`** - Number of GPUs to distribute training data across.

- If you have multiple GPUs, you can set this to 2, 4, etc. to speed up training
- If you have 1 GPU, set this to `1`
- Ignored if `PARALLELISM_AUTO=true`

**`TENSOR_PARALLELISM`** - Number of GPUs to distribute the model across, this lets you train a model you can't fit on one single GPU.

- If you have 1 GPU, set this to `1`
- If your have `n` GPUs you can distribute the model across all of them by setting it to `n`.
- Ignored if `PARALLELISM_AUTO=true`

**`MICRO_BATCH_SIZE`** - Number of samples processed per GPU per training step

- Set as high as your GPU memory allows
- Ignored if `PARALLELISM_AUTO=true`

**`AUTHORIZER`** - The Solana address that authorized your wallet to join this run

Expand Down
1 change: 1 addition & 0 deletions shared/client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ sysinfo = "0.32.0"
iroh.workspace = true
iroh-blobs.workspace = true
google-cloud-storage.workspace = true
nvml-wrapper = "0.11.0"

[features]
parallelism = ["psyche-modeling/parallelism"]
Expand Down
4 changes: 4 additions & 0 deletions shared/client/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ pub struct TrainArgs {
#[clap(long, env, value_parser = parse_trim_quotes)]
pub run_id: String,

/// Auto-detect parallelism settings from lookup table based on model and GPU count
#[clap(long, env)]
pub parallelism_auto: bool,

#[clap(long, default_value_t = 1, env)]
pub data_parallelism: usize,

Expand Down
1 change: 1 addition & 0 deletions shared/client/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod cli;
mod client;
mod fetch_data;
pub mod parallelism_lookup;
mod protocol;
mod state;
mod tui;
Expand Down
80 changes: 80 additions & 0 deletions shared/client/src/parallelism_lookup.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use anyhow::{Context, Result};
use nvml_wrapper::Nvml;
use psyche_data_provider::{RunDownClient, download_parallelism_data_from_gcs_signed};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::info;

#[derive(Debug, Clone, Copy, Deserialize)]
pub struct ParallelismConfig {
pub dp: usize,
pub tp: usize,
pub micro_batch_size: usize,
}

type Table = HashMap<String, HashMap<String, ParallelismConfig>>;

pub async fn lookup(run_down_client: &Arc<RunDownClient>) -> Result<ParallelismConfig> {
let device_count = tch::Cuda::device_count() as usize;
if device_count == 0 {
anyhow::bail!("No GPUs found for parallelism auto-detection");
}

let gpu_type = normalize_gpu_name(&get_gpu_type_from_nvml()?);
info!("Detected {} x {} GPU(s)", device_count, gpu_type);

info!(
"Fetching parallelism_data.json from GCS via run-down signed URLs for run {}",
run_down_client.run_id()
);
let json = download_parallelism_data_from_gcs_signed(run_down_client)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;

let table: Table =
serde_json::from_str(&json).context("Failed to parse parallelism_data.json")?;

lookup_in_table(&table, &gpu_type, device_count)
}

fn get_gpu_type_from_nvml() -> Result<String> {
let nvml = Nvml::init().context("Failed to initialize NVML")?;
let device = nvml
.device_by_index(0)
.context("Failed to get GPU device 0")?;
device.name().context("Failed to get GPU name")
}

fn normalize_gpu_name(raw_name: &str) -> String {
let upper = raw_name.to_uppercase();
if upper.contains("H200") {
"H200".to_string()
} else if upper.contains("H100") {
"H100".to_string()
} else {
raw_name.to_string()
}
}

fn lookup_in_table(table: &Table, gpu_type: &str, num_gpus: usize) -> Result<ParallelismConfig> {
let gpu_configs = table.get(gpu_type).ok_or_else(|| {
anyhow::anyhow!(
"No parallelism config for GPU type '{}'. Available: {:?}",
gpu_type,
table.keys().collect::<Vec<_>>()
)
})?;

gpu_configs
.get(&num_gpus.to_string())
.copied()
.ok_or_else(|| {
anyhow::anyhow!(
"No parallelism config for {} x {}. Available counts: {:?}",
num_gpus,
gpu_type,
gpu_configs.keys().collect::<Vec<_>>()
)
})
}
51 changes: 50 additions & 1 deletion shared/client/src/state/init.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "parallelism")]
use crate::parallelism_lookup;
use crate::{WandBInfo, fetch_data::DataFetcher};
use psyche_coordinator::{
Coordinator, HealthChecks,
Expand Down Expand Up @@ -51,6 +53,8 @@ pub struct RunInitConfig<T: NodeIdentity, A: AuthenticatableIdentity> {
pub device: Devices,
pub hub_read_token: Option<String>,
pub hub_max_concurrent_downloads: usize,
/// If true, auto-detect parallelism from lookup table (overrides dp/tp/micro_batch_size)
pub parallelism_auto: bool,
pub data_parallelism: usize,
pub tensor_parallelism: usize,
pub micro_batch_size: usize,
Expand Down Expand Up @@ -119,6 +123,9 @@ pub enum InitRunError {
#[error("Unsupported architecture: {0}")]
UnsupportedArchitecture(String),

#[error("Parallelism auto-detection failed: {0}")]
ParallelismLookupFailed(anyhow::Error),

#[cfg(feature = "python")]
#[error("Python distributed error: {0}")]
PythonDistributedError(#[from] psyche_modeling::PythonDistributedCausalLMError),
Expand Down Expand Up @@ -173,7 +180,7 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static> RunInitConfigAndIO<T
state: Coordinator<T>,
) -> Result<StepStateMachine<T, A>, InitRunError> {
let Self {
init_config,
mut init_config,
tx_witness,
tx_health_check,
tx_model,
Expand All @@ -197,6 +204,48 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static> RunInitConfigAndIO<T

let model::Model::LLM(llm) = state.model;

// Parallelism auto-detection
#[cfg(not(feature = "parallelism"))]
if init_config.parallelism_auto {
return Err(InitRunError::ParallelismLookupFailed(anyhow::anyhow!(
"--parallelism-auto requires building with --features=parallelism"
)));
}

#[cfg(feature = "parallelism")]
if init_config.parallelism_auto {
if init_config.data_parallelism != 1
|| init_config.tensor_parallelism != 1
|| init_config.micro_batch_size != 1
{
tracing::warn!(
"--parallelism-auto is set, ignoring manual dp/tp/micro_batch_size values"
);
}

let run_down_client = init_config
.checkpoint_config
.run_down_client
.as_ref()
.ok_or_else(|| {
InitRunError::ParallelismLookupFailed(anyhow::anyhow!(
"--parallelism-auto requires a GCS checkpoint type with run-down service"
))
})?;

let config = parallelism_lookup::lookup(run_down_client)
.await
.map_err(InitRunError::ParallelismLookupFailed)?;

info!(
"Parallelism auto-detected: dp={}, tp={}, micro_batch_size={}",
config.dp, config.tp, config.micro_batch_size
);
init_config.data_parallelism = config.dp;
init_config.tensor_parallelism = config.tp;
init_config.micro_batch_size = config.micro_batch_size;
}

let hub_read_token = init_config.hub_read_token.clone();
let hub_max_concurrent_downloads = init_config.hub_max_concurrent_downloads;
let data_future = async {
Expand Down
43 changes: 43 additions & 0 deletions shared/data-provider/src/gcs_signed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,49 @@ pub async fn upload_to_gcs_signed(
Ok(())
}

/// Download parallelism_data.json from GCS via signed URLs.
pub async fn download_parallelism_data_from_gcs_signed(
run_down: &RunDownClient,
) -> Result<String, DownloadError> {
let http = reqwest::Client::new();

let download_response = run_down
.get_download_urls()
.await
.map_err(|e| DownloadError::RunDown(e.to_string()))?;

let entry = download_response
.urls
.iter()
.find(|e| e.path.ends_with("parallelism_data.json"))
.ok_or_else(|| {
DownloadError::RunDown(
"parallelism_data.json not found in GCS. Upload it alongside the model files."
.to_string(),
)
})?;

info!("Downloading parallelism_data.json via signed URL");

let response = http
.get(&entry.url)
.send()
.await
.map_err(|e| DownloadError::RunDown(e.to_string()))?;

if !response.status().is_success() {
return Err(DownloadError::RunDown(format!(
"Failed to download parallelism_data.json: {}",
response.status()
)));
}

response
.text()
.await
.map_err(|e| DownloadError::RunDown(e.to_string()))
}

pub async fn download_model_from_gcs_signed_async(
run_down: &RunDownClient,
) -> Result<Vec<PathBuf>, DownloadError> {
Expand Down
5 changes: 4 additions & 1 deletion shared/data-provider/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ pub use gcs::{
GcsCheckpointManifest, GcsManifestMetadata, GcsUploadInfo, ManifestFileEntry, ManifestMetadata,
download_model_from_gcs_async, download_model_from_gcs_sync, upload_to_gcs,
};
pub use gcs_signed::{download_model_from_gcs_signed_async, upload_to_gcs_signed};
pub use gcs_signed::{
download_model_from_gcs_signed_async, download_parallelism_data_from_gcs_signed,
upload_to_gcs_signed,
};
pub use hub::{
HubUploadInfo, download_dataset_repo_async, download_dataset_repo_sync,
download_model_repo_async, download_model_repo_sync, upload_to_hub,
Expand Down
Loading