Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,9 @@ checkpoints

# MAS (madragonse)
*mastest*

*.csv
tmp/
data_long_ctx/
data_test/
checkpoint/
5 changes: 4 additions & 1 deletion configs/_eval/basic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ evaluator:
- wikitext
- piqa
limit: null # No limit, evaluate on full datasets
device: cuda
device: cuda
max_length: ${common.sequence_length}
max_gen_toks: 256
batch_size: 1
8 changes: 4 additions & 4 deletions configs/_misc/default.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@

infrastructure:
metric_logger:
project_name: pmtest/llm-random
heavy_metrics_calculation_interval: 100
new_neptune_job: true
new_wandb_job: true
type: neptune
name: default
type: wandb
wandb_entity: ideas_cv
project_name: llm-random-test
tags:
- nano
- new_wandb_job

git:
remote_name: cemetery
Expand Down
1 change: 1 addition & 0 deletions configs/_trainer/llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ trainer:
_target_: src.core.trainer.Trainer
eval_interval: 100
n_eval_steps: 10
lm_eval_interval: 0 # 0 = disabled; set to N to run lm_eval every N steps during training
gradient_accumulation_steps: 1
gradient_clipping: 1.0
n_steps: null
Expand Down
3 changes: 2 additions & 1 deletion configs/tiny_local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ common:

trainer:
gradient_accumulation_steps: 1
n_steps: 100
n_steps: 201
learning_rate: 1e-3
lm_eval_interval: 100 # 0 = disabled; set to N to run lm_eval every N steps during training

checkpoint:
save:
Expand Down
10 changes: 6 additions & 4 deletions configs/tiny_remote.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- _cluster@_here_: entropy_a100
- _cluster@_here_: entropy
- _model@_here_: tiny
- _trainer@_here_: llama
- _dataset@_here_: c4
Expand All @@ -13,8 +13,9 @@ common:

trainer:
gradient_accumulation_steps: 1
n_steps: 100
n_steps: 1001
learning_rate: 1e-3
lm_eval_interval: 100 # 0 = disabled; set to N to run lm_eval every N steps during training

checkpoint:
save:
Expand All @@ -25,6 +26,7 @@ infrastructure:
max_concurrent_jobs: 1

metric_logger:
# type: stdout
name: tiny_remote
tags:
- nano
Expand All @@ -33,8 +35,8 @@ infrastructure:

slurm:
time: "00:10:00"
gres: gpu:1
gres: gpu:2
job-name: ${infrastructure.metric_logger.name}

evaluator:
limit: 10
limit: 3
43 changes: 0 additions & 43 deletions configs/tiny_remote_wandb.yaml

This file was deleted.

114 changes: 46 additions & 68 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
import logging
from hydra.utils import instantiate
import logging
from neptune.integrations.python_logger import NeptuneHandler
from src.core.checkpointing import (
load_checkpoint_from_file,
load_training_state,
get_full_checkpoint_path,
)
from src.core.metric_loggers import NeptuneLogger, WandbLogger, get_metric_logger
from src.core.metric_loggers import WandbLogger, get_metric_logger
from src.core.model import Residual
import platform

Expand Down Expand Up @@ -61,9 +60,7 @@ def upload_config_file(metric_logger):
slurm_array_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
file_path = f"generated_configs/config_{slurm_array_task_id}.yaml"
if slurm_array_task_id is not None and os.path.exists(file_path):
metric_logger.run["yaml_config"].upload(
f"generated_configs/config_{slurm_array_task_id}.yaml"
)
metric_logger.run.save(file_path)


def check_env_vars():
Expand Down Expand Up @@ -179,8 +176,8 @@ def log_environs(metric_logger):
]

environs = os.environ
for environ_key in scrap_keys:
metric_logger.run[f"job/{environ_key}"] = str(environs.get(environ_key))
env_dict = {f"job/{k}": str(environs.get(k)) for k in scrap_keys}
metric_logger.run.config.update(env_dict)


def get_device():
Expand Down Expand Up @@ -219,27 +216,9 @@ def initialize_training_components(cfg: OmegaConf, metric_logger=None):
full_config=cfg,
)

# Other loggers do not have `run` method
if isinstance(metric_logger, NeptuneLogger):
npt_handler = NeptuneHandler(run=metric_logger.run)
logger.addHandler(npt_handler)

learning_rate, exp_lr = solve_config_lr(cfg.trainer.learning_rate)

if isinstance(metric_logger, NeptuneLogger) and (
training_state["run_id"] is None
or cfg.infrastructure.metric_logger.new_neptune_job
):
metric_logger.run["job_config"] = cfg
upload_config_file(metric_logger)
log_environs(metric_logger)
metric_logger.run[f"job/full_save_checkpoints_path"] = get_full_checkpoint_path(
cfg.trainer.checkpoint.save.path
)
metric_logger.run["learning_rate"] = learning_rate
metric_logger.run["exp_lr"] = exp_lr

elif isinstance(metric_logger, WandbLogger) and (
if isinstance(metric_logger, WandbLogger) and (
training_state["run_id"] is None
or cfg.infrastructure.metric_logger.new_wandb_job
):
Expand Down Expand Up @@ -325,48 +304,47 @@ def run(cfg: OmegaConf, metric_logger=None):
cfg, metric_logger
)

if model is not None:
logger.info(f"Model initialized")

trainer = instantiate(cfg.trainer)

if "distillation" in cfg:
if cfg.distillation.load.type == "huggingface":
teacher_model = instantiate(
cfg.distillation.teacher_model, _convert_="all"
).to(get_device())
copy_llama_model_weights_from_HF(
teacher_model, cfg.distillation.load.path
)
teacher_model = setup_distributed_training(
teacher_model, cfg.trainer.teacher_distributed
)
elif cfg.distillation.load.type == "pc_memeff_base":
teacher_model = model.source_model

trainer(
teacher_model=teacher_model,
model=model,
optimizer=optimizer,
scheduler=scheduler,
training_state=training_state,
metric_logger=metric_logger,
).train()
else:
trainer(
model=model,
optimizer=optimizer,
scheduler=scheduler,
training_state=training_state,
metric_logger=metric_logger,
).train()

# TODO
# finetuning

evaluator = instantiate(cfg.evaluator)
if evaluator is not None:
evaluator(metric_logger=metric_logger).eval()
logger.info(f"Model initialized")

evaluator_partial = instantiate(cfg.evaluator)
lm_evaluator = (
evaluator_partial(metric_logger=metric_logger, model=model)
if evaluator_partial is not None
else None
)

trainer = instantiate(cfg.trainer)

if "distillation" in cfg:
if cfg.distillation.load.type == "huggingface":
teacher_model = instantiate(
cfg.distillation.teacher_model, _convert_="all"
).to(get_device())
copy_llama_model_weights_from_HF(teacher_model, cfg.distillation.load.path)
teacher_model = setup_distributed_training(
teacher_model, cfg.trainer.teacher_distributed
)
elif cfg.distillation.load.type == "pc_memeff_base":
teacher_model = model.source_model

trainer(
teacher_model=teacher_model,
model=model,
optimizer=optimizer,
scheduler=scheduler,
training_state=training_state,
metric_logger=metric_logger,
evaluator=lm_evaluator,
).train()
else:
trainer(
model=model,
optimizer=optimizer,
scheduler=scheduler,
training_state=training_state,
metric_logger=metric_logger,
evaluator=lm_evaluator,
).train()

cleanup()

Expand Down
Loading
Loading