Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/python/psyche/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def make_causal_lm(
param_dtype: torch.dtype = torch.bfloat16,
reduce_dtype: torch.dtype = torch.float32,
fsdp_modules: Optional[Iterable[str]] = None,
compile: bool = True,
) -> CausalLM:
if not isinstance(device, torch.device):
device = torch.device(device if isinstance(device, str) else f"cuda:{device}")
Expand Down Expand Up @@ -45,5 +46,6 @@ def make_causal_lm(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
fsdp_modules=fsdp_modules,
compile=compile,
)
raise ValueError(f"Unknown architecture {architecture}")
3 changes: 2 additions & 1 deletion python/python/psyche/models/ttitan.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def from_pretrained(
param_dtype: torch.dtype = torch.bfloat16,
reduce_dtype: torch.dtype = torch.float32,
fsdp_modules: Optional[Iterable[str]] = None,
compile: bool = True,
):
config_json = None
if isinstance(source, PretrainedSourceStateDict):
Expand All @@ -302,7 +303,7 @@ def from_pretrained(

job_config = JobConfig()
job_config.training.seq_len = config_tt.max_seq_len
job_config.compile.enable = True
job_config.compile.enable = compile
job_config.compile.components = ["model", "loss"]
job_config.compile.fullgraph = False
job_config.activation_checkpoint.mode = "full"
Expand Down
2 changes: 1 addition & 1 deletion shared/client/src/state/evals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ impl ModelTaskRunner {

pub fn start(&self, trainers: Vec<Trainer>) -> RunningEvals {
let cancel = CancellationToken::new();
trace!("Starting evals!");
info!("Starting evals!");

RunningEvals {
cancel: cancel.clone(),
Expand Down
3 changes: 2 additions & 1 deletion shared/client/src/state/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use psyche_modeling::Trainer;
use psyche_network::P2PEndpointInfo;
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokenizers::Tokenizer;
use tracing::{debug, trace, warn};
use tracing::{debug, info, trace, warn};
use wandb::{DataValue, LogData};

use crate::state::evals::{EnumModelTask, PROMPT_TASK_NAME};
Expand Down Expand Up @@ -125,6 +125,7 @@ impl StatsLogger {
.collect::<String>()
);
round_log.insert(formatted_key.clone(), val);
info!("{}: {:.4}", formatted_key, val);

self.metrics.record_eval_metric(&key, val);
}
Expand Down
29 changes: 16 additions & 13 deletions shared/eval/examples/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,22 @@ fn run_data_parallel(
{
psyche_python_extension_impl::init_embedded_python()?;

Box::new(psyche_modeling::PythonDistributedCausalLM::new(
python_arch,
psyche_modeling::PretrainedSource::RepoFiles(repo),
device,
psyche_modeling::AttentionImplementation::default(),
psyche_modeling::ParallelismConfig {
dp: data_parallelism,
tp: 1,
},
None,
None,
None,
)?) as Box<dyn CausalLM>
Box::new(
psyche_modeling::PythonDistributedCausalLM::new_with_options(
python_arch,
psyche_modeling::PretrainedSource::RepoFiles(repo),
device,
psyche_modeling::AttentionImplementation::default(),
psyche_modeling::ParallelismConfig {
dp: data_parallelism,
tp: 1,
},
None,
None,
None,
false, // disable torch.compile for eval — avoids recompilation on every different sequence length
)?,
) as Box<dyn CausalLM>
}

#[cfg(not(feature = "python"))]
Expand Down
24 changes: 23 additions & 1 deletion shared/modeling/src/python_causal_lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,26 @@ impl PythonCausalLM {
attn_implementation: AttentionImplementation,
parallelism: Option<ParallelismConfig>,
override_max_position_embeddings: Option<usize>,
) -> Result<PythonCausalLM, PythonCausalLMError> {
Self::new_with_options(
architecture,
source,
device,
attn_implementation,
parallelism,
override_max_position_embeddings,
true,
)
}

pub fn new_with_options(
architecture: &str,
source: &PretrainedSource<PythonModelConfig>,
device: Device,
attn_implementation: AttentionImplementation,
parallelism: Option<ParallelismConfig>,
override_max_position_embeddings: Option<usize>,
compile: bool,
) -> Result<PythonCausalLM, PythonCausalLMError> {
let config = source.get_config()?;
let result: PyResult<PyObject> = Python::with_gil(|py| {
Expand Down Expand Up @@ -176,7 +196,9 @@ impl PythonCausalLM {
parallelism.as_ref().map(|x| x.tp).unwrap_or(1),
override_max_position_embeddings,
);
let causal_lm = make_causal_lm.call1(args)?;
let kwargs = PyDict::new(py);
kwargs.set_item("compile", compile)?;
let causal_lm = make_causal_lm.call(args, Some(&kwargs))?;
Ok(causal_lm.unbind())
});
let causal_lm = result?;
Expand Down
27 changes: 26 additions & 1 deletion shared/modeling/src/python_distributed_causal_lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,30 @@ impl PythonDistributedCausalLM {
override_max_position_embeddings: Option<usize>,
port: Option<u16>,
num_local_ranks: Option<i64>,
) -> Result<Self, PythonDistributedCausalLMError> {
Self::new_with_options(
architecture,
source,
device,
attn_implementation,
parallelism,
override_max_position_embeddings,
port,
num_local_ranks,
true,
)
}

pub fn new_with_options(
architecture: String,
source: PretrainedSource<PythonModelConfig>,
device: Device,
attn_implementation: AttentionImplementation,
parallelism: ParallelismConfig,
override_max_position_embeddings: Option<usize>,
port: Option<u16>,
num_local_ranks: Option<i64>,
compile: bool,
) -> Result<Self, PythonDistributedCausalLMError> {
if !tch::Cuda::is_available() {
return Err(PythonDistributedCausalLMError::CUDANotAvailable);
Expand Down Expand Up @@ -323,13 +347,14 @@ impl PythonDistributedCausalLM {
}
comm.set("dp", &format!("{}", parallelism.dp))?;
comm.set("tp", &format!("{}", parallelism.tp))?;
let local = PythonCausalLM::new(
let local = PythonCausalLM::new_with_options(
&architecture,
&source,
device,
attn_implementation,
Some(parallelism),
override_max_position_embeddings,
compile,
)?;
Ok((comm, local))
})
Expand Down
Loading