Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 105 additions & 31 deletions picolm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
import torch
import typer
import wandb
from datasets import IterableDataset, IterableDatasetDict, load_dataset
from lightning.fabric import Fabric
from lightning.fabric.strategies import FSDPStrategy
from datasets import IterableDataset, IterableDatasetDict, load_dataset # IterableDataset already here
from pydantic import BaseModel
from rich.console import Console
from rich.logging import RichHandler

from pico.infer import infer
from pico.model import PICO_XS_PRESET, Pico
from pico.serialization import load, load_metadata, save
from pico.serialization import load, load_metadata
from pico.train import DEFAULT_TRAINING_META, TrainingMeta, TrainingStep, train

console = Console()
Expand Down Expand Up @@ -199,6 +201,10 @@ def train_command(
str,
typer.Argument(help="HuggingFace dataset path"),
],
devices: Annotated[
int,
typer.Option(help="Number of devices to use for training"),
] = torch.cuda.device_count(),
dataset_column_name: Annotated[
str,
typer.Option(help="Name of the column in the dataset to use for training."),
Expand Down Expand Up @@ -240,26 +246,16 @@ def train_command(
warmup_steps: int | None = None,
grad_accumulation_steps: int | None = None,
validation_interval: int | None = None,
test_run: Annotated[
bool,
typer.Option(help="Enable a quick test run with a minimal dataset and config."),
] = False,
):
if not path.is_dir() or not (path / "model.json").exists():
logger.error(f"Invalid model directory: {path}")
raise typer.Exit(code=1)

# Load datasets
train_dataset, validation_dataset = _load_dataset(
dataset_path,
dataset_train_file,
dataset_validation_file,
dataset_train_split,
dataset_validation_split,
dataset_validation_size,
)

# Prepare datasets
train_dataset = _prepare_dataset(train_dataset, dataset_column_name, shuffle=True)
validation_dataset = _prepare_dataset(validation_dataset, dataset_column_name)

# Init training metadata
# Init training metadata - moved up to be potentially overridden by test_run
training_meta = DEFAULT_TRAINING_META.model_copy()
_update_if_set(
training_meta,
Expand Down Expand Up @@ -290,33 +286,109 @@ def train_command(
tracker_project_name=tracker_project_name or path.name,
)
training_logs_file = run_directory / "training.json"
training_logs_file.write_text(training_logs.model_dump_json(indent=2))
# training_logs_file.write_text(training_logs.model_dump_json(indent=2)) # Moved down after potential meta override

if test_run:
logger.info("--- TEST RUN ENABLED ---")
logger.info("Overriding dataset to a minimal in-memory dataset.")
dummy_data = [
{"bytes": b"This is a short test sentence."},
{"bytes": b"Another piece of text for testing."},
{"bytes": b"Pico is learning, byte by byte."},
{"bytes": b"A final entry in our tiny dataset."},
]
train_dataset = IterableDataset.from_iterable(dummy_data[:3])
validation_dataset = IterableDataset.from_iterable(dummy_data[3:])

# Mock num_shards attribute if needed by DataLoader setup in train function
# This is a simple way to ensure format_dataset in train.py doesn't break
# if it expects num_shards for DataLoader(num_workers=...)
train_dataset.num_shards = 1
validation_dataset.num_shards = 1

logger.info(f"Using {len(dummy_data)} dummy samples for test run.")

logger.info("Overriding training metadata for a minimal test run.")
training_meta.max_steps = 2
training_meta.batch_size = 1
training_meta.context_len = 32
training_meta.warmup_steps = 1
training_meta.validation_interval = 2
training_meta.epochs = 1
training_meta.grad_accumulation_steps = 1

wandb_run = wandb.init(mode="disabled")
logger.info("W&B tracking disabled for test run.")
else:
# Load datasets
train_dataset, validation_dataset = _load_dataset(
dataset_path,
dataset_train_file,
dataset_validation_file,
dataset_train_split,
dataset_validation_split,
dataset_validation_size,
)

wandb_run = wandb.init(
project=tracker_project_name or path.name,
name=f"{run_name}-{''.join(random.choice('abcdefghijklmnopqrstuvwxyz') for _ in range(3))}",
config=training_meta.model_dump(),
)
# Prepare datasets
train_dataset = _prepare_dataset(train_dataset, dataset_column_name, shuffle=True)
validation_dataset = _prepare_dataset(validation_dataset, dataset_column_name)

wandb_run = wandb.init(
project=tracker_project_name or path.name,
name=f"{run_name}-{''.join(random.choice('abcdefghijklmnopqrstuvwxyz') for _ in range(3))}",
config=training_meta.model_dump(),
)

# Write training logs *after* potential override by test_run
training_logs_file.write_text(training_logs.model_dump_json(indent=2))

# Load model
if base_weights is None:
model = Pico(load_metadata(path / "model.json"))
else:
metadata_file = base_weights.parent.parent / "model.json"
model = load(base_weights, metadata_file)
# Always initialize a new model instance
model_meta = load_metadata(path / "model.json")
model = Pico(model_meta)

model = torch.compile(model)
logger.debug("Model compiled")

# Initialize Fabric based on devices
if devices == 0:
accelerator = "cpu"
num_fabric_devices = 1
strategy_fabric = "auto"
precision_fabric = None
elif devices == 1:
accelerator = "cuda"
num_fabric_devices = 1
strategy_fabric = "auto"
precision_fabric = "bf16-mixed"
else: # devices > 1
accelerator = "cuda"
num_fabric_devices = devices
strategy_fabric = FSDPStrategy(state_dict_type="full")
precision_fabric = "bf16-mixed"
torch._dynamo.config.optimize_ddp = False

fabric = Fabric(
accelerator=accelerator,
devices=num_fabric_devices,
strategy=strategy_fabric,
precision=precision_fabric,
)
fabric.launch()

# Train loop
tm1 = time.time()
step: TrainingStep | None = None

for step in train(
model,
fabric,
train_dataset,
validation_dataset=validation_dataset,
training_meta=training_meta,
devices=devices, # This argument is no longer used by pico.train.train but is kept for compatibility with fabric init
Copy link

Copilot AI May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Since the 'devices' argument is not used in the pico.train.train function, consider removing it from the call to reduce potential confusion and improve maintainability.

Suggested change
devices=devices, # This argument is no longer used by pico.train.train but is kept for compatibility with fabric init

Copilot uses AI. Check for mistakes.
base_weights_path=base_weights,
test_run=test_run,
):
tm2 = time.time()

Expand All @@ -332,15 +404,17 @@ def train_command(
if step.i % 500 == 0:
logger.info(f"Saving checkpoint at step: {step.i}")

save(model, run_directory / f"{step.i:08d}.safetensors")
checkpoint_path = run_directory / f"{step.i:08d}.safetensors"
fabric.save(checkpoint_path, {"model": model})
training_logs.checkpoints.append(step)
training_logs_file.write_text(training_logs.model_dump_json(indent=2))

tm1 = time.time()

# Save last step
if step and step.i not in [s.i for s in training_logs.checkpoints]:
save(model, run_directory / f"{step.i:08d}.safetensors")
checkpoint_path = run_directory / f"{step.i:08d}.safetensors"
fabric.save(checkpoint_path, {"model": model})
training_logs.checkpoints.append(step)
training_logs_file.write_text(training_logs.model_dump_json(indent=2))

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"datasets>=3.0.2",
"einops>=0.8.0",
"flash-attn",
"lightning>=2.5.0.post0",
"pydantic>=2.10.6",
"safetensors>=0.5.2",
"setuptools>=75.8.0",
Expand Down
77 changes: 52 additions & 25 deletions src/pico/train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import math
import pathlib
from typing import Literal

import einops
import torch
from datasets import Dataset, IterableDataset
from lightning.fabric import Fabric
from pydantic import BaseModel, computed_field
from torch.amp import GradScaler
from torch.nn import DataParallel
from torch.nn import functional as F
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -196,18 +196,17 @@ def loss_fn(
def get_validation_metrics(
model: Pico,
dataloader: DataLoader,
device: torch.device,
) -> TrainingStepMetrics:
total_loss = 0.0
total_aux_loss = 0.0
total_next_token_loss = 0.0
num_steps = 0

for data in dataloader:
x = data["x"].to(device)
y = data["y"].to(device)
x = data["x"]
y = data["y"]

with torch.autocast(device.type, dtype=torch.bfloat16):
with torch.autocast("cuda", dtype=torch.bfloat16):
Copy link

Copilot AI May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The autocast context is hard-coded to 'cuda', which may cause issues when running on CPU. Consider using fabric.accelerator to dynamically set the device context to ensure compatibility with both CPU and GPU runs.

Suggested change
with torch.autocast("cuda", dtype=torch.bfloat16):
fabric = Fabric()
with torch.autocast(fabric.device.type, dtype=torch.bfloat16):

Copilot uses AI. Check for mistakes.
pred, router_weights, router_decisions, _ = model(x)

loss, aux_loss, next_token_lm_loss = loss_fn(
Expand All @@ -233,31 +232,49 @@ def get_validation_metrics(

def train(
model: Pico,
fabric: Fabric,
dataset: Dataset | IterableDataset,
validation_dataset: Dataset | IterableDataset | None = None,
training_meta: TrainingMeta = DEFAULT_TRAINING_META,
base_weights_path: pathlib.Path | None = None,
test_run: bool = False, # New parameter
):
if test_run:
# For test runs, especially on CPU, num_workers=0 can be more stable/faster.
# For GPU test runs, 1 might be fine.
loader_num_workers = 0 if fabric.accelerator == "cpu" else 1
fabric.print(f"Test run: Setting DataLoader num_workers to {loader_num_workers}")
else:
# Existing logic, ensure dataset has num_shards or provide a default.
# The IterableDataset from HuggingFace usually has num_shards.
loader_num_workers = dataset.num_shards if hasattr(dataset, "num_shards") and dataset.num_shards > 0 else 1
if not hasattr(dataset, "num_shards") or dataset.num_shards <= 0 :
fabric.print(f"Warning: dataset.num_shards not available or invalid. Defaulting num_workers to 1.")

dataloader = DataLoader(
format_dataset(dataset, model.metadata, training_meta),
batch_size=training_meta.batch_size,
pin_memory=True,
num_workers=dataset.num_shards,
pin_memory=True, # pin_memory is fine for CPU if memory allows, no major harm.
num_workers=loader_num_workers,
)

validation_dataloader = None
if validation_dataset is not None:
# Apply similar logic for validation_dataloader num_workers
if test_run:
val_loader_num_workers = 0 if fabric.accelerator == "cpu" else 1
else:
val_loader_num_workers = validation_dataset.num_shards if hasattr(validation_dataset, "num_shards") and validation_dataset.num_shards > 0 else 1
if not hasattr(validation_dataset, "num_shards") or validation_dataset.num_shards <= 0:
fabric.print(f"Warning: validation_dataset.num_shards not available or invalid. Defaulting num_workers for validation to 1.")

validation_dataloader = DataLoader(
format_dataset(validation_dataset, model.metadata, training_meta),
batch_size=training_meta.batch_size,
pin_memory=True,
num_workers=validation_dataset.num_shards,
num_workers=val_loader_num_workers,
)

device = torch.device("cuda")
model = DataParallel(model)
model = model.to(device)
model.train()

trainable_params = {
name: param for name, param in model.named_parameters() if param.requires_grad
}
Expand All @@ -281,30 +298,40 @@ def train(
lr=training_meta.learning_rate,
)

scaler = GradScaler()
model, optimizer = fabric.setup(model, optimizer)
dataloader, validation_dataloader = fabric.setup_dataloaders(
dataloader, validation_dataloader
)

if base_weights_path:
fabric.print(f"Loading base weights from {base_weights_path}")
states_to_load = {"model": model}
fabric.load(base_weights_path, states_to_load)

model.train()

# Training loop
step = 0
for epoch in range(training_meta.epochs):
for data in dataloader:
x = data["x"].to(device)
y = data["y"].to(device)
x = data["x"]
y = data["y"]

with torch.autocast(device.type, dtype=torch.bfloat16):
pred, router_weights, router_decisions, _ = model(x)
pred, router_weights, router_decisions, _ = model(x)

loss, aux_loss, next_token_lm_loss = loss_fn(
pred, router_weights, router_decisions, y
)

validation_metrics = None
if (
validation_dataloader is not None
fabric.global_rank == 0
and validation_dataloader is not None
and step % training_meta.validation_interval == 0
and step > 0
):
validation_metrics = get_validation_metrics(
model, validation_dataloader, device=device
model, validation_dataloader
)

training_step = TrainingStep(
Expand All @@ -318,18 +345,18 @@ def train(
validation=validation_metrics,
)

yield training_step
if fabric.global_rank == 0:
yield training_step

scaler.scale(loss).backward()
fabric.backward(loss)

if (step + 1) % training_meta.grad_accumulation_steps == 0:
# Update learning rate according to schedule before next optimizer step
lr = lr_schedule(step, training_meta)
for param_group in optimizer.param_groups:
param_group["lr"] = lr

scaler.step(optimizer)
scaler.update()
optimizer.step()
optimizer.zero_grad()

step += 1
Loading