Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions shared/client/src/state/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ pub enum InitRunError {
#[error("Unsupported architecture: {0}")]
UnsupportedArchitecture(String),

#[error("Device is not usable: {0}")]
DeviceNotUsable(anyhow::Error),

#[cfg(feature = "python")]
#[error("Python distributed error: {0}")]
PythonDistributedError(#[from] psyche_modeling::PythonDistributedCausalLMError),
Expand Down Expand Up @@ -193,6 +196,12 @@ impl RunInitConfigAndIO {
));
}

// Fail fast if the device is not usable
init_config
.device
.ensure_usable()
.map_err(InitRunError::DeviceNotUsable)?;

let model::Model::LLM(llm) = state.model;

let hub_read_token = init_config.hub_read_token.clone();
Expand Down
69 changes: 61 additions & 8 deletions shared/modeling/src/device_utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::{fmt, str::FromStr};

use anyhow::Context;
use itertools::Itertools;
use tch::{Device, utils::has_mps};
#[cfg(test)]
use tch::{Kind, Tensor};
use tch::{utils::has_mps, Device, Kind, Tensor};
use thiserror::Error;

/// Get all available CUDA devices
Expand Down Expand Up @@ -99,6 +98,22 @@ impl Devices {
Devices::Mps => has_mps(),
}
}

/// Ensures the device is usable
///
/// Currently does nothing for CPU and MPS devices
pub fn ensure_usable(&self) -> anyhow::Result<()> {
match self {
Devices::Cpu | Devices::Mps => Ok(()),
Devices::Cuda(device_indices) => {
for &device_idx in device_indices {
ensure_cuda_device_usable(device_idx)
.with_context(|| format!("cuda:{device_idx} is not usable"))?;
}
Ok(())
}
}
}
}

/// Get all available devices, for debugging purposes
Expand Down Expand Up @@ -217,6 +232,46 @@ impl DevicePytorchStr for Device {
}
}

/// Ensures that a CUDA device is usable
///
/// The causal SDPA with a large sequence length is designed to trigger cuDNN's
/// runtime kernel compilation
fn ensure_cuda_device_usable(device_idx: usize) -> anyhow::Result<()> {
let device = Device::Cuda(device_idx);

let batch: i64 = 2;
let heads: i64 = 32;
let seq_len: i64 = 4096;
let head_dim: i64 = 128;

let q = Tensor::f_randn([batch, heads, seq_len, head_dim], (Kind::BFloat16, device))?
.f_set_requires_grad(true)
.context("failed to set requires grad")?;
let k = Tensor::f_randn([batch, heads, seq_len, head_dim], (Kind::BFloat16, device))?
.f_set_requires_grad(true)
.context("failed to set requires grad")?;
let v = Tensor::f_randn([batch, heads, seq_len, head_dim], (Kind::BFloat16, device))?
.f_set_requires_grad(true)
.context("failed to set requires grad")?;

let output = Tensor::f_scaled_dot_product_attention::<Tensor>(
&q, &k, &v, None, // attn_mask
0.0, // dropout_p
true, // is_causal
None, // scale (None = default 1/sqrt(head_dim))
false, // enable_gqa
)
.context("failed to run SDPA operation")?;

let loss = output
.f_sum(Kind::BFloat16)
.context("failed to sum SDPA output")?;
loss.f_backward()
.context("failed to run backward pass on SDPA output")?;

Ok(())
}

#[cfg(test)]
mod tests {
use tch::utils::has_cuda;
Expand All @@ -242,11 +297,9 @@ mod tests {
.unwrap(),
Devices::Cuda((0..tch::Cuda::device_count() as usize).collect())
);
assert!(
format!("cuda:{}", tch::Cuda::device_count())
.parse::<Devices>()
.is_err()
);
assert!(format!("cuda:{}", tch::Cuda::device_count())
.parse::<Devices>()
.is_err());
} else {
assert!(matches!(
"cuda".parse::<Devices>(),
Expand Down
Loading