Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions shared/client/src/state/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,11 +596,18 @@ impl RunInitConfigAndIO {
let device = device.ok_or_else(|| {
ModelLoadError::NoDeviceForRank(rank, devices)
})?;
// Macos/Metal does not support BFloat16; use Float32 for stability
// Note that this might not work in production, it's intended for small test runs.
let model_kind = if device == tch::Device::Mps {
Kind::Float
} else {
Kind::BFloat16
};
match architecture {
model::LLMArchitecture::HfLlama => {
LlamaForCausalLM::from_pretrained(
&source.try_into()?,
Some(Kind::BFloat16),
Some(model_kind),
attn_implementation,
Some(device),
tensor_parallelism_world,
Expand All @@ -611,7 +618,7 @@ impl RunInitConfigAndIO {
model::LLMArchitecture::HfDeepseek => {
DeepseekForCausalLM::from_pretrained(
&source.try_into()?,
Some(Kind::BFloat16),
Some(model_kind),
attn_implementation,
Some(device),
tensor_parallelism_world,
Expand Down
55 changes: 45 additions & 10 deletions shared/modeling/src/distro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ impl TransformDCT {

// Pregenerate DCT basis matrices
if let std::collections::hash_map::Entry::Vacant(e) = f_dict.entry(sc) {
let i = Tensor::eye(sc, (Kind::Float, variable.device()));
// Metal does not support FFT/complex operations; compute on CPU
// In this case since it's an Identity matrix it shouldn't matter too much.
let dct_device = if variable.device() == Device::Mps {
Device::Cpu
} else {
variable.device()
};
let i = Tensor::eye(sc, (Kind::Float, dct_device));
e.insert(
Self::dct(&i, true)
.to_kind(variable.kind())
Expand Down Expand Up @@ -402,36 +409,56 @@ impl CompressDCT {
kind: Kind,
device: Device,
) -> Tensor {
let idx_concat = Tensor::cat(idx, -1).to_device(device);
let val_concat = Tensor::cat(val, -1).to_device(device);
let idx_concat = crate::mps_compat::cat_owned(idx, -1).to_device(device);
let val_concat = crate::mps_compat::cat_owned(val, -1).to_device(device);
// Call the decompress method
Self::decompress(&idx_concat, &val_concat, xshape, totalk, kind, device)
}
}

/// For MPS devices we need to move to CPU for processing
/// For Linux/CUDA devices this doesn't change anything and performs a shallow clone
fn move_to_cpu_if_mps(tensor: &Tensor) -> Tensor {
if tensor.device() == Device::Mps {
tensor.to(Device::Cpu)
} else {
tensor.shallow_clone()
}
}

fn compress_idx(max_value: i64, idx: &Tensor) -> Tensor {
if max_value <= 256 {
idx.to_kind(Kind::Uint8)
} else if max_value <= 65536 {
idx.to_kind(Kind::UInt16).view_dtype(Kind::Uint8)
move_to_cpu_if_mps(idx)
.to_kind(Kind::UInt16)
.view_dtype(Kind::Uint8)
.to(idx.device())
} else if max_value <= 4294967296 {
idx.to_kind(Kind::UInt32).view_dtype(Kind::Uint8)
move_to_cpu_if_mps(idx)
.to_kind(Kind::UInt32)
.view_dtype(Kind::Uint8)
.to(idx.device())
} else {
idx.shallow_clone()
}
}

fn decompress_idx(max_value: i64, idx: &Tensor) -> Tensor {
if max_value <= 256 {
let viewed = if max_value <= 256 {
idx.view_dtype(Kind::Uint8)
} else if max_value <= 65536 {
idx.view_dtype(Kind::UInt16)
move_to_cpu_if_mps(idx)
.view_dtype(Kind::UInt16)
.to(idx.device())
} else if max_value <= 4294967296 {
idx.view_dtype(Kind::UInt32)
move_to_cpu_if_mps(idx)
.view_dtype(Kind::UInt32)
.to(idx.device())
} else {
idx.shallow_clone()
}
.to_kind(Kind::Int64)
};
viewed.to_kind(Kind::Int64)
}

struct State {
Expand Down Expand Up @@ -585,6 +612,14 @@ impl Distro {
let (sparse_idx, sparse_val, xshape, totalk) =
CompressDCT::compress(&self.transform.encode(&full_delta), self.compression_topk);

// Since MPS does not support BFloat16, we ensure to cast again to BFloat16 for compability
// since the receiving end will expect that.
let sparse_val = if sparse_val.device() == Device::Mps {
sparse_val.to_kind(Kind::BFloat16)
} else {
sparse_val
};

let delta_energy: Option<f64> = match stats {
true => Some(
full_delta
Expand Down
1 change: 1 addition & 0 deletions shared/modeling/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod distro;
mod dummy;
mod fp32_gradient_accumulator;
mod models;
pub(crate) mod mps_compat;
mod optimizer;
mod parallelism;
#[cfg(feature = "python")]
Expand Down
12 changes: 6 additions & 6 deletions shared/modeling/src/models/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ pub fn apply_rotary_pos_emb(
let freqs = inv_freq_expanded.matmul(&position_ids_expanded); // [pos_b, head_dim_2, seq_len]
let freqs = freqs.transpose(1, 2); // [pos_b, seq_len, head_dim_2]

let emb = Tensor::cat(&[&freqs, &freqs], -1); // [pos_b, seq_len, head_dim]
let emb = crate::mps_compat::cat(&[&freqs, &freqs], -1); // [pos_b, seq_len, head_dim]

let mut cos = emb.cos();
let mut sin = emb.sin();
Expand Down Expand Up @@ -369,8 +369,8 @@ impl MLAAttention {

let (q_pe, k_pe) = apply_rotary_pos_emb(cache, &q_pe, &k_pe, position_ids);

let mut query_states = Tensor::cat(&[&q_nope, &q_pe], -1);
let mut key_states = Tensor::cat(&[&k_nope, &k_pe], -1);
let mut query_states = crate::mps_compat::cat(&[&q_nope, &q_pe], -1);
let mut key_states = crate::mps_compat::cat(&[&k_nope, &k_pe], -1);

let y = match self.attn_implementation {
#[cfg(feature = "parallelism")]
Expand All @@ -379,7 +379,7 @@ impl MLAAttention {
let full_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim;
if full_head_dim != self.head_v_dim {
let pad_size = full_head_dim - self.head_v_dim;
padded_value_states = Tensor::cat(
padded_value_states = crate::mps_compat::cat(
&[
&value_states,
&Tensor::zeros(
Expand Down Expand Up @@ -442,7 +442,7 @@ impl MLAAttention {
let mut padded_value_states = value_states.shallow_clone();
if self.qk_nope_head_dim + self.qk_rope_head_dim != self.head_v_dim {
let pad_size = self.qk_nope_head_dim + self.qk_rope_head_dim - self.head_v_dim;
padded_value_states = Tensor::cat(
padded_value_states = crate::mps_compat::cat(
&[
&value_states,
&Tensor::zeros(
Expand Down Expand Up @@ -804,7 +804,7 @@ impl DeepseekMoE {
}
}

Tensor::cat(&outputs, 0)
crate::mps_compat::cat_owned(&outputs, 0)
}
}

Expand Down
29 changes: 29 additions & 0 deletions shared/modeling/src/mps_compat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use tch::{Device, Tensor};

/// Wrapper around Tensor::cat. For Linux/CUDA this is the same as Tensor::cat,
/// but for MPS, we need to move the tensors to CPU before concatenating and then move back to MPS,
/// else it would crash.
pub fn cat(tensors: &[&Tensor], dim: i64) -> Tensor {
if tensors.is_empty() {
return Tensor::empty([0], (tch::Kind::Float, Device::Cpu));
}
let device = tensors[0].device();
if device != Device::Mps {
return Tensor::cat(tensors, dim);
}
let cpu_tensors: Vec<Tensor> = tensors.iter().map(|t| t.to(Device::Cpu)).collect();
Tensor::cat(&cpu_tensors.iter().collect::<Vec<_>>(), dim).to(device)
}

/// Owned-tensor variant for when callers already have Vec<Tensor>.
pub fn cat_owned(tensors: &[Tensor], dim: i64) -> Tensor {
if tensors.is_empty() {
return Tensor::empty([0], (tch::Kind::Float, Device::Cpu));
}
let device = tensors[0].device();
if device != Device::Mps {
return Tensor::cat(tensors, dim);
}
let cpu_tensors: Vec<Tensor> = tensors.iter().map(|t| t.to(Device::Cpu)).collect();
Tensor::cat(&cpu_tensors.iter().collect::<Vec<_>>(), dim).to(device)
}
4 changes: 2 additions & 2 deletions shared/modeling/src/rope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ pub fn rotate_half(xs: &Tensor) -> Tensor {
let last_dim = *xs.size().last().unwrap();
let xs1 = xs.narrow(-1, 0, last_dim / 2);
let xs2 = xs.narrow(-1, last_dim / 2, last_dim - last_dim / 2);
Tensor::cat(&[&xs2.neg(), &xs1], -1)
crate::mps_compat::cat(&[&xs2.neg(), &xs1], -1)
}

impl RoPECache {
Expand Down Expand Up @@ -227,7 +227,7 @@ impl RoPECache {
let freqs = inv_freq_expanded.matmul(&position_ids_expanded); // [pos_b, head_dim_2, seq_len]
let freqs = freqs.transpose(1, 2); // [pos_b, seq_len, head_dim_2]

let emb = Tensor::cat(&[&freqs, &freqs], -1); // [pos_b, seq_len, head_dim]
let emb = crate::mps_compat::cat(&[&freqs, &freqs], -1); // [pos_b, seq_len, head_dim]

let mut cos = emb.cos();
let mut sin = emb.sin();
Expand Down
9 changes: 6 additions & 3 deletions shared/modeling/src/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,22 +197,25 @@ impl Batch {

let padding_input_ids =
Tensor::zeros([padding_needed as i64, seq_len], (Kind::Int64, device));
gpu_data.input_ids = Tensor::cat(&[&gpu_data.input_ids, &padding_input_ids], 0);
gpu_data.input_ids =
crate::mps_compat::cat(&[&gpu_data.input_ids, &padding_input_ids], 0);

if let Some(labels) = gpu_data.labels.take() {
let padding_labels = Tensor::full(
[padding_needed as i64, seq_len],
-100i64,
(Kind::Int64, device),
);
gpu_data.labels = Some(Tensor::cat(&[&labels, &padding_labels], 0));
gpu_data.labels =
Some(crate::mps_compat::cat(&[&labels, &padding_labels], 0));
}

if gpu_data.position_ids.is_some() {
let pos_row = Tensor::arange(seq_len, (Kind::Int64, device));
let padding_pos = pos_row.unsqueeze(0).repeat([padding_needed as i64, 1]);
if let Some(pos) = gpu_data.position_ids.take() {
gpu_data.position_ids = Some(Tensor::cat(&[&pos, &padding_pos], 0));
gpu_data.position_ids =
Some(crate::mps_compat::cat(&[&pos, &padding_pos], 0));
}
}

Expand Down
Loading