From e9acb727fdf72979d4d6224e1d679adc4a032d01 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 17:36:25 +0800 Subject: [PATCH 01/30] docs: add design philosophy guide --- docs/philosophy.md | 128 +++++++++++++++++++++++++++++++++++++++++++++ mkdocs.yml | 1 + 2 files changed, 129 insertions(+) create mode 100644 docs/philosophy.md diff --git a/docs/philosophy.md b/docs/philosophy.md new file mode 100644 index 0000000..911ad84 --- /dev/null +++ b/docs/philosophy.md @@ -0,0 +1,128 @@ +# Design Philosophy + +TinyExp is intentionally small. + +It is not trying to become a heavy training framework, a trainer abstraction, or a system that takes over the +user's training loop. Its goal is much narrower and, for many research and iteration workflows, much more useful: +make experiment code easy to define, easy to launch, and easy to evolve without hiding plain PyTorch. + +## What TinyExp Is + +TinyExp is best understood as an experiment entry framework with a few lightweight helpers. + +It focuses on: + +- keeping the experiment definition as the main entrypoint +- making configuration explicit and easy to override +- supporting multiple launch styles without changing experiment code too much +- keeping user code close to normal PyTorch +- reducing repeated experiment "plumbing" without owning the full training lifecycle + +In practice, TinyExp wants the file you edit to remain the file you run. + +## What TinyExp Is Not + +TinyExp is intentionally not designed to be: + +- a general-purpose trainer framework +- a runtime system that owns epoch/step control flow +- a callback-heavy abstraction layer +- a framework that forces users into a single lifecycle or DSL +- a system that hides the actual training loop behind too many layers + +If a feature would make experiments feel less like plain PyTorch and more like framework ceremony, it is usually a +bad fit for TinyExp. + +## Core Principles + +### 1. The experiment is the entrypoint + +The experiment definition should be the center of the workflow. + +Users should not need to spread one experiment across many disconnected files just to launch, configure, and run it. +The experiment class should stay readable, local, and easy to reason about. + +### 2. Explicit is better than implicit + +TinyExp prefers explicit calls over hidden side effects. + +For example, integrations with external systems such as W&B should remain explicit. A config object can expose the +ability to build an integration, but the user should still decide when to call it. + +### 3. Keep the training loop in user space + +The training loop is often the most task-specific part of an experiment. TinyExp should not rush to abstract it into +a universal trainer. + +Users should be able to: + +- write their own training loop +- define their own evaluation logic +- control when to validate, log, save, or resume +- stay in plain PyTorch as much as possible + +### 4. Helpers are good; control frameworks are not + +TinyExp should provide thin, reusable helpers for common experiment chores, such as: + +- output directory setup +- config dumping +- lightweight metric logging +- checkpoint save/load helpers +- launcher compatibility + +These helpers reduce repeated boilerplate without dictating how the user structures training. + +### 5. Examples are recipes, not just demos + +Examples in TinyExp are not only meant to showcase features. They should also serve as reusable recipes and +inheritance-friendly templates. + +That means examples should remain understandable and useful as starting points for real projects. When common logic +emerges across multiple examples, it may be worth extracting a small helper or a recipe base class. But that logic +should only move into the framework when it is broadly useful and still keeps the system light. + +### 6. Framework-level additions must earn their place + +A good question for any new feature is: + +Does this reduce repeated experiment plumbing while preserving user control? + +If yes, it may belong in TinyExp. + +If it starts to own the user workflow, hide core control flow, or push the project toward a heavy trainer-style +architecture, it probably does not belong in TinyExp. + +## Recommended Boundary + +### TinyExp should own + +- configuration structure and CLI overrides +- experiment entry and launch ergonomics +- lightweight utilities shared across many experiments +- minimal artifact helpers that do not take over control flow + +### Examples or user experiments should own + +- model construction details +- training and evaluation loops +- task-specific metrics +- validation timing +- checkpointing policy such as what counts as "best" +- external integration timing and usage + +This boundary keeps TinyExp small while still making it genuinely useful. + +## Design Direction for Future Development + +When extending TinyExp, prefer: + +- small helpers over large abstractions +- explicit calls over automatic behavior +- recipe-style examples over framework-owned trainers +- local clarity over generic indirection +- composable building blocks over lifecycle machinery + +In short: + +TinyExp should help users write less experiment plumbing, not less experiment logic. diff --git a/mkdocs.yml b/mkdocs.yml index 82c6533..7bb1274 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -9,6 +9,7 @@ copyright: Maintained by zengarden. nav: - Home: index.md + - Philosophy: philosophy.md - Modules: modules.md plugins: - search From 9d5863eab9190f4e7a705e24ef109f065edf198d Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 17:44:00 +0800 Subject: [PATCH 02/30] docs: clarify project philosophy in readme --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index 8e6caaf..a3e66df 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,23 @@ TinyExp focuses on simple, maintainable experiment management: - Your config stays structured and easy to override. - Your execution path stays consistent as experiments grow. +## Design Philosophy + +TinyExp is intentionally light. + +It is not trying to be a heavy trainer framework that owns your epoch loop, callback system, or full runtime +lifecycle. Instead, it focuses on a smaller goal: + +- keep the experiment itself as the main entrypoint +- keep the training loop in user space +- make configuration and launch behavior explicit +- provide thin helpers rather than framework-owned control flow +- treat examples as reusable recipes, not just demos + +In short, TinyExp should help you write less experiment plumbing, not less experiment logic. + +For a longer explanation, see [`docs/philosophy.md`](docs/philosophy.md). + ## Quick Start (1 Minute) ### Option A: Install with pip and use import-based entrypoint From 2ae4ea88720ac924f3d4d924cf686101ab6eef40 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 17:45:21 +0800 Subject: [PATCH 03/30] docs: align homepage with project philosophy --- docs/index.md | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/docs/index.md b/docs/index.md index 06881fc..7f25e25 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,10 +2,39 @@ A minimalist Python project for deep learning experiment management. -TinyExp lets you launch experiments with one click: the file you edit becomes the entrypoint to your experiment. +TinyExp keeps one idea at the center: +your configured experiment is your entrypoint. -# Features +Instead of splitting config, launcher, and execution across many files, TinyExp keeps them together in one experiment +definition so iteration stays fast and predictable. + +## What You Get + +- Experiment-centered configuration with Hydra/OmegaConf +- CLI overrides without rewriting code +- Training loops that stay close to plain PyTorch +- The same experiment definition from local debug to distributed launch + +## Design Philosophy + +TinyExp is intentionally light. + +It is not trying to be a heavy trainer framework that owns your epoch loop, callback system, or full runtime +lifecycle. Instead, it focuses on a smaller and more explicit goal: + +- keep the experiment itself as the main entrypoint +- keep the training loop in user space +- make configuration and launch behavior explicit +- provide thin helpers instead of framework-owned control flow +- treat examples as reusable recipes, not just demos + +In short, TinyExp should help you write less experiment plumbing, not less experiment logic. + +For the longer version, see [Design Philosophy](philosophy.md). + +## Features - ๐Ÿš€ One-click experiment launch: The file you edit becomes the entrypoint to your experiment. -- ๐Ÿ“Š Experiment tracking: Track your experiments with W&B. -- ๐Ÿ”„ Experiment management: Manage your experiments configuration with Hydra. +- ๐Ÿ”„ Config-driven experiment management with Hydra. +- ๐Ÿงฉ Thin helpers without taking over your training loop. +- ๐Ÿงช Examples that can serve as reusable experiment recipes. From e5a1c232f2fa4100cef9887910db05d608c50686 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 17:52:39 +0800 Subject: [PATCH 04/30] docs: add phase 1 minimal helpers plan --- docs/phase1-minimal-helpers.md | 246 +++++++++++++++++++++++++++++++++ mkdocs.yml | 1 + 2 files changed, 247 insertions(+) create mode 100644 docs/phase1-minimal-helpers.md diff --git a/docs/phase1-minimal-helpers.md b/docs/phase1-minimal-helpers.md new file mode 100644 index 0000000..4edab37 --- /dev/null +++ b/docs/phase1-minimal-helpers.md @@ -0,0 +1,246 @@ +# Phase 1: Minimal Helpers Plan + +This document describes the first implementation phase that follows TinyExp's design philosophy. + +The goal is not to turn TinyExp into a trainer framework. The goal is to add a small set of reusable helpers that +remove repeated experiment plumbing while keeping training loops in user space. + +For the broader principles behind these choices, see [Design Philosophy](philosophy.md). + +## Why This Phase Exists + +TinyExp already has a clear core direction: + +- the experiment is the entrypoint +- configuration is explicit and override-friendly +- launch behavior should stay simple +- users should keep control of their training loop + +What is still missing is a minimal layer for common experiment chores that many examples will otherwise repeat by hand. + +This phase adds only thin helpers for those chores. It does not add a trainer, runtime, callback engine, or framework +owned lifecycle. + +## Goals + +Phase 1 should make experiments easier to run and maintain without changing TinyExp's character. + +The goals are: + +- keep plain PyTorch style intact +- reduce repeated setup code in examples +- improve reproducibility through lightweight artifacts +- make resume/eval workflows easier +- create a stable base for future examples and recipe-style inheritance + +## Non-Goals + +Phase 1 explicitly does not aim to add: + +- a generic trainer abstraction +- a runtime layer that owns epoch or step flow +- a callback or hook engine +- automatic external tracker initialization +- a framework-wide best-model policy system +- a heavy experiment lifecycle API + +If a feature starts to own the user workflow instead of helping it, it is out of scope for this phase. + +## Minimal Additions to TinyExp + +The base `TinyExp` class should remain small. This phase only proposes a few minimal additions. + +### New fields + +Recommended additions: + +- `mode: str = "train"` +- `resume_from: str = ""` + +These are intentionally minimal: + +- `mode` provides a small, explicit switch for training, validation, and config help flows +- `resume_from` provides a standard path for loading a checkpoint + +More policy-driven settings should stay in examples unless they prove broadly reusable. + +### New helper methods + +The following methods are the proposed Phase 1 surface area: + +- `get_run_dir() -> str` +- `ensure_run_dir() -> str` +- `dump_config(path: str | None = None) -> str` +- `log_metrics(metrics: dict, *, step: int | None = None, epoch: int | None = None, filename: str = "metrics.jsonl") -> None` +- `save_checkpoint(...) -> str` +- `load_checkpoint(...) -> dict` +- `maybe_resume(...) -> dict | None` + +These are helpers, not control-flow abstractions. + +## Artifact Conventions + +Phase 1 should establish simple, stable artifact conventions. + +The recommended default run layout is: + +- `output//config.yaml` +- `output//metrics.jsonl` +- `output//last.ckpt` +- `output//best.ckpt` +- `output//log.txt` + +This layout is intentionally straightforward. It improves usability and reproducibility without introducing a heavy run +management system. + +## Helper Behavior + +### Run directory helpers + +`get_run_dir()` should return the default run directory for the current experiment. + +`ensure_run_dir()` should create that directory if needed and return it. + +These helpers should not introduce a large naming or versioning system in Phase 1. + +### Config dumping + +`dump_config()` should write the effective experiment configuration to YAML. + +Expected behavior: + +- default path is `/config.yaml` +- output reflects current config state after overrides +- writing should happen only from the main process when running distributed + +### Metric logging + +`log_metrics()` should append structured records to a local JSONL file. + +Expected behavior: + +- default file is `/metrics.jsonl` +- each record should include the provided metrics +- helper may also attach lightweight metadata such as timestamp, step, and epoch +- writing should happen only from the main process + +This gives TinyExp a useful local record format without introducing a full tracker framework. + +### Checkpoint helpers + +`save_checkpoint()` and `load_checkpoint()` should provide a standard way to persist and recover experiment state. + +Recommended checkpoint content: + +- `model_state_dict` +- `optimizer_state_dict` when available +- `scheduler_state_dict` when available +- `epoch` +- `global_step` +- `best_metric` +- `meta` + +Recommended metadata: + +- `exp_name` +- `exp_class` +- `saved_at` + +The helper should only standardize the storage format. It should not decide when checkpoints are written. + +### Resume helper + +`maybe_resume()` should be a thin convenience layer over `resume_from`. + +Expected behavior: + +- return `None` when `resume_from` is empty +- otherwise call `load_checkpoint()` +- return the loaded checkpoint state so the example can decide how to resume + +This keeps resume logic explicit while reducing repeated boilerplate. + +## Boundary Between TinyExp and Examples + +This phase depends on keeping a strong boundary between the framework and examples. + +### TinyExp should own + +- configuration structure and override ergonomics +- launch integration +- thin artifact helpers +- small reusable utilities shared across many experiments + +### Examples should own + +- model construction +- data loading details +- the training loop +- evaluation logic +- when validation runs +- when checkpoints are saved +- what metric counts as best +- whether and when external integrations are initialized + +This boundary is central to TinyExp's design. + +## Example Migration Strategy + +The first migration target should be `tinyexp/examples/mnist_exp.py`. + +It is a good candidate because: + +- it is small enough to change safely +- it already represents the intended user-facing workflow +- it can validate whether the helpers are actually reducing useful boilerplate + +The migration should: + +- keep the training loop inside the example +- replace repeated path/config writing code with helpers +- add checkpoint save/load through helpers +- add `mode=val` using `resume_from` + +Only after this works well should TinyExp consider extracting a recipe-style base class from examples. + +## Testing Plan + +Phase 1 should be backed by lightweight tests. + +Recommended test coverage: + +- unit tests for run directory creation +- unit tests for config dumping +- unit tests for metric logging +- unit tests for checkpoint save/load +- unit tests for `maybe_resume()` +- a small integration test for `mode=val` + +The tests should stay CPU-first and deterministic. + +## Implementation Order + +Recommended implementation order: + +1. add run directory helpers +2. add config dumping +3. add metric logging +4. add checkpoint save/load +5. add `maybe_resume()` +6. migrate `mnist_exp.py` +7. add `mode=val` +8. add tests + +This order keeps each change small and easy to validate. + +## Success Criteria + +Phase 1 is successful if TinyExp can do all of the following while still feeling light: + +- keep experiments centered around one explicit entrypoint +- preserve user-owned training loops +- save config and local metrics in a standard way +- save and resume checkpoints with minimal boilerplate +- support a simple validation flow from a checkpoint + +In short, Phase 1 should make TinyExp more practical without making it more framework-heavy. diff --git a/mkdocs.yml b/mkdocs.yml index 7bb1274..95f2e15 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -10,6 +10,7 @@ copyright: Maintained by zengarden. nav: - Home: index.md - Philosophy: philosophy.md + - Phase 1 Plan: phase1-minimal-helpers.md - Modules: modules.md plugins: - search From 62d652561308a4086c616cb30f9a28718d49caa7 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 18:03:37 +0800 Subject: [PATCH 05/30] docs: refine cfg-driven design philosophy --- README.md | 1 + docs/index.md | 1 + docs/phase1-minimal-helpers.md | 68 ++++++++++++++++++++++------------ docs/philosophy.md | 43 +++++++++++++++++---- 4 files changed, 82 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index a3e66df..f3b52f1 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ lifecycle. Instead, it focuses on a smaller goal: - keep the experiment itself as the main entrypoint - keep the training loop in user space - make configuration and launch behavior explicit +- expose shared capabilities through focused `XXXCfg` components - provide thin helpers rather than framework-owned control flow - treat examples as reusable recipes, not just demos diff --git a/docs/index.md b/docs/index.md index 7f25e25..722039c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -25,6 +25,7 @@ lifecycle. Instead, it focuses on a smaller and more explicit goal: - keep the experiment itself as the main entrypoint - keep the training loop in user space - make configuration and launch behavior explicit +- expose shared capabilities through focused `XXXCfg` components - provide thin helpers instead of framework-owned control flow - treat examples as reusable recipes, not just demos diff --git a/docs/phase1-minimal-helpers.md b/docs/phase1-minimal-helpers.md index 4edab37..0c860bf 100644 --- a/docs/phase1-minimal-helpers.md +++ b/docs/phase1-minimal-helpers.md @@ -2,8 +2,8 @@ This document describes the first implementation phase that follows TinyExp's design philosophy. -The goal is not to turn TinyExp into a trainer framework. The goal is to add a small set of reusable helpers that -remove repeated experiment plumbing while keeping training loops in user space. +The goal is not to turn TinyExp into a trainer framework. The goal is to add a small set of reusable, +configuration-driven helpers that remove repeated experiment plumbing while keeping training loops in user space. For the broader principles behind these choices, see [Design Philosophy](philosophy.md). @@ -21,6 +21,12 @@ What is still missing is a minimal layer for common experiment chores that many This phase adds only thin helpers for those chores. It does not add a trainer, runtime, callback engine, or framework owned lifecycle. +It also follows an additional structural rule: + +- shared capabilities should usually be exposed through a focused `XXXCfg` class +- fields inside that class should be Hydra-override-friendly +- behavior should only run when the user explicitly calls a method on that config object + ## Goals Phase 1 should make experiments easier to run and maintain without changing TinyExp's character. @@ -48,7 +54,8 @@ If a feature starts to own the user workflow instead of helping it, it is out of ## Minimal Additions to TinyExp -The base `TinyExp` class should remain small. This phase only proposes a few minimal additions. +The base `TinyExp` class should remain small. This phase only proposes a few minimal additions directly on `TinyExp`, +and it prefers feature-specific `XXXCfg` classes for behavior-rich capabilities. ### New fields @@ -64,7 +71,7 @@ These are intentionally minimal: More policy-driven settings should stay in examples unless they prove broadly reusable. -### New helper methods +### New helper methods on `TinyExp` The following methods are the proposed Phase 1 surface area: @@ -72,11 +79,30 @@ The following methods are the proposed Phase 1 surface area: - `ensure_run_dir() -> str` - `dump_config(path: str | None = None) -> str` - `log_metrics(metrics: dict, *, step: int | None = None, epoch: int | None = None, filename: str = "metrics.jsonl") -> None` + +These are helpers, not control-flow abstractions. + +### New `CheckpointCfg` + +Checkpointing should follow the same overall style as existing components such as `logger_cfg` and `wandb_cfg`. + +Recommended addition: + +- `checkpoint_cfg: CheckpointCfg` + +Recommended scope for `CheckpointCfg`: + +- default checkpoint filenames +- standard checkpoint save logic +- standard checkpoint load logic + +Recommended methods: + - `save_checkpoint(...) -> str` - `load_checkpoint(...) -> dict` -- `maybe_resume(...) -> dict | None` -These are helpers, not control-flow abstractions. +This keeps checkpoint behavior grouped with its own configuration instead of expanding `TinyExp` with many feature +specific methods. ## Artifact Conventions @@ -128,7 +154,8 @@ This gives TinyExp a useful local record format without introducing a full track ### Checkpoint helpers -`save_checkpoint()` and `load_checkpoint()` should provide a standard way to persist and recover experiment state. +`checkpoint_cfg.save_checkpoint()` and `checkpoint_cfg.load_checkpoint()` should provide a standard way to persist and +recover experiment state. Recommended checkpoint content: @@ -148,17 +175,11 @@ Recommended metadata: The helper should only standardize the storage format. It should not decide when checkpoints are written. -### Resume helper - -`maybe_resume()` should be a thin convenience layer over `resume_from`. - -Expected behavior: - -- return `None` when `resume_from` is empty -- otherwise call `load_checkpoint()` -- return the loaded checkpoint state so the example can decide how to resume +Resume should remain explicit in user code: -This keeps resume logic explicit while reducing repeated boilerplate. +- `resume_from` stores the path +- the example decides whether to call `checkpoint_cfg.load_checkpoint()` +- the example decides how to continue from the loaded state ## Boundary Between TinyExp and Examples @@ -169,6 +190,7 @@ This phase depends on keeping a strong boundary between the framework and exampl - configuration structure and override ergonomics - launch integration - thin artifact helpers +- feature-specific `XXXCfg` components - small reusable utilities shared across many experiments ### Examples should own @@ -198,7 +220,7 @@ The migration should: - keep the training loop inside the example - replace repeated path/config writing code with helpers -- add checkpoint save/load through helpers +- add checkpoint save/load through `checkpoint_cfg` - add `mode=val` using `resume_from` Only after this works well should TinyExp consider extracting a recipe-style base class from examples. @@ -213,7 +235,6 @@ Recommended test coverage: - unit tests for config dumping - unit tests for metric logging - unit tests for checkpoint save/load -- unit tests for `maybe_resume()` - a small integration test for `mode=val` The tests should stay CPU-first and deterministic. @@ -225,11 +246,10 @@ Recommended implementation order: 1. add run directory helpers 2. add config dumping 3. add metric logging -4. add checkpoint save/load -5. add `maybe_resume()` -6. migrate `mnist_exp.py` -7. add `mode=val` -8. add tests +4. add `CheckpointCfg` with save/load +5. migrate `mnist_exp.py` +6. add `mode=val` +7. add tests This order keeps each change small and easy to validate. diff --git a/docs/philosophy.md b/docs/philosophy.md index 911ad84..d31fdcd 100644 --- a/docs/philosophy.md +++ b/docs/philosophy.md @@ -14,6 +14,7 @@ It focuses on: - keeping the experiment definition as the main entrypoint - making configuration explicit and easy to override +- expressing shared features through focused `XXXCfg` components - supporting multiple launch styles without changing experiment code too much - keeping user code close to normal PyTorch - reducing repeated experiment "plumbing" without owning the full training lifecycle @@ -49,7 +50,34 @@ TinyExp prefers explicit calls over hidden side effects. For example, integrations with external systems such as W&B should remain explicit. A config object can expose the ability to build an integration, but the user should still decide when to call it. -### 3. Keep the training loop in user space +This same rule applies to TinyExp features more broadly: + +- configuration should live in a small `XXXCfg` class +- config fields should be override-friendly through Hydra +- behavior should only run when the user explicitly calls a method on that config object + +This keeps configuration and execution separate while still letting execution be part of the experiment structure. + +### 3. Prefer `XXXCfg` components for shared features + +When TinyExp grows, new capabilities should usually be introduced as focused config components rather than as many +top-level methods on `TinyExp`. + +For example, a feature is often a better fit as: + +- `logger_cfg.build_logger(...)` +- `wandb_cfg.build_wandb(...)` +- `checkpoint_cfg.save_checkpoint(...)` + +than as a large collection of flat framework methods. + +This pattern keeps a feature's configuration and execution close together: + +- fields describe the feature and can be overridden through Hydra +- methods execute behavior only when the user explicitly calls them +- `TinyExp` itself stays smaller and easier to understand + +### 4. Keep the training loop in user space The training loop is often the most task-specific part of an experiment. TinyExp should not rush to abstract it into a universal trainer. @@ -61,19 +89,19 @@ Users should be able to: - control when to validate, log, save, or resume - stay in plain PyTorch as much as possible -### 4. Helpers are good; control frameworks are not +### 5. Helpers are good; control frameworks are not TinyExp should provide thin, reusable helpers for common experiment chores, such as: - output directory setup - config dumping - lightweight metric logging -- checkpoint save/load helpers +- checkpoint save/load helpers exposed through focused config components - launcher compatibility These helpers reduce repeated boilerplate without dictating how the user structures training. -### 5. Examples are recipes, not just demos +### 6. Examples are recipes, not just demos Examples in TinyExp are not only meant to showcase features. They should also serve as reusable recipes and inheritance-friendly templates. @@ -82,7 +110,7 @@ That means examples should remain understandable and useful as starting points f emerges across multiple examples, it may be worth extracting a small helper or a recipe base class. But that logic should only move into the framework when it is broadly useful and still keeps the system light. -### 6. Framework-level additions must earn their place +### 7. Framework-level additions must earn their place A good question for any new feature is: @@ -100,7 +128,8 @@ architecture, it probably does not belong in TinyExp. - configuration structure and CLI overrides - experiment entry and launch ergonomics - lightweight utilities shared across many experiments -- minimal artifact helpers that do not take over control flow +- small `XXXCfg` components for shared capabilities +- minimal helpers that do not take over control flow ### Examples or user experiments should own @@ -117,7 +146,7 @@ This boundary keeps TinyExp small while still making it genuinely useful. When extending TinyExp, prefer: -- small helpers over large abstractions +- small `XXXCfg` components over large abstractions - explicit calls over automatic behavior - recipe-style examples over framework-owned trainers - local clarity over generic indirection From 9305d1cffb024fdea530b38ad31226281873f03e Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 18:19:01 +0800 Subject: [PATCH 06/30] docs: add phase 1 file-by-file plan --- docs/phase1-file-by-file-plan.md | 315 +++++++++++++++++++++++++++++++ mkdocs.yml | 1 + 2 files changed, 316 insertions(+) create mode 100644 docs/phase1-file-by-file-plan.md diff --git a/docs/phase1-file-by-file-plan.md b/docs/phase1-file-by-file-plan.md new file mode 100644 index 0000000..fc811c3 --- /dev/null +++ b/docs/phase1-file-by-file-plan.md @@ -0,0 +1,315 @@ +# Phase 1: File-by-File Implementation Plan + +This document turns the Phase 1 helper direction into a file-by-file implementation plan. + +It follows the same constraints described in: + +- [Design Philosophy](philosophy.md) +- [Phase 1: Minimal Helpers Plan](phase1-minimal-helpers.md) + +The key design rule is unchanged: + +- shared capabilities should usually be exposed through focused `XXXCfg` classes +- those config fields should be Hydra-override-friendly +- behavior should execute only when the user explicitly calls a method + +This plan is intentionally conservative. It aims to establish one clean first slice, not to solve every future need in +one pass. + +## Phase 1 Scope + +The first implementation slice should focus on: + +- a stable run directory helper +- explicit config dumping +- a new `CheckpointCfg` +- `mode=val` support through explicit checkpoint loading +- a migration of the MNIST example to validate the design + +This phase should not introduce: + +- a trainer abstraction +- a runtime layer +- callback systems +- automatic checkpoint policy +- automatic resume behavior + +## Files to Change + +The recommended Phase 1 file set is intentionally small: + +- `tinyexp/__init__.py` +- `tinyexp/examples/mnist_exp.py` +- `tests/test_tinyexp.py` or a new artifact-focused test file +- a new checkpoint-focused test file +- optionally one example-level integration test for validation mode + +The ResNet example should not be part of the first slice. + +## 1. `tinyexp/__init__.py` + +This is the main design anchor for Phase 1. + +### Why this file changes + +This file already defines: + +- `TinyExp` +- `LoggerCfg` +- `WandbCfg` +- `RedisCfgMixin` + +That makes it the natural place to reinforce the cfg-driven model and add the first checkpoint component. + +### What should change + +#### Keep `TinyExp` small + +`TinyExp` should continue to be the root experiment object, but it should not turn into a feature sink. + +It should remain responsible for: + +- experiment-level config structure +- launcher-facing fields +- a small number of experiment-wide helpers +- composition of shared `XXXCfg` components + +#### Add only minimal experiment-wide fields + +Recommended additions or clarifications: + +- `mode: str = "train"` +- `resume_from: str = ""` + +These belong at the experiment level because they describe run intent rather than one isolated feature subsystem. + +#### Add only minimal experiment-wide helpers + +Recommended methods on `TinyExp`: + +- `get_run_dir() -> str` +- `ensure_run_dir() -> str` +- `dump_config(path: str | None = None) -> str` + +These are good fits for `TinyExp` because they are experiment-scoped rather than belonging to a single feature config. + +### What should not be added here + +Avoid adding many feature-specific top-level methods such as: + +- `save_checkpoint(...)` +- `load_checkpoint(...)` +- `maybe_resume(...)` + +Those are better expressed through a focused config component. + +## 2. Add `CheckpointCfg` in `tinyexp/__init__.py` + +For the first slice, `CheckpointCfg` can live in `tinyexp/__init__.py` alongside `LoggerCfg` and `WandbCfg`. + +This keeps the initial implementation simple and consistent with the current project structure. + +If it grows later, it can be split into a dedicated module. + +### Why `CheckpointCfg` + +Checkpointing fits the cfg-driven TinyExp pattern well: + +- filenames and related defaults are configuration +- save/load methods are explicit actions +- users choose when to call those methods + +This is more aligned with TinyExp's style than adding many checkpoint methods directly to `TinyExp`. + +### Recommended `CheckpointCfg` scope + +Fields: + +- `last_ckpt_name: str = "last.ckpt"` +- `best_ckpt_name: str = "best.ckpt"` + +Methods: + +- `save_checkpoint(...) -> str` +- `load_checkpoint(...) -> dict` + +### Recommended responsibilities + +`CheckpointCfg` should handle: + +- default checkpoint filenames +- run-dir-relative checkpoint path generation when useful +- standard save format +- standard load behavior +- optional loading into model / optimizer / scheduler objects + +### What `CheckpointCfg` should not own + +Do not put policy into `CheckpointCfg`, including: + +- when to save +- whether to save best checkpoints +- how to compare best metrics +- save frequency +- retention policies +- automatic resume behavior + +Those decisions belong in the example or user code. + +## 3. `tinyexp/examples/mnist_exp.py` + +This file should be the first real migration target. + +### Why this file changes first + +The MNIST example is: + +- small enough to change safely +- representative of the intended user workflow +- a good way to validate whether the cfg-driven helper design actually reduces useful boilerplate + +### What should change + +#### `run()` should adopt experiment-wide helpers + +Expected updates: + +- call `self.ensure_run_dir()` +- build the logger using `self.logger_cfg.build_logger(...)` +- call `self.dump_config()` +- branch on `self.mode` + +#### training should remain explicit + +The training loop should stay in the example. + +What should change is only the repeated plumbing: + +- explicit checkpoint loading when `self.resume_from` is set +- explicit calls to `self.checkpoint_cfg.save_checkpoint(...)` +- explicit best-checkpoint save logic, still decided by the example + +#### validation should also stay explicit + +For `mode=val`, the example should: + +- require a meaningful `resume_from` +- explicitly call `self.checkpoint_cfg.load_checkpoint(...)` +- run evaluation logic in example code + +The example remains responsible for evaluation semantics. + +### What should not change + +Do not try to extract: + +- a trainer +- a recipe base class +- generic evaluation policy + +Those can be revisited later only if repeated patterns clearly emerge. + +## 4. `tinyexp/examples/resnet_exp.py` + +This file should not be part of the first implementation slice. + +### Why it should wait + +The ResNet example includes additional concerns: + +- DDP usage +- ImageNet-specific data loading +- Redis-backed caching +- a more complex training setup + +It is not the right place to define the first minimal checkpoint and artifact API. + +### Recommended Phase 1 stance + +- leave it unchanged +- only revisit after the MNIST migration proves the shape of the APIs + +## 5. Tests + +Phase 1 needs lightweight but meaningful coverage. + +### Recommended test additions + +#### Artifact tests + +Add or expand tests for: + +- `get_run_dir()` +- `ensure_run_dir()` +- `dump_config()` + +These can live in: + +- `tests/test_tinyexp.py` +- or a new `tests/test_tinyexp_artifacts.py` + +#### Checkpoint tests + +Add a dedicated checkpoint test file, for example: + +- `tests/test_tinyexp_checkpoint_cfg.py` + +Cover: + +- save model-only checkpoint +- save/load with optimizer and scheduler +- standard metadata presence +- correct state restoration + +#### Example-level validation test + +Add one small integration-style test for: + +- `mode=val` +- loading a checkpoint through `resume_from` + +This should stay CPU-first and deterministic. + +## 6. Files Not Needed in Phase 1 + +The following files or modules do not need changes in the first slice: + +- `tinyexp/utils/ray_utils.py` +- `tinyexp/tiny_engine/accelerator/*` +- `tinyexp/examples/resnet_exp.py` +- Redis-related utilities + +This is important. + +The first slice should validate the cfg-driven artifact pattern, not broaden the implementation surface. + +## Recommended Implementation Order + +The order below minimizes risk and keeps the design easy to validate. + +1. update `tinyexp/__init__.py` with: + - `mode` + - `resume_from` + - `get_run_dir()` + - `ensure_run_dir()` + - `dump_config()` + - `CheckpointCfg` +2. add checkpoint-focused tests +3. migrate `tinyexp/examples/mnist_exp.py` +4. add validation-mode test coverage +5. only then decide whether any further cfg component is worth introducing + +## Stop Point for Phase 1 + +Phase 1 should stop once the following are true: + +- experiment-level artifact basics are available +- checkpointing is exposed through `checkpoint_cfg` +- the MNIST example uses the new pattern successfully +- validation from checkpoint works +- the project still feels light and explicit + +That stop point matters. + +The goal of Phase 1 is not to fully design TinyExp's long-term helper ecosystem. The goal is to establish one clean, +cfg-driven example of how shared capabilities should be added without drifting toward a trainer framework. diff --git a/mkdocs.yml b/mkdocs.yml index 95f2e15..af7a657 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -11,6 +11,7 @@ nav: - Home: index.md - Philosophy: philosophy.md - Phase 1 Plan: phase1-minimal-helpers.md + - Phase 1 File Plan: phase1-file-by-file-plan.md - Modules: modules.md plugins: - search From 033bd678cded0ab7486bbb5a74096aa8b673442c Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 18:33:16 +0800 Subject: [PATCH 07/30] docs: add phase 1 api draft --- docs/phase1-api-draft.md | 305 +++++++++++++++++++++++++++++++++++++++ mkdocs.yml | 1 + 2 files changed, 306 insertions(+) create mode 100644 docs/phase1-api-draft.md diff --git a/docs/phase1-api-draft.md b/docs/phase1-api-draft.md new file mode 100644 index 0000000..3d1002d --- /dev/null +++ b/docs/phase1-api-draft.md @@ -0,0 +1,305 @@ +# Phase 1: API Draft + +This document proposes a concrete API shape for the first implementation slice. + +It should be read together with: + +- [Design Philosophy](philosophy.md) +- [Phase 1: Minimal Helpers Plan](phase1-minimal-helpers.md) +- [Phase 1: File-by-File Implementation Plan](phase1-file-by-file-plan.md) + +This is still a draft. The goal is to make the intended shape explicit before implementation expands further. + +## Drafting Principles + +The Phase 1 API should follow these rules: + +- keep `TinyExp` small +- keep training and evaluation loops in examples +- expose shared capabilities through focused `XXXCfg` components +- make configuration override-friendly through Hydra +- keep execution explicit through method calls +- avoid introducing trainer-like control flow + +## `TinyExp` Draft Surface + +Phase 1 keeps the root experiment object intentionally small. + +### Fields + +Recommended experiment-level fields: + +- `mode: str = "train"` +- `resume_from: str = ""` +- `output_root: str = "./output"` +- `exp_name: str = ...` + +These fields describe the experiment as a whole rather than one isolated feature subsystem. + +### Composed config components + +The experiment object should continue to expose capability-specific config components, such as: + +- `logger_cfg` +- `wandb_cfg` +- `checkpoint_cfg` + +Other feature configs may be added later only if they earn their place. + +### Methods + +Recommended Phase 1 methods on `TinyExp`: + +```python +def get_run_dir(self) -> str: + ... + +def ensure_run_dir(self) -> str: + ... + +def dump_config(self, path: str | None = None) -> str: + ... +``` + +These belong on `TinyExp` because they are experiment-scoped, not feature-scoped. + +## `CheckpointCfg` Draft + +Checkpointing is the main new shared capability in Phase 1. + +It should follow the same cfg-driven pattern as logger and W&B integration: + +- fields define behavior and defaults +- methods perform explicit actions only when called + +### Draft fields + +```python +@dataclass +class CheckpointCfg: + last_ckpt_name: str = "last.ckpt" + best_ckpt_name: str = "best.ckpt" +``` + +Phase 1 should keep this deliberately small. + +### Draft methods + +```python +def save_checkpoint( + self, + *, + run_dir: str, + name: str, + model=None, + optimizer=None, + scheduler=None, + epoch: int | None = None, + global_step: int | None = None, + best_metric: float | None = None, + extra_state: dict | None = None, +) -> str: + ... + +def load_checkpoint( + self, + path: str, + *, + model=None, + optimizer=None, + scheduler=None, + strict: bool = True, + map_location=None, +) -> dict: + ... +``` + +### Responsibilities + +`CheckpointCfg` should: + +- define default checkpoint filenames +- save a standard checkpoint structure +- load a standard checkpoint structure +- optionally restore state into provided model / optimizer / scheduler objects + +### Non-responsibilities + +`CheckpointCfg` should not decide: + +- when to save +- whether to save best checkpoints +- which metric is considered best +- whether resume is automatic +- how many checkpoints to retain + +Those remain example-level or user-level decisions. + +## Draft Checkpoint Format + +The checkpoint format should be simple and explicit. + +Recommended structure: + +```python +{ + "model_state_dict": ..., + "optimizer_state_dict": ..., + "scheduler_state_dict": ..., + "epoch": ..., + "global_step": ..., + "best_metric": ..., + "meta": { + "exp_name": ..., + "exp_class": ..., + "saved_at": ..., + }, + ... +} +``` + +Notes: + +- `optimizer_state_dict` is optional +- `scheduler_state_dict` is optional +- `meta` should stay lightweight +- `extra_state` can extend the structure without forcing premature abstraction + +## Config Dump Draft + +Configuration dumping should remain an experiment-level helper. + +### Draft method + +```python +def dump_config(self, path: str | None = None) -> str: + ... +``` + +### Expected behavior + +- default path is `/config.yaml` +- output reflects current configuration after Hydra overrides +- dump should be safe to call from examples +- distributed runs should avoid duplicate writes + +## Run Directory Draft + +Run directory behavior should remain simple in Phase 1. + +### Draft methods + +```python +def get_run_dir(self) -> str: + ... + +def ensure_run_dir(self) -> str: + ... +``` + +### Expected behavior + +- `get_run_dir()` returns `os.path.join(self.output_root, self.exp_name)` +- `ensure_run_dir()` creates the directory and returns it + +Phase 1 should not add timestamped run folders, version managers, or heavier run registry behavior. + +## Example Usage Draft + +Below is the intended style for examples after Phase 1. + +### Logger setup + +```python +run_dir = self.ensure_run_dir() +logger = self.logger_cfg.build_logger( + save_dir=run_dir, + distributed_rank=accelerator.rank, +) +self.dump_config() +``` + +### Explicit W&B usage + +```python +if self.wandb_cfg.enable_wandb: + wandb = self.wandb_cfg.build_wandb( + accelerator=accelerator, + project="Baselines", + config=cfg_dict, + ) +``` + +### Explicit checkpoint load + +```python +resume_state = None +if self.resume_from: + resume_state = self.checkpoint_cfg.load_checkpoint( + self.resume_from, + model=model, + optimizer=optimizer, + scheduler=scheduler, + map_location=accelerator.device, + ) +``` + +### Explicit checkpoint save + +```python +self.checkpoint_cfg.save_checkpoint( + run_dir=run_dir, + name=self.checkpoint_cfg.last_ckpt_name, + model=accelerator.unwrap_model(model), + optimizer=optimizer, + scheduler=scheduler, + epoch=epoch, + global_step=global_step, + best_metric=best_metric, +) +``` + +### Explicit best checkpoint policy in example code + +```python +if best_metric is None or val_metric > best_metric: + best_metric = val_metric + self.checkpoint_cfg.save_checkpoint( + run_dir=run_dir, + name=self.checkpoint_cfg.best_ckpt_name, + model=accelerator.unwrap_model(model), + optimizer=optimizer, + scheduler=scheduler, + epoch=epoch, + global_step=global_step, + best_metric=best_metric, + ) +``` + +This is the intended balance: + +- framework provides the capability +- example decides when to use it + +## API Choices Deferred Beyond Phase 1 + +The following questions should stay open until the first implementation slice proves itself: + +- whether metrics deserve their own `MetricCfg` +- whether config dumping should later move into a dedicated artifact cfg +- whether `CheckpointCfg` should move into its own module +- whether shared recipe base classes are worth introducing + +These should not be over-designed before the first slice is working. + +## Success Criteria + +The Phase 1 API draft is successful if it leads to implementation that feels: + +- small +- explicit +- override-friendly +- consistent with the existing `XXXCfg` style +- still close to plain PyTorch + +That is the standard Phase 1 should be judged against. diff --git a/mkdocs.yml b/mkdocs.yml index af7a657..b0ed407 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -12,6 +12,7 @@ nav: - Philosophy: philosophy.md - Phase 1 Plan: phase1-minimal-helpers.md - Phase 1 File Plan: phase1-file-by-file-plan.md + - Phase 1 API Draft: phase1-api-draft.md - Modules: modules.md plugins: - search From 0b89c71d829df114677c722a273dc5e25023cfa3 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 18:43:23 +0800 Subject: [PATCH 08/30] feat: add cfg-driven checkpoint and artifact helpers --- tests/test_tinyexp_artifacts.py | 79 ++++++++++++++++++++++++ tinyexp/__init__.py | 104 +++++++++++++++++++++++++++++++- 2 files changed, 180 insertions(+), 3 deletions(-) create mode 100644 tests/test_tinyexp_artifacts.py diff --git a/tests/test_tinyexp_artifacts.py b/tests/test_tinyexp_artifacts.py new file mode 100644 index 0000000..706d43c --- /dev/null +++ b/tests/test_tinyexp_artifacts.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from pathlib import Path + +import torch + +from tinyexp import CheckpointCfg, TinyExp + + +def test_get_run_dir_and_ensure_run_dir(tmp_path: Path) -> None: + exp = TinyExp(output_root=str(tmp_path), exp_name="demo_exp") + + expected = tmp_path / "demo_exp" + assert exp.get_run_dir() == str(expected) + + created = Path(exp.ensure_run_dir()) + assert created == expected + assert created.is_dir() + + +def test_dump_config_writes_yaml(tmp_path: Path, monkeypatch) -> None: + monkeypatch.setenv("RANK", "0") + exp = TinyExp(output_root=str(tmp_path), exp_name="demo_exp", mode="val", resume_from="checkpoint.ckpt") + + dumped = Path(exp.dump_config()) + + assert dumped == tmp_path / "demo_exp" / "config.yaml" + content = dumped.read_text(encoding="utf-8") + assert "exp_name: demo_exp" in content + assert "mode: val" in content + assert "resume_from: checkpoint.ckpt" in content + + +def test_checkpoint_cfg_save_and_load_roundtrip(tmp_path: Path) -> None: + model = torch.nn.Linear(2, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + + with torch.no_grad(): + model.weight.fill_(1.5) + model.bias.fill_(0.5) + + checkpoint_cfg = CheckpointCfg() + checkpoint_path = checkpoint_cfg.save_checkpoint( + run_dir=str(tmp_path), + name=checkpoint_cfg.last_ckpt_name, + model=model, + optimizer=optimizer, + scheduler=scheduler, + epoch=3, + global_step=12, + best_metric=0.9, + exp_name="demo_exp", + exp_class="tests.demo.Exp", + extra_state={"custom_value": 7}, + ) + + reloaded_model = torch.nn.Linear(2, 1) + reloaded_optimizer = torch.optim.SGD(reloaded_model.parameters(), lr=0.1) + reloaded_scheduler = torch.optim.lr_scheduler.StepLR(reloaded_optimizer, step_size=1) + + checkpoint = checkpoint_cfg.load_checkpoint( + checkpoint_path, + model=reloaded_model, + optimizer=reloaded_optimizer, + scheduler=reloaded_scheduler, + ) + + assert Path(checkpoint_path).is_file() + assert checkpoint["epoch"] == 3 + assert checkpoint["global_step"] == 12 + assert checkpoint["best_metric"] == 0.9 + assert checkpoint["custom_value"] == 7 + assert checkpoint["meta"]["exp_name"] == "demo_exp" + assert checkpoint["meta"]["exp_class"] == "tests.demo.Exp" + assert "saved_at" in checkpoint["meta"] + + for original_param, reloaded_param in zip(model.parameters(), reloaded_model.parameters()): + assert torch.equal(original_param, reloaded_param) diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index d656884..84544c2 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -5,18 +5,20 @@ import os import sys from dataclasses import dataclass, field -from typing import Optional +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional from hydra.conf import HydraConf, RunDir from hydra.core.config_store import ConfigStore -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from omegaconf.listconfig import ListConfig from .exceptions import UnknownConfigurationKeyError from .utils.log_utils import tiny_logger_setup from .utils.ray_utils import simple_launch_exp -__all__ = ["ConfigStore", "RedisCfgMixin", "TinyExp", "simple_launch_exp"] +__all__ = ["CheckpointCfg", "ConfigStore", "RedisCfgMixin", "TinyExp", "simple_launch_exp"] @dataclass @@ -46,6 +48,81 @@ def _default_exp_name() -> str: return "exp" +def _is_main_process() -> bool: + return os.getenv("RANK", "0") == "0" + + +@dataclass +class CheckpointCfg: + last_ckpt_name: str = "last.ckpt" + best_ckpt_name: str = "best.ckpt" + + def save_checkpoint( + self, + *, + run_dir: str, + name: str, + model=None, + optimizer=None, + scheduler=None, + epoch: Optional[int] = None, + global_step: Optional[int] = None, + best_metric: Optional[float] = None, + exp_name: str = "", + exp_class: str = "", + extra_state: Optional[dict[str, Any]] = None, + ) -> str: + import torch + + save_path = Path(run_dir) / name + save_path.parent.mkdir(parents=True, exist_ok=True) + + checkpoint: dict[str, Any] = { + "epoch": epoch, + "global_step": global_step, + "best_metric": best_metric, + "meta": { + "exp_name": exp_name, + "exp_class": exp_class, + "saved_at": datetime.now(timezone.utc).isoformat(), + }, + } + if model is not None: + checkpoint["model_state_dict"] = model.state_dict() + if optimizer is not None: + checkpoint["optimizer_state_dict"] = optimizer.state_dict() + if scheduler is not None: + checkpoint["scheduler_state_dict"] = scheduler.state_dict() + if extra_state is not None: + checkpoint.update(extra_state) + + torch.save(checkpoint, save_path) + return str(save_path) + + def load_checkpoint( + self, + path: str, + *, + model=None, + optimizer=None, + scheduler=None, + strict: bool = True, + map_location=None, + ) -> dict[str, Any]: + import torch + + checkpoint = torch.load(path, map_location=map_location) + + if model is not None and "model_state_dict" in checkpoint: + model.load_state_dict(checkpoint["model_state_dict"], strict=strict) + if optimizer is not None and "optimizer_state_dict" in checkpoint: + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + if scheduler is not None and "scheduler_state_dict" in checkpoint: + scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + + return checkpoint + + @dataclass class TinyExp: """ @@ -72,6 +149,8 @@ class TinyExp: # log directory output_root: str = "./output" + mode: str = "train" + resume_from: str = "" # overridden configurations, only for internal use overrided_cfg: dict = field(default_factory=dict) @@ -102,6 +181,25 @@ def build_logger(self, save_dir: str, distributed_rank: int = 0, filename: str = return logger logger_cfg: LoggerCfg = field(default_factory=LoggerCfg) + checkpoint_cfg: CheckpointCfg = field(default_factory=CheckpointCfg) + + def get_run_dir(self) -> str: + return os.path.join(self.output_root, self.exp_name) + + def ensure_run_dir(self) -> str: + run_dir = self.get_run_dir() + Path(run_dir).mkdir(parents=True, exist_ok=True) + return run_dir + + def dump_config(self, path: Optional[str] = None) -> str: + run_dir = self.ensure_run_dir() + dump_path = Path(path) if path is not None else Path(run_dir) / "config.yaml" + + if _is_main_process(): + cfg_dict = OmegaConf.to_container(OmegaConf.structured(self), resolve=True) + dump_path.write_text(OmegaConf.to_yaml(cfg_dict), encoding="utf-8") + + return str(dump_path) def set_cfg(self, cfg_hydra, cfg_object=None): if cfg_object is None: From 07859a1787c38f6184d0045c14894168ffe18f39 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 18:53:15 +0800 Subject: [PATCH 09/30] feat: add checkpoint-aware mnist example flows --- tests/examples/test_mnist_exp_unit.py | 38 ++++++++++++ tinyexp/examples/mnist_exp.py | 83 +++++++++++++++++++++++---- 2 files changed, 111 insertions(+), 10 deletions(-) create mode 100644 tests/examples/test_mnist_exp_unit.py diff --git a/tests/examples/test_mnist_exp_unit.py b/tests/examples/test_mnist_exp_unit.py new file mode 100644 index 0000000..e460d08 --- /dev/null +++ b/tests/examples/test_mnist_exp_unit.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from tinyexp.examples.mnist_exp import Exp + + +def test_mnist_val_mode_requires_resume_from(tmp_path) -> None: + exp = Exp(output_root=str(tmp_path), exp_name="mnist_test", mode="val", resume_from="") + + dummy_accelerator = SimpleNamespace(rank=0, device="cpu", is_main_process=True) + dummy_logger = SimpleNamespace(info=lambda *args, **kwargs: None) + + with pytest.raises(ValueError, match="resume_from"): + exp._validate_from_checkpoint(accelerator=dummy_accelerator, logger=dummy_logger) + + +def test_mnist_validate_from_checkpoint_calls_evaluate(monkeypatch, tmp_path) -> None: + exp = Exp(output_root=str(tmp_path), exp_name="mnist_test", mode="val", resume_from="demo.ckpt") + + called: dict[str, object] = {} + + def fake_evaluate(*, accelerator, logger, module_or_module_path, val_dataloader=None): + called["accelerator"] = accelerator + called["logger"] = logger + called["module_or_module_path"] = module_or_module_path + called["val_dataloader"] = val_dataloader + return 0.5 + + monkeypatch.setattr(exp, "_evaluate", fake_evaluate) + + dummy_accelerator = SimpleNamespace(rank=0, device="cpu", is_main_process=True) + dummy_logger = SimpleNamespace(info=lambda *args, **kwargs: None) + exp._validate_from_checkpoint(accelerator=dummy_accelerator, logger=dummy_logger) + + assert called["module_or_module_path"] == "demo.ckpt" diff --git a/tinyexp/examples/mnist_exp.py b/tinyexp/examples/mnist_exp.py index 2a14de7..c934971 100644 --- a/tinyexp/examples/mnist_exp.py +++ b/tinyexp/examples/mnist_exp.py @@ -1,6 +1,6 @@ import datetime -import os from dataclasses import dataclass, field +from typing import Any import torch import torch.nn as nn @@ -15,6 +15,11 @@ from tinyexp.exceptions import UnknownAcceleratorTypeError +class ResumeFromRequiredError(ValueError): + def __init__(self) -> None: + super().__init__("resume_from") + + class Net(nn.Module): def __init__(self) -> None: super().__init__() @@ -142,21 +147,22 @@ def build_lr_scheduler(self, optimizer): # ------------------------------ bellowing is the execution part --------------------- # def run(self) -> None: accelerator = self.accelerator_cfg.build_accelerator() - logger = self.logger_cfg.build_logger( - save_dir=os.path.join(self.output_root, self.exp_name), - distributed_rank=accelerator.rank, - ) + run_dir = self.ensure_run_dir() + logger = self.logger_cfg.build_logger(save_dir=run_dir, distributed_rank=accelerator.rank) cfg_dict = OmegaConf.to_container(OmegaConf.structured(self), resolve=True) del cfg_dict["hydra"] + self.dump_config() cfg_msg = OmegaConf.to_yaml(cfg_dict).strip().replace("\n", "\n ") logger.info(f"-------- Configurations --------\n {cfg_msg}") if self.mode == "train": - self._train(accelerator=accelerator, logger=logger, cfg_dict=cfg_dict) + self._train(accelerator=accelerator, logger=logger, cfg_dict=cfg_dict, run_dir=run_dir) + elif self.mode == "val": + self._validate_from_checkpoint(accelerator=accelerator, logger=logger) else: raise NotImplementedError(f"Mode {self.mode} is not implemented") - def _evaluate(self, accelerator, logger, module_or_module_path, val_dataloader=None) -> None: + def _evaluate(self, accelerator, logger, module_or_module_path, val_dataloader=None) -> float: if isinstance(module_or_module_path, str): module = Net() module.load_state_dict(torch.load(module_or_module_path)) @@ -187,7 +193,9 @@ def _evaluate(self, accelerator, logger, module_or_module_path, val_dataloader=N if self.wandb_cfg.enable_wandb and accelerator.is_main_process: wandb.log({"val_metric": eval_metric}) - def _train(self, accelerator, logger, cfg_dict) -> None: + return eval_metric + + def _train(self, accelerator, logger, cfg_dict, run_dir: str) -> None: train_dataloader = self.dataloader_cfg.build_train_dataloader(accelerator) val_dataloader = self.dataloader_cfg.build_val_dataloader(accelerator) ori_module = self.module_cfg.build_module() @@ -195,6 +203,12 @@ def _train(self, accelerator, logger, cfg_dict) -> None: lr_scheduler = self.lr_scheduler_cfg.build_lr_scheduler(ori_optimizer) module, optimizer = accelerator.prepare(ori_module, ori_optimizer) + start_epoch, global_step, best_metric = self._load_training_state_if_needed( + accelerator=accelerator, + module=accelerator.unwrap_model(module), + optimizer=optimizer, + scheduler=lr_scheduler, + ) train_iter = iter(train_dataloader) if self.wandb_cfg.enable_wandb and accelerator.rank == 0: @@ -204,7 +218,7 @@ def _train(self, accelerator, logger, cfg_dict) -> None: config=cfg_dict, ) - for epoch in range(3): + for epoch in range(start_epoch, 3): module.train() for step in range(len(train_dataloader)): @@ -221,6 +235,7 @@ def _train(self, accelerator, logger, cfg_dict) -> None: optimizer.zero_grad() accelerator.backward(loss) optimizer.step() + global_step += 1 if (step + 1) % 20 == 0: logger.info(f"epoch {epoch} loss: {loss.item(): .4f} lr: {optimizer.param_groups[0]['lr']: .4f}") if self.wandb_cfg.enable_wandb and accelerator.rank == 0: @@ -231,12 +246,60 @@ def _train(self, accelerator, logger, cfg_dict) -> None: "lr": optimizer.param_groups[0]["lr"], } ) - self._evaluate( + eval_metric = self._evaluate( accelerator=accelerator, logger=logger, module_or_module_path=module, val_dataloader=val_dataloader ) + if accelerator.is_main_process: + self.checkpoint_cfg.save_checkpoint( + run_dir=run_dir, + name=self.checkpoint_cfg.last_ckpt_name, + model=accelerator.unwrap_model(module), + optimizer=optimizer, + scheduler=lr_scheduler, + epoch=epoch, + global_step=global_step, + best_metric=best_metric, + exp_name=self.exp_name, + exp_class=self.exp_class, + ) + if best_metric is None or eval_metric > best_metric: + best_metric = eval_metric + self.checkpoint_cfg.save_checkpoint( + run_dir=run_dir, + name=self.checkpoint_cfg.best_ckpt_name, + model=accelerator.unwrap_model(module), + optimizer=optimizer, + scheduler=lr_scheduler, + epoch=epoch, + global_step=global_step, + best_metric=best_metric, + exp_name=self.exp_name, + exp_class=self.exp_class, + ) lr_scheduler.step() + def _load_training_state_if_needed(self, accelerator, module, optimizer, scheduler) -> tuple[int, int, Any]: + if not self.resume_from: + return 0, 0, None + + checkpoint = self.checkpoint_cfg.load_checkpoint( + self.resume_from, + model=module, + optimizer=optimizer, + scheduler=scheduler, + map_location=accelerator.device, + ) + start_epoch = int(checkpoint.get("epoch", -1)) + 1 + global_step = int(checkpoint.get("global_step", 0)) + best_metric = checkpoint.get("best_metric") + return start_epoch, global_step, best_metric + + def _validate_from_checkpoint(self, accelerator, logger) -> None: + if not self.resume_from: + raise ResumeFromRequiredError + self._evaluate(accelerator=accelerator, logger=logger, module_or_module_path=self.resume_from) + # import hydra # @hydra.main(version_base=None, config_name="cfg") From 567e142f790b6da264801a3f2f438848f35e50b1 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 18:55:00 +0800 Subject: [PATCH 10/30] test: cover mnist validation run flow --- tests/examples/test_mnist_exp_run.py | 43 ++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/examples/test_mnist_exp_run.py diff --git a/tests/examples/test_mnist_exp_run.py b/tests/examples/test_mnist_exp_run.py new file mode 100644 index 0000000..2213c76 --- /dev/null +++ b/tests/examples/test_mnist_exp_run.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +from tinyexp.examples.mnist_exp import Exp + + +def test_mnist_run_val_mode_uses_checkpoint_and_dumps_config(tmp_path: Path, monkeypatch) -> None: + exp_for_ckpt = Exp(output_root=str(tmp_path), exp_name="mnist_val") + checkpoint_path = exp_for_ckpt.checkpoint_cfg.save_checkpoint( + run_dir=str(tmp_path / "mnist_val"), + name="demo.ckpt", + model=exp_for_ckpt.module_cfg.build_module(), + exp_name=exp_for_ckpt.exp_name, + exp_class=exp_for_ckpt.exp_class, + ) + + exp = Exp(output_root=str(tmp_path), exp_name="mnist_val", mode="val", resume_from=checkpoint_path) + + dummy_accelerator = SimpleNamespace(rank=0, device="cpu", is_main_process=True) + dummy_logger = SimpleNamespace(info=lambda *args, **kwargs: None) + + monkeypatch.setattr(exp.accelerator_cfg, "build_accelerator", lambda: dummy_accelerator) + monkeypatch.setattr(exp.logger_cfg, "build_logger", lambda **kwargs: dummy_logger) + + called: dict[str, object] = {} + + def fake_evaluate(*, accelerator, logger, module_or_module_path, val_dataloader=None): + called["accelerator"] = accelerator + called["logger"] = logger + called["module_or_module_path"] = module_or_module_path + called["val_dataloader"] = val_dataloader + return 0.5 + + monkeypatch.setattr(exp, "_evaluate", fake_evaluate) + + exp.run() + + assert called["accelerator"] is dummy_accelerator + assert called["logger"] is dummy_logger + assert called["module_or_module_path"] == checkpoint_path + assert (tmp_path / "mnist_val" / "config.yaml").is_file() From 09bcfc52ed3d34318b86e7613fe04c11d5517b36 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 19:34:46 +0800 Subject: [PATCH 11/30] fix: load model state from checkpoint files --- tests/examples/test_mnist_exp_unit.py | 44 +++++++++++++++++++++++++++ tinyexp/examples/mnist_exp.py | 6 +++- tinyexp/examples/resnet_exp.py | 6 +++- 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/tests/examples/test_mnist_exp_unit.py b/tests/examples/test_mnist_exp_unit.py index e460d08..55e3325 100644 --- a/tests/examples/test_mnist_exp_unit.py +++ b/tests/examples/test_mnist_exp_unit.py @@ -3,6 +3,7 @@ from types import SimpleNamespace import pytest +import torch from tinyexp.examples.mnist_exp import Exp @@ -36,3 +37,46 @@ def fake_evaluate(*, accelerator, logger, module_or_module_path, val_dataloader= exp._validate_from_checkpoint(accelerator=dummy_accelerator, logger=dummy_logger) assert called["module_or_module_path"] == "demo.ckpt" + + +def test_mnist_evaluate_loads_model_state_from_checkpoint(tmp_path) -> None: + exp = Exp(output_root=str(tmp_path), exp_name="mnist_test") + checkpoint_path = exp.checkpoint_cfg.save_checkpoint( + run_dir=str(tmp_path), + name="demo.ckpt", + model=exp.module_cfg.build_module(), + exp_name=exp.exp_name, + exp_class=exp.exp_class, + ) + + class DummyAccelerator: + device = "cpu" + rank = 0 + world_size = 1 + is_main_process = True + + def prepare(self, module): + return module + + def reduce_sum(self, tensor): + return tensor + + def wait_for_everyone(self) -> None: + return None + + dummy_logger = SimpleNamespace(info=lambda *args, **kwargs: None) + + class DummyDataLoader(list): + def __init__(self): + super().__init__([(torch.zeros(1, 1, 28, 28), torch.zeros(1, dtype=torch.long))]) + self.dataset = [0] + + val_dataloader = DummyDataLoader() + metric = exp._evaluate( + accelerator=DummyAccelerator(), + logger=dummy_logger, + module_or_module_path=checkpoint_path, + val_dataloader=val_dataloader, + ) + + assert isinstance(metric, float) diff --git a/tinyexp/examples/mnist_exp.py b/tinyexp/examples/mnist_exp.py index c934971..8d23aca 100644 --- a/tinyexp/examples/mnist_exp.py +++ b/tinyexp/examples/mnist_exp.py @@ -165,7 +165,11 @@ def run(self) -> None: def _evaluate(self, accelerator, logger, module_or_module_path, val_dataloader=None) -> float: if isinstance(module_or_module_path, str): module = Net() - module.load_state_dict(torch.load(module_or_module_path)) + self.checkpoint_cfg.load_checkpoint( + module_or_module_path, + model=module, + map_location=accelerator.device, + ) module = accelerator.prepare(module) else: module = module_or_module_path diff --git a/tinyexp/examples/resnet_exp.py b/tinyexp/examples/resnet_exp.py index 5447497..f25fd44 100644 --- a/tinyexp/examples/resnet_exp.py +++ b/tinyexp/examples/resnet_exp.py @@ -343,7 +343,11 @@ def run(self) -> None: def _evaluate(self, accelerator, logger, module_or_module_path, val_dataloader=None) -> None: if isinstance(module_or_module_path, str): module: nn.Module = self.module_cfg.build_module() - module.load_state_dict(torch.load(module_or_module_path)) + self.checkpoint_cfg.load_checkpoint( + module_or_module_path, + model=module, + map_location=accelerator.device, + ) module = accelerator.prepare_model(module) else: module = module_or_module_path From 17d3a3ef7d694bc3a2af25409e9252cd6da4d74f Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 22:23:36 +0800 Subject: [PATCH 12/30] refactor: simplify validation flow and checkpoint imports --- tests/examples/test_mnist_exp_run.py | 15 +++++++++++++ tests/examples/test_mnist_exp_unit.py | 32 --------------------------- tinyexp/__init__.py | 5 +---- tinyexp/examples/mnist_exp.py | 9 +++----- 4 files changed, 19 insertions(+), 42 deletions(-) diff --git a/tests/examples/test_mnist_exp_run.py b/tests/examples/test_mnist_exp_run.py index 2213c76..4c21354 100644 --- a/tests/examples/test_mnist_exp_run.py +++ b/tests/examples/test_mnist_exp_run.py @@ -3,9 +3,24 @@ from pathlib import Path from types import SimpleNamespace +import pytest + from tinyexp.examples.mnist_exp import Exp +def test_mnist_run_val_mode_requires_resume_from(tmp_path: Path, monkeypatch) -> None: + exp = Exp(output_root=str(tmp_path), exp_name="mnist_val", mode="val", resume_from="") + + dummy_accelerator = SimpleNamespace(rank=0, device="cpu", is_main_process=True) + dummy_logger = SimpleNamespace(info=lambda *args, **kwargs: None) + + monkeypatch.setattr(exp.accelerator_cfg, "build_accelerator", lambda: dummy_accelerator) + monkeypatch.setattr(exp.logger_cfg, "build_logger", lambda **kwargs: dummy_logger) + + with pytest.raises(ValueError, match="resume_from"): + exp.run() + + def test_mnist_run_val_mode_uses_checkpoint_and_dumps_config(tmp_path: Path, monkeypatch) -> None: exp_for_ckpt = Exp(output_root=str(tmp_path), exp_name="mnist_val") checkpoint_path = exp_for_ckpt.checkpoint_cfg.save_checkpoint( diff --git a/tests/examples/test_mnist_exp_unit.py b/tests/examples/test_mnist_exp_unit.py index 55e3325..be6b826 100644 --- a/tests/examples/test_mnist_exp_unit.py +++ b/tests/examples/test_mnist_exp_unit.py @@ -2,43 +2,11 @@ from types import SimpleNamespace -import pytest import torch from tinyexp.examples.mnist_exp import Exp -def test_mnist_val_mode_requires_resume_from(tmp_path) -> None: - exp = Exp(output_root=str(tmp_path), exp_name="mnist_test", mode="val", resume_from="") - - dummy_accelerator = SimpleNamespace(rank=0, device="cpu", is_main_process=True) - dummy_logger = SimpleNamespace(info=lambda *args, **kwargs: None) - - with pytest.raises(ValueError, match="resume_from"): - exp._validate_from_checkpoint(accelerator=dummy_accelerator, logger=dummy_logger) - - -def test_mnist_validate_from_checkpoint_calls_evaluate(monkeypatch, tmp_path) -> None: - exp = Exp(output_root=str(tmp_path), exp_name="mnist_test", mode="val", resume_from="demo.ckpt") - - called: dict[str, object] = {} - - def fake_evaluate(*, accelerator, logger, module_or_module_path, val_dataloader=None): - called["accelerator"] = accelerator - called["logger"] = logger - called["module_or_module_path"] = module_or_module_path - called["val_dataloader"] = val_dataloader - return 0.5 - - monkeypatch.setattr(exp, "_evaluate", fake_evaluate) - - dummy_accelerator = SimpleNamespace(rank=0, device="cpu", is_main_process=True) - dummy_logger = SimpleNamespace(info=lambda *args, **kwargs: None) - exp._validate_from_checkpoint(accelerator=dummy_accelerator, logger=dummy_logger) - - assert called["module_or_module_path"] == "demo.ckpt" - - def test_mnist_evaluate_loads_model_state_from_checkpoint(tmp_path) -> None: exp = Exp(output_root=str(tmp_path), exp_name="mnist_test") checkpoint_path = exp.checkpoint_cfg.save_checkpoint( diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index 84544c2..52075ff 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import Any, Optional +import torch from hydra.conf import HydraConf, RunDir from hydra.core.config_store import ConfigStore from omegaconf import DictConfig, OmegaConf @@ -72,8 +73,6 @@ def save_checkpoint( exp_class: str = "", extra_state: Optional[dict[str, Any]] = None, ) -> str: - import torch - save_path = Path(run_dir) / name save_path.parent.mkdir(parents=True, exist_ok=True) @@ -109,8 +108,6 @@ def load_checkpoint( strict: bool = True, map_location=None, ) -> dict[str, Any]: - import torch - checkpoint = torch.load(path, map_location=map_location) if model is not None and "model_state_dict" in checkpoint: diff --git a/tinyexp/examples/mnist_exp.py b/tinyexp/examples/mnist_exp.py index 8d23aca..982d9e3 100644 --- a/tinyexp/examples/mnist_exp.py +++ b/tinyexp/examples/mnist_exp.py @@ -158,7 +158,9 @@ def run(self) -> None: if self.mode == "train": self._train(accelerator=accelerator, logger=logger, cfg_dict=cfg_dict, run_dir=run_dir) elif self.mode == "val": - self._validate_from_checkpoint(accelerator=accelerator, logger=logger) + if not self.resume_from: + raise ResumeFromRequiredError + self._evaluate(accelerator=accelerator, logger=logger, module_or_module_path=self.resume_from) else: raise NotImplementedError(f"Mode {self.mode} is not implemented") @@ -299,11 +301,6 @@ def _load_training_state_if_needed(self, accelerator, module, optimizer, schedul best_metric = checkpoint.get("best_metric") return start_epoch, global_step, best_metric - def _validate_from_checkpoint(self, accelerator, logger) -> None: - if not self.resume_from: - raise ResumeFromRequiredError - self._evaluate(accelerator=accelerator, logger=logger, module_or_module_path=self.resume_from) - # import hydra # @hydra.main(version_base=None, config_name="cfg") From b311181c81ab6543e36e91139cf4f04e108e7dc7 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 22:28:42 +0800 Subject: [PATCH 13/30] refactor: inline mnist training resume state --- tinyexp/examples/mnist_exp.py | 37 +++++++++++++---------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/tinyexp/examples/mnist_exp.py b/tinyexp/examples/mnist_exp.py index 982d9e3..e411a1a 100644 --- a/tinyexp/examples/mnist_exp.py +++ b/tinyexp/examples/mnist_exp.py @@ -1,6 +1,5 @@ import datetime from dataclasses import dataclass, field -from typing import Any import torch import torch.nn as nn @@ -209,12 +208,20 @@ def _train(self, accelerator, logger, cfg_dict, run_dir: str) -> None: lr_scheduler = self.lr_scheduler_cfg.build_lr_scheduler(ori_optimizer) module, optimizer = accelerator.prepare(ori_module, ori_optimizer) - start_epoch, global_step, best_metric = self._load_training_state_if_needed( - accelerator=accelerator, - module=accelerator.unwrap_model(module), - optimizer=optimizer, - scheduler=lr_scheduler, - ) + start_epoch = 0 + global_step = 0 + best_metric = None + if self.resume_from: + checkpoint = self.checkpoint_cfg.load_checkpoint( + self.resume_from, + model=accelerator.unwrap_model(module), + optimizer=optimizer, + scheduler=lr_scheduler, + map_location=accelerator.device, + ) + start_epoch = int(checkpoint.get("epoch", -1)) + 1 + global_step = int(checkpoint.get("global_step", 0)) + best_metric = checkpoint.get("best_metric") train_iter = iter(train_dataloader) if self.wandb_cfg.enable_wandb and accelerator.rank == 0: @@ -285,22 +292,6 @@ def _train(self, accelerator, logger, cfg_dict, run_dir: str) -> None: lr_scheduler.step() - def _load_training_state_if_needed(self, accelerator, module, optimizer, scheduler) -> tuple[int, int, Any]: - if not self.resume_from: - return 0, 0, None - - checkpoint = self.checkpoint_cfg.load_checkpoint( - self.resume_from, - model=module, - optimizer=optimizer, - scheduler=scheduler, - map_location=accelerator.device, - ) - start_epoch = int(checkpoint.get("epoch", -1)) + 1 - global_step = int(checkpoint.get("global_step", 0)) - best_metric = checkpoint.get("best_metric") - return start_epoch, global_step, best_metric - # import hydra # @hydra.main(version_base=None, config_name="cfg") From a17bc69afac94c96463152ba778719bb653df2fb Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 22:51:10 +0800 Subject: [PATCH 14/30] refactor: remove ensure_run_dir helper --- tests/test_tinyexp_artifacts.py | 16 +++++++++++----- tinyexp/__init__.py | 9 +++------ tinyexp/examples/mnist_exp.py | 2 +- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/test_tinyexp_artifacts.py b/tests/test_tinyexp_artifacts.py index 706d43c..8d48c72 100644 --- a/tests/test_tinyexp_artifacts.py +++ b/tests/test_tinyexp_artifacts.py @@ -7,16 +7,12 @@ from tinyexp import CheckpointCfg, TinyExp -def test_get_run_dir_and_ensure_run_dir(tmp_path: Path) -> None: +def test_get_run_dir(tmp_path: Path) -> None: exp = TinyExp(output_root=str(tmp_path), exp_name="demo_exp") expected = tmp_path / "demo_exp" assert exp.get_run_dir() == str(expected) - created = Path(exp.ensure_run_dir()) - assert created == expected - assert created.is_dir() - def test_dump_config_writes_yaml(tmp_path: Path, monkeypatch) -> None: monkeypatch.setenv("RANK", "0") @@ -31,6 +27,16 @@ def test_dump_config_writes_yaml(tmp_path: Path, monkeypatch) -> None: assert "resume_from: checkpoint.ckpt" in content +def test_logger_cfg_creates_run_dir(tmp_path: Path) -> None: + exp = TinyExp(output_root=str(tmp_path), exp_name="demo_exp") + run_dir = Path(exp.get_run_dir()) + + exp.logger_cfg.build_logger(save_dir=str(run_dir), distributed_rank=0) + + assert run_dir.is_dir() + assert (run_dir / "log.txt").is_file() + + def test_checkpoint_cfg_save_and_load_roundtrip(tmp_path: Path) -> None: model = torch.nn.Linear(2, 1) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index 52075ff..ef093bd 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -173,6 +173,7 @@ def build_wandb(self, accelerator=None, **kwargs): @dataclass class LoggerCfg: def build_logger(self, save_dir: str, distributed_rank: int = 0, filename: str = "log.txt", mode: str = "a"): + Path(save_dir).mkdir(parents=True, exist_ok=True) logger = tiny_logger_setup(save_dir, distributed_rank, filename, mode) logger.info(f"==> log file: {os.path.join(save_dir, filename)}") return logger @@ -183,16 +184,12 @@ def build_logger(self, save_dir: str, distributed_rank: int = 0, filename: str = def get_run_dir(self) -> str: return os.path.join(self.output_root, self.exp_name) - def ensure_run_dir(self) -> str: - run_dir = self.get_run_dir() - Path(run_dir).mkdir(parents=True, exist_ok=True) - return run_dir - def dump_config(self, path: Optional[str] = None) -> str: - run_dir = self.ensure_run_dir() + run_dir = self.get_run_dir() dump_path = Path(path) if path is not None else Path(run_dir) / "config.yaml" if _is_main_process(): + dump_path.parent.mkdir(parents=True, exist_ok=True) cfg_dict = OmegaConf.to_container(OmegaConf.structured(self), resolve=True) dump_path.write_text(OmegaConf.to_yaml(cfg_dict), encoding="utf-8") diff --git a/tinyexp/examples/mnist_exp.py b/tinyexp/examples/mnist_exp.py index e411a1a..80e46b9 100644 --- a/tinyexp/examples/mnist_exp.py +++ b/tinyexp/examples/mnist_exp.py @@ -146,7 +146,7 @@ def build_lr_scheduler(self, optimizer): # ------------------------------ bellowing is the execution part --------------------- # def run(self) -> None: accelerator = self.accelerator_cfg.build_accelerator() - run_dir = self.ensure_run_dir() + run_dir = self.get_run_dir() logger = self.logger_cfg.build_logger(save_dir=run_dir, distributed_rank=accelerator.rank) cfg_dict = OmegaConf.to_container(OmegaConf.structured(self), resolve=True) del cfg_dict["hydra"] From 182d40e522b0b245124d3305e7c8657c0bdcf57b Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 23:01:43 +0800 Subject: [PATCH 15/30] fix: harden checkpoint format handling --- tests/test_tinyexp_artifacts.py | 32 +++++++++++++++++++++++++++++++- tinyexp/__init__.py | 17 +++++++++++++++-- tinyexp/exceptions.py | 15 +++++++++++++++ 3 files changed, 61 insertions(+), 3 deletions(-) diff --git a/tests/test_tinyexp_artifacts.py b/tests/test_tinyexp_artifacts.py index 8d48c72..2378cfc 100644 --- a/tests/test_tinyexp_artifacts.py +++ b/tests/test_tinyexp_artifacts.py @@ -2,9 +2,11 @@ from pathlib import Path +import pytest import torch from tinyexp import CheckpointCfg, TinyExp +from tinyexp.exceptions import UnsupportedCheckpointFormatError def test_get_run_dir(tmp_path: Path) -> None: @@ -76,10 +78,38 @@ def test_checkpoint_cfg_save_and_load_roundtrip(tmp_path: Path) -> None: assert checkpoint["epoch"] == 3 assert checkpoint["global_step"] == 12 assert checkpoint["best_metric"] == 0.9 - assert checkpoint["custom_value"] == 7 + assert checkpoint["format_version"] == 1 + assert checkpoint["extra_state"]["custom_value"] == 7 assert checkpoint["meta"]["exp_name"] == "demo_exp" assert checkpoint["meta"]["exp_class"] == "tests.demo.Exp" assert "saved_at" in checkpoint["meta"] for original_param, reloaded_param in zip(model.parameters(), reloaded_model.parameters()): assert torch.equal(original_param, reloaded_param) + + +def test_checkpoint_cfg_extra_state_does_not_override_reserved_keys(tmp_path: Path) -> None: + checkpoint_cfg = CheckpointCfg() + + checkpoint_path = checkpoint_cfg.save_checkpoint( + run_dir=str(tmp_path), + name=checkpoint_cfg.last_ckpt_name, + epoch=3, + extra_state={"epoch": 99, "meta": {"exp_name": "bad"}}, + ) + + checkpoint = checkpoint_cfg.load_checkpoint(checkpoint_path) + + assert checkpoint["epoch"] == 3 + assert checkpoint["extra_state"]["epoch"] == 99 + assert checkpoint["meta"]["exp_name"] == "" + + +def test_checkpoint_cfg_rejects_unsupported_model_only_format(tmp_path: Path) -> None: + checkpoint_path = tmp_path / "model_only.ckpt" + torch.save({"state_dict": {"weight": torch.tensor([1.0])}}, checkpoint_path) + + checkpoint_cfg = CheckpointCfg() + + with pytest.raises(UnsupportedCheckpointFormatError, match="not a supported tinyexp checkpoint format"): + checkpoint_cfg.load_checkpoint(str(checkpoint_path)) diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index ef093bd..c0b1007 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -15,7 +15,7 @@ from omegaconf import DictConfig, OmegaConf from omegaconf.listconfig import ListConfig -from .exceptions import UnknownConfigurationKeyError +from .exceptions import InvalidCheckpointTypeError, UnknownConfigurationKeyError, UnsupportedCheckpointFormatError from .utils.log_utils import tiny_logger_setup from .utils.ray_utils import simple_launch_exp @@ -55,6 +55,7 @@ def _is_main_process() -> bool: @dataclass class CheckpointCfg: + format_version: int = 1 last_ckpt_name: str = "last.ckpt" best_ckpt_name: str = "best.ckpt" @@ -77,6 +78,7 @@ def save_checkpoint( save_path.parent.mkdir(parents=True, exist_ok=True) checkpoint: dict[str, Any] = { + "format_version": self.format_version, "epoch": epoch, "global_step": global_step, "best_metric": best_metric, @@ -93,7 +95,7 @@ def save_checkpoint( if scheduler is not None: checkpoint["scheduler_state_dict"] = scheduler.state_dict() if extra_state is not None: - checkpoint.update(extra_state) + checkpoint["extra_state"] = extra_state torch.save(checkpoint, save_path) return str(save_path) @@ -110,6 +112,17 @@ def load_checkpoint( ) -> dict[str, Any]: checkpoint = torch.load(path, map_location=map_location) + if not isinstance(checkpoint, dict): + raise InvalidCheckpointTypeError(path, type(checkpoint).__name__) + + if any( + key in checkpoint + for key in ("epoch", "global_step", "best_metric", "meta", "extra_state", "format_version") + ): + pass + elif "model_state_dict" not in checkpoint: + raise UnsupportedCheckpointFormatError(path) + if model is not None and "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"], strict=strict) if optimizer is not None and "optimizer_state_dict" in checkpoint: diff --git a/tinyexp/exceptions.py b/tinyexp/exceptions.py index 38bb8c7..a92c707 100644 --- a/tinyexp/exceptions.py +++ b/tinyexp/exceptions.py @@ -44,3 +44,18 @@ def __init__(self, launcher: str, allowed: Sequence[str] = ("python", "torchrun" class CudaNotAvailableError(RuntimeError): def __init__(self) -> None: super().__init__("CUDA is required but not available.") + + +class InvalidCheckpointTypeError(TypeError): + def __init__(self, path: str, checkpoint_type: str) -> None: + self.path = path + self.checkpoint_type = checkpoint_type + super().__init__(f"Checkpoint at {path} must be a dict, got {checkpoint_type}.") + + +class UnsupportedCheckpointFormatError(ValueError): + def __init__(self, path: str) -> None: + self.path = path + super().__init__( + f"Checkpoint at {path} is not a supported tinyexp checkpoint format and does not contain model_state_dict." + ) From b3085eda36ac2a5afa11bcb28c054ce8bc73d9ea Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 23:09:49 +0800 Subject: [PATCH 16/30] fix: validate required checkpoint state --- tests/test_tinyexp_artifacts.py | 38 ++++++++++++++++++++++++++++++++- tinyexp/__init__.py | 19 +++++++++++++---- tinyexp/exceptions.py | 7 ++++++ 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/tests/test_tinyexp_artifacts.py b/tests/test_tinyexp_artifacts.py index 2378cfc..948cb59 100644 --- a/tests/test_tinyexp_artifacts.py +++ b/tests/test_tinyexp_artifacts.py @@ -6,7 +6,7 @@ import torch from tinyexp import CheckpointCfg, TinyExp -from tinyexp.exceptions import UnsupportedCheckpointFormatError +from tinyexp.exceptions import MissingCheckpointStateError, UnsupportedCheckpointFormatError def test_get_run_dir(tmp_path: Path) -> None: @@ -113,3 +113,39 @@ def test_checkpoint_cfg_rejects_unsupported_model_only_format(tmp_path: Path) -> with pytest.raises(UnsupportedCheckpointFormatError, match="not a supported tinyexp checkpoint format"): checkpoint_cfg.load_checkpoint(str(checkpoint_path)) + + +def test_checkpoint_cfg_requires_model_state_when_model_is_provided(tmp_path: Path) -> None: + checkpoint_path = tmp_path / "missing_model_state.ckpt" + torch.save({"format_version": 1, "meta": {}, "epoch": 1}, checkpoint_path) + + checkpoint_cfg = CheckpointCfg() + model = torch.nn.Linear(2, 1) + + with pytest.raises(MissingCheckpointStateError, match="model_state_dict"): + checkpoint_cfg.load_checkpoint(str(checkpoint_path), model=model) + + +def test_checkpoint_cfg_requires_optimizer_state_when_optimizer_is_provided(tmp_path: Path) -> None: + checkpoint_path = tmp_path / "missing_optimizer_state.ckpt" + torch.save({"format_version": 1, "meta": {}, "model_state_dict": {}}, checkpoint_path) + + checkpoint_cfg = CheckpointCfg() + model = torch.nn.Linear(2, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + with pytest.raises(MissingCheckpointStateError, match="optimizer_state_dict"): + checkpoint_cfg.load_checkpoint(str(checkpoint_path), optimizer=optimizer) + + +def test_checkpoint_cfg_requires_scheduler_state_when_scheduler_is_provided(tmp_path: Path) -> None: + checkpoint_path = tmp_path / "missing_scheduler_state.ckpt" + torch.save({"format_version": 1, "meta": {}, "model_state_dict": {}, "optimizer_state_dict": {}}, checkpoint_path) + + checkpoint_cfg = CheckpointCfg() + model = torch.nn.Linear(2, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + + with pytest.raises(MissingCheckpointStateError, match="scheduler_state_dict"): + checkpoint_cfg.load_checkpoint(str(checkpoint_path), scheduler=scheduler) diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index c0b1007..cbe9078 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -15,7 +15,12 @@ from omegaconf import DictConfig, OmegaConf from omegaconf.listconfig import ListConfig -from .exceptions import InvalidCheckpointTypeError, UnknownConfigurationKeyError, UnsupportedCheckpointFormatError +from .exceptions import ( + InvalidCheckpointTypeError, + MissingCheckpointStateError, + UnknownConfigurationKeyError, + UnsupportedCheckpointFormatError, +) from .utils.log_utils import tiny_logger_setup from .utils.ray_utils import simple_launch_exp @@ -123,11 +128,17 @@ def load_checkpoint( elif "model_state_dict" not in checkpoint: raise UnsupportedCheckpointFormatError(path) - if model is not None and "model_state_dict" in checkpoint: + if model is not None: + if "model_state_dict" not in checkpoint: + raise MissingCheckpointStateError(path, "model_state_dict") model.load_state_dict(checkpoint["model_state_dict"], strict=strict) - if optimizer is not None and "optimizer_state_dict" in checkpoint: + if optimizer is not None: + if "optimizer_state_dict" not in checkpoint: + raise MissingCheckpointStateError(path, "optimizer_state_dict") optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - if scheduler is not None and "scheduler_state_dict" in checkpoint: + if scheduler is not None: + if "scheduler_state_dict" not in checkpoint: + raise MissingCheckpointStateError(path, "scheduler_state_dict") scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) return checkpoint diff --git a/tinyexp/exceptions.py b/tinyexp/exceptions.py index a92c707..bd8a3ba 100644 --- a/tinyexp/exceptions.py +++ b/tinyexp/exceptions.py @@ -59,3 +59,10 @@ def __init__(self, path: str) -> None: super().__init__( f"Checkpoint at {path} is not a supported tinyexp checkpoint format and does not contain model_state_dict." ) + + +class MissingCheckpointStateError(KeyError): + def __init__(self, path: str, state_name: str) -> None: + self.path = path + self.state_name = state_name + super().__init__(f"Checkpoint at {path} does not contain required {state_name}.") From 662744ee2fb0a108aba732fe0b9fcc42ad213fd1 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 23:17:10 +0800 Subject: [PATCH 17/30] refactor: simplify checkpoint error types --- tests/test_tinyexp_artifacts.py | 8 ++++---- tinyexp/__init__.py | 15 +++++---------- tinyexp/exceptions.py | 14 -------------- 3 files changed, 9 insertions(+), 28 deletions(-) diff --git a/tests/test_tinyexp_artifacts.py b/tests/test_tinyexp_artifacts.py index 948cb59..69239f7 100644 --- a/tests/test_tinyexp_artifacts.py +++ b/tests/test_tinyexp_artifacts.py @@ -6,7 +6,7 @@ import torch from tinyexp import CheckpointCfg, TinyExp -from tinyexp.exceptions import MissingCheckpointStateError, UnsupportedCheckpointFormatError +from tinyexp.exceptions import UnsupportedCheckpointFormatError def test_get_run_dir(tmp_path: Path) -> None: @@ -122,7 +122,7 @@ def test_checkpoint_cfg_requires_model_state_when_model_is_provided(tmp_path: Pa checkpoint_cfg = CheckpointCfg() model = torch.nn.Linear(2, 1) - with pytest.raises(MissingCheckpointStateError, match="model_state_dict"): + with pytest.raises(KeyError, match="model_state_dict"): checkpoint_cfg.load_checkpoint(str(checkpoint_path), model=model) @@ -134,7 +134,7 @@ def test_checkpoint_cfg_requires_optimizer_state_when_optimizer_is_provided(tmp_ model = torch.nn.Linear(2, 1) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) - with pytest.raises(MissingCheckpointStateError, match="optimizer_state_dict"): + with pytest.raises(KeyError, match="optimizer_state_dict"): checkpoint_cfg.load_checkpoint(str(checkpoint_path), optimizer=optimizer) @@ -147,5 +147,5 @@ def test_checkpoint_cfg_requires_scheduler_state_when_scheduler_is_provided(tmp_ optimizer = torch.optim.SGD(model.parameters(), lr=0.1) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) - with pytest.raises(MissingCheckpointStateError, match="scheduler_state_dict"): + with pytest.raises(KeyError, match="scheduler_state_dict"): checkpoint_cfg.load_checkpoint(str(checkpoint_path), scheduler=scheduler) diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index cbe9078..f7a7801 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -15,12 +15,7 @@ from omegaconf import DictConfig, OmegaConf from omegaconf.listconfig import ListConfig -from .exceptions import ( - InvalidCheckpointTypeError, - MissingCheckpointStateError, - UnknownConfigurationKeyError, - UnsupportedCheckpointFormatError, -) +from .exceptions import UnknownConfigurationKeyError, UnsupportedCheckpointFormatError from .utils.log_utils import tiny_logger_setup from .utils.ray_utils import simple_launch_exp @@ -118,7 +113,7 @@ def load_checkpoint( checkpoint = torch.load(path, map_location=map_location) if not isinstance(checkpoint, dict): - raise InvalidCheckpointTypeError(path, type(checkpoint).__name__) + raise TypeError(type(checkpoint).__name__) if any( key in checkpoint @@ -130,15 +125,15 @@ def load_checkpoint( if model is not None: if "model_state_dict" not in checkpoint: - raise MissingCheckpointStateError(path, "model_state_dict") + raise KeyError("model_state_dict") model.load_state_dict(checkpoint["model_state_dict"], strict=strict) if optimizer is not None: if "optimizer_state_dict" not in checkpoint: - raise MissingCheckpointStateError(path, "optimizer_state_dict") + raise KeyError("optimizer_state_dict") optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) if scheduler is not None: if "scheduler_state_dict" not in checkpoint: - raise MissingCheckpointStateError(path, "scheduler_state_dict") + raise KeyError("scheduler_state_dict") scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) return checkpoint diff --git a/tinyexp/exceptions.py b/tinyexp/exceptions.py index bd8a3ba..b190481 100644 --- a/tinyexp/exceptions.py +++ b/tinyexp/exceptions.py @@ -46,23 +46,9 @@ def __init__(self) -> None: super().__init__("CUDA is required but not available.") -class InvalidCheckpointTypeError(TypeError): - def __init__(self, path: str, checkpoint_type: str) -> None: - self.path = path - self.checkpoint_type = checkpoint_type - super().__init__(f"Checkpoint at {path} must be a dict, got {checkpoint_type}.") - - class UnsupportedCheckpointFormatError(ValueError): def __init__(self, path: str) -> None: self.path = path super().__init__( f"Checkpoint at {path} is not a supported tinyexp checkpoint format and does not contain model_state_dict." ) - - -class MissingCheckpointStateError(KeyError): - def __init__(self, path: str, state_name: str) -> None: - self.path = path - self.state_name = state_name - super().__init__(f"Checkpoint at {path} does not contain required {state_name}.") From c10999775c8c9669060b4d51a0439ec7c61fa4b1 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 23:21:45 +0800 Subject: [PATCH 18/30] fix: validate checkpoint format version --- tests/test_tinyexp_artifacts.py | 20 ++++++++++++ tinyexp/__init__.py | 55 +++++++++++++++++++++++---------- 2 files changed, 58 insertions(+), 17 deletions(-) diff --git a/tests/test_tinyexp_artifacts.py b/tests/test_tinyexp_artifacts.py index 69239f7..5670e7f 100644 --- a/tests/test_tinyexp_artifacts.py +++ b/tests/test_tinyexp_artifacts.py @@ -115,6 +115,26 @@ def test_checkpoint_cfg_rejects_unsupported_model_only_format(tmp_path: Path) -> checkpoint_cfg.load_checkpoint(str(checkpoint_path)) +def test_checkpoint_cfg_rejects_unsupported_format_version(tmp_path: Path) -> None: + checkpoint_path = tmp_path / "unsupported_version.ckpt" + torch.save({"format_version": 999, "meta": {}, "model_state_dict": {}}, checkpoint_path) + + checkpoint_cfg = CheckpointCfg() + + with pytest.raises(ValueError, match="unsupported format_version 999"): + checkpoint_cfg.load_checkpoint(str(checkpoint_path)) + + +def test_checkpoint_cfg_rejects_non_dict_payload(tmp_path: Path) -> None: + checkpoint_path = tmp_path / "not_a_dict.ckpt" + torch.save([1, 2, 3], checkpoint_path) + + checkpoint_cfg = CheckpointCfg() + + with pytest.raises(TypeError, match="must be a dict, got list"): + checkpoint_cfg.load_checkpoint(str(checkpoint_path)) + + def test_checkpoint_cfg_requires_model_state_when_model_is_provided(tmp_path: Path) -> None: checkpoint_path = tmp_path / "missing_model_state.ckpt" torch.save({"format_version": 1, "meta": {}, "epoch": 1}, checkpoint_path) diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index f7a7801..13515a9 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -100,29 +100,37 @@ def save_checkpoint( torch.save(checkpoint, save_path) return str(save_path) - def load_checkpoint( - self, - path: str, - *, - model=None, - optimizer=None, - scheduler=None, - strict: bool = True, - map_location=None, - ) -> dict[str, Any]: - checkpoint = torch.load(path, map_location=map_location) + def _invalid_checkpoint_type_error(self, path: str, checkpoint: Any) -> TypeError: + return TypeError(f"Checkpoint at {path} must be a dict, got {type(checkpoint).__name__}.") + def _unsupported_checkpoint_version_error(self, path: str, checkpoint_format_version: Any) -> ValueError: + return ValueError( + f"Checkpoint at {path} has unsupported format_version {checkpoint_format_version}; " + f"expected {self.format_version}." + ) + + def _validate_checkpoint_payload(self, path: str, checkpoint: Any) -> dict[str, Any]: if not isinstance(checkpoint, dict): - raise TypeError(type(checkpoint).__name__) + raise self._invalid_checkpoint_type_error(path, checkpoint) + + checkpoint_format_version = checkpoint.get("format_version") + if checkpoint_format_version is not None and checkpoint_format_version != self.format_version: + raise self._unsupported_checkpoint_version_error(path, checkpoint_format_version) - if any( - key in checkpoint - for key in ("epoch", "global_step", "best_metric", "meta", "extra_state", "format_version") + if ( + not any( + key in checkpoint + for key in ("epoch", "global_step", "best_metric", "meta", "extra_state", "format_version") + ) + and "model_state_dict" not in checkpoint ): - pass - elif "model_state_dict" not in checkpoint: raise UnsupportedCheckpointFormatError(path) + return checkpoint + + def _load_required_state( + self, checkpoint: dict[str, Any], *, model=None, optimizer=None, scheduler=None, strict: bool = True + ) -> None: if model is not None: if "model_state_dict" not in checkpoint: raise KeyError("model_state_dict") @@ -136,6 +144,19 @@ def load_checkpoint( raise KeyError("scheduler_state_dict") scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + def load_checkpoint( + self, + path: str, + *, + model=None, + optimizer=None, + scheduler=None, + strict: bool = True, + map_location=None, + ) -> dict[str, Any]: + checkpoint = self._validate_checkpoint_payload(path, torch.load(path, map_location=map_location)) + self._load_required_state(checkpoint, model=model, optimizer=optimizer, scheduler=scheduler, strict=strict) + return checkpoint From 0953f8213ff180cbbf2c97f48853ff73f85b8a84 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 23:34:20 +0800 Subject: [PATCH 19/30] refactor: remove checkpoint error helpers --- tinyexp/__init__.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index 13515a9..c6b15d1 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -100,22 +100,16 @@ def save_checkpoint( torch.save(checkpoint, save_path) return str(save_path) - def _invalid_checkpoint_type_error(self, path: str, checkpoint: Any) -> TypeError: - return TypeError(f"Checkpoint at {path} must be a dict, got {type(checkpoint).__name__}.") - - def _unsupported_checkpoint_version_error(self, path: str, checkpoint_format_version: Any) -> ValueError: - return ValueError( - f"Checkpoint at {path} has unsupported format_version {checkpoint_format_version}; " - f"expected {self.format_version}." - ) - def _validate_checkpoint_payload(self, path: str, checkpoint: Any) -> dict[str, Any]: if not isinstance(checkpoint, dict): - raise self._invalid_checkpoint_type_error(path, checkpoint) + raise TypeError(f"Checkpoint at {path} must be a dict, got {type(checkpoint).__name__}.") # noqa: TRY003 checkpoint_format_version = checkpoint.get("format_version") if checkpoint_format_version is not None and checkpoint_format_version != self.format_version: - raise self._unsupported_checkpoint_version_error(path, checkpoint_format_version) + raise ValueError( # noqa: TRY003 + f"Checkpoint at {path} has unsupported format_version {checkpoint_format_version}; " + f"expected {self.format_version}." + ) if ( not any( From 7738d5b385adc3020708a4e5357a99ef4f67dd6c Mon Sep 17 00:00:00 2001 From: Zane Li Date: Mon, 30 Mar 2026 23:38:59 +0800 Subject: [PATCH 20/30] refactor: remove checkpoint format version --- tests/test_tinyexp_artifacts.py | 17 +++-------------- tinyexp/__init__.py | 14 +------------- 2 files changed, 4 insertions(+), 27 deletions(-) diff --git a/tests/test_tinyexp_artifacts.py b/tests/test_tinyexp_artifacts.py index 5670e7f..84c72f7 100644 --- a/tests/test_tinyexp_artifacts.py +++ b/tests/test_tinyexp_artifacts.py @@ -78,7 +78,6 @@ def test_checkpoint_cfg_save_and_load_roundtrip(tmp_path: Path) -> None: assert checkpoint["epoch"] == 3 assert checkpoint["global_step"] == 12 assert checkpoint["best_metric"] == 0.9 - assert checkpoint["format_version"] == 1 assert checkpoint["extra_state"]["custom_value"] == 7 assert checkpoint["meta"]["exp_name"] == "demo_exp" assert checkpoint["meta"]["exp_class"] == "tests.demo.Exp" @@ -115,16 +114,6 @@ def test_checkpoint_cfg_rejects_unsupported_model_only_format(tmp_path: Path) -> checkpoint_cfg.load_checkpoint(str(checkpoint_path)) -def test_checkpoint_cfg_rejects_unsupported_format_version(tmp_path: Path) -> None: - checkpoint_path = tmp_path / "unsupported_version.ckpt" - torch.save({"format_version": 999, "meta": {}, "model_state_dict": {}}, checkpoint_path) - - checkpoint_cfg = CheckpointCfg() - - with pytest.raises(ValueError, match="unsupported format_version 999"): - checkpoint_cfg.load_checkpoint(str(checkpoint_path)) - - def test_checkpoint_cfg_rejects_non_dict_payload(tmp_path: Path) -> None: checkpoint_path = tmp_path / "not_a_dict.ckpt" torch.save([1, 2, 3], checkpoint_path) @@ -137,7 +126,7 @@ def test_checkpoint_cfg_rejects_non_dict_payload(tmp_path: Path) -> None: def test_checkpoint_cfg_requires_model_state_when_model_is_provided(tmp_path: Path) -> None: checkpoint_path = tmp_path / "missing_model_state.ckpt" - torch.save({"format_version": 1, "meta": {}, "epoch": 1}, checkpoint_path) + torch.save({"meta": {}, "epoch": 1}, checkpoint_path) checkpoint_cfg = CheckpointCfg() model = torch.nn.Linear(2, 1) @@ -148,7 +137,7 @@ def test_checkpoint_cfg_requires_model_state_when_model_is_provided(tmp_path: Pa def test_checkpoint_cfg_requires_optimizer_state_when_optimizer_is_provided(tmp_path: Path) -> None: checkpoint_path = tmp_path / "missing_optimizer_state.ckpt" - torch.save({"format_version": 1, "meta": {}, "model_state_dict": {}}, checkpoint_path) + torch.save({"meta": {}, "model_state_dict": {}}, checkpoint_path) checkpoint_cfg = CheckpointCfg() model = torch.nn.Linear(2, 1) @@ -160,7 +149,7 @@ def test_checkpoint_cfg_requires_optimizer_state_when_optimizer_is_provided(tmp_ def test_checkpoint_cfg_requires_scheduler_state_when_scheduler_is_provided(tmp_path: Path) -> None: checkpoint_path = tmp_path / "missing_scheduler_state.ckpt" - torch.save({"format_version": 1, "meta": {}, "model_state_dict": {}, "optimizer_state_dict": {}}, checkpoint_path) + torch.save({"meta": {}, "model_state_dict": {}, "optimizer_state_dict": {}}, checkpoint_path) checkpoint_cfg = CheckpointCfg() model = torch.nn.Linear(2, 1) diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index c6b15d1..47a5cb8 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -55,7 +55,6 @@ def _is_main_process() -> bool: @dataclass class CheckpointCfg: - format_version: int = 1 last_ckpt_name: str = "last.ckpt" best_ckpt_name: str = "best.ckpt" @@ -78,7 +77,6 @@ def save_checkpoint( save_path.parent.mkdir(parents=True, exist_ok=True) checkpoint: dict[str, Any] = { - "format_version": self.format_version, "epoch": epoch, "global_step": global_step, "best_metric": best_metric, @@ -104,18 +102,8 @@ def _validate_checkpoint_payload(self, path: str, checkpoint: Any) -> dict[str, if not isinstance(checkpoint, dict): raise TypeError(f"Checkpoint at {path} must be a dict, got {type(checkpoint).__name__}.") # noqa: TRY003 - checkpoint_format_version = checkpoint.get("format_version") - if checkpoint_format_version is not None and checkpoint_format_version != self.format_version: - raise ValueError( # noqa: TRY003 - f"Checkpoint at {path} has unsupported format_version {checkpoint_format_version}; " - f"expected {self.format_version}." - ) - if ( - not any( - key in checkpoint - for key in ("epoch", "global_step", "best_metric", "meta", "extra_state", "format_version") - ) + not any(key in checkpoint for key in ("epoch", "global_step", "best_metric", "meta", "extra_state")) and "model_state_dict" not in checkpoint ): raise UnsupportedCheckpointFormatError(path) From bae724e0a322e120aa308ba50a27b211f039ad8d Mon Sep 17 00:00:00 2001 From: Zane Li Date: Tue, 31 Mar 2026 09:29:12 +0800 Subject: [PATCH 21/30] fix: register missing hydra env config --- tinyexp/__init__.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index 47a5cb8..353e01f 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -28,10 +28,27 @@ class _HydraConfig(HydraConf): To avoid hydra output the config in unexpected directory. """ + defaults: list[Any] = field( + default_factory=lambda: [ + {"output": "default"}, + {"launcher": "basic"}, + {"sweeper": "basic"}, + {"help": "default"}, + {"hydra_help": "default"}, + {"hydra_logging": "default"}, + {"job_logging": "default"}, + {"callbacks": None}, + ] + ) output_subdir: Optional[str] = None run: RunDir = field(default_factory=lambda: RunDir("./output")) +# Hydra 1.3.2 in this environment expects hydra/env/default during composition, +# but does not register a built-in config for it. +ConfigStore.instance().store(group="hydra/env", name="default", node={}, provider="tinyexp") + + def _default_exp_name() -> str: """ Get the default experiment name from the main module or the command line. From 186a62c6b0fbd9ea76290e9bee57275a1559965f Mon Sep 17 00:00:00 2001 From: Zane Li Date: Tue, 31 Mar 2026 09:33:54 +0800 Subject: [PATCH 22/30] fix: backfill missing hydra env config safely --- tinyexp/__init__.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index 353e01f..16a6fd5 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -44,9 +44,24 @@ class _HydraConfig(HydraConf): run: RunDir = field(default_factory=lambda: RunDir("./output")) -# Hydra 1.3.2 in this environment expects hydra/env/default during composition, -# but does not register a built-in config for it. -ConfigStore.instance().store(group="hydra/env", name="default", node={}, provider="tinyexp") +def _ensure_hydra_env_default() -> None: + cs = ConfigStore.instance() + cur = cs.repo + for group in ("hydra", "env"): + next_cur = cur.get(group) + if not isinstance(next_cur, dict): + next_cur = {} + cur[group] = next_cur + cur = next_cur + + if "default.yaml" not in cur: + cs.store(group="hydra/env", name="default", node={}, provider="tinyexp") + + +# Some Hydra installations expect hydra/env/default during composition, but do +# not register a built-in config for it. Only backfill it when missing so older +# or fuller Hydra setups keep their own env config untouched. +_ensure_hydra_env_default() def _default_exp_name() -> str: From 15027ee573498123061324cbf2d1b9280e4d150c Mon Sep 17 00:00:00 2001 From: Zane Li Date: Tue, 31 Mar 2026 10:11:30 +0800 Subject: [PATCH 23/30] revert: drop hydra env backfill workaround --- tinyexp/__init__.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index 16a6fd5..47a5cb8 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -28,42 +28,10 @@ class _HydraConfig(HydraConf): To avoid hydra output the config in unexpected directory. """ - defaults: list[Any] = field( - default_factory=lambda: [ - {"output": "default"}, - {"launcher": "basic"}, - {"sweeper": "basic"}, - {"help": "default"}, - {"hydra_help": "default"}, - {"hydra_logging": "default"}, - {"job_logging": "default"}, - {"callbacks": None}, - ] - ) output_subdir: Optional[str] = None run: RunDir = field(default_factory=lambda: RunDir("./output")) -def _ensure_hydra_env_default() -> None: - cs = ConfigStore.instance() - cur = cs.repo - for group in ("hydra", "env"): - next_cur = cur.get(group) - if not isinstance(next_cur, dict): - next_cur = {} - cur[group] = next_cur - cur = next_cur - - if "default.yaml" not in cur: - cs.store(group="hydra/env", name="default", node={}, provider="tinyexp") - - -# Some Hydra installations expect hydra/env/default during composition, but do -# not register a built-in config for it. Only backfill it when missing so older -# or fuller Hydra setups keep their own env config untouched. -_ensure_hydra_env_default() - - def _default_exp_name() -> str: """ Get the default experiment name from the main module or the command line. From d29d07e79e6d3b78f82f5884b7d0a29a091086ff Mon Sep 17 00:00:00 2001 From: Zane Li Date: Tue, 31 Mar 2026 11:20:28 +0800 Subject: [PATCH 24/30] docs: sync phase 1 plans with current helpers --- docs/phase1-api-draft.md | 10 ++-------- docs/phase1-file-by-file-plan.md | 6 +----- docs/phase1-minimal-helpers.md | 7 ++++--- 3 files changed, 7 insertions(+), 16 deletions(-) diff --git a/docs/phase1-api-draft.md b/docs/phase1-api-draft.md index 3d1002d..9b7ed29 100644 --- a/docs/phase1-api-draft.md +++ b/docs/phase1-api-draft.md @@ -54,9 +54,6 @@ Recommended Phase 1 methods on `TinyExp`: def get_run_dir(self) -> str: ... -def ensure_run_dir(self) -> str: - ... - def dump_config(self, path: str | None = None) -> str: ... ``` @@ -192,15 +189,12 @@ Run directory behavior should remain simple in Phase 1. ```python def get_run_dir(self) -> str: ... - -def ensure_run_dir(self) -> str: - ... ``` ### Expected behavior - `get_run_dir()` returns `os.path.join(self.output_root, self.exp_name)` -- `ensure_run_dir()` creates the directory and returns it +- methods that write files should create parent directories when needed Phase 1 should not add timestamped run folders, version managers, or heavier run registry behavior. @@ -211,7 +205,7 @@ Below is the intended style for examples after Phase 1. ### Logger setup ```python -run_dir = self.ensure_run_dir() +run_dir = self.get_run_dir() logger = self.logger_cfg.build_logger( save_dir=run_dir, distributed_rank=accelerator.rank, diff --git a/docs/phase1-file-by-file-plan.md b/docs/phase1-file-by-file-plan.md index fc811c3..21e865a 100644 --- a/docs/phase1-file-by-file-plan.md +++ b/docs/phase1-file-by-file-plan.md @@ -88,7 +88,6 @@ These belong at the experiment level because they describe run intent rather tha Recommended methods on `TinyExp`: - `get_run_dir() -> str` -- `ensure_run_dir() -> str` - `dump_config(path: str | None = None) -> str` These are good fits for `TinyExp` because they are experiment-scoped rather than belonging to a single feature config. @@ -99,7 +98,6 @@ Avoid adding many feature-specific top-level methods such as: - `save_checkpoint(...)` - `load_checkpoint(...)` -- `maybe_resume(...)` Those are better expressed through a focused config component. @@ -174,7 +172,7 @@ The MNIST example is: Expected updates: -- call `self.ensure_run_dir()` +- call `self.get_run_dir()` - build the logger using `self.logger_cfg.build_logger(...)` - call `self.dump_config()` - branch on `self.mode` @@ -240,7 +238,6 @@ Phase 1 needs lightweight but meaningful coverage. Add or expand tests for: - `get_run_dir()` -- `ensure_run_dir()` - `dump_config()` These can live in: @@ -291,7 +288,6 @@ The order below minimizes risk and keeps the design easy to validate. - `mode` - `resume_from` - `get_run_dir()` - - `ensure_run_dir()` - `dump_config()` - `CheckpointCfg` 2. add checkpoint-focused tests diff --git a/docs/phase1-minimal-helpers.md b/docs/phase1-minimal-helpers.md index 0c860bf..e3e2fb2 100644 --- a/docs/phase1-minimal-helpers.md +++ b/docs/phase1-minimal-helpers.md @@ -76,7 +76,6 @@ More policy-driven settings should stay in examples unless they prove broadly re The following methods are the proposed Phase 1 surface area: - `get_run_dir() -> str` -- `ensure_run_dir() -> str` - `dump_config(path: str | None = None) -> str` - `log_metrics(metrics: dict, *, step: int | None = None, epoch: int | None = None, filename: str = "metrics.jsonl") -> None` @@ -125,9 +124,11 @@ management system. `get_run_dir()` should return the default run directory for the current experiment. -`ensure_run_dir()` should create that directory if needed and return it. +Directory creation should happen in the method that actually writes files, such as logger setup, +config dumping, or checkpoint saving. -These helpers should not introduce a large naming or versioning system in Phase 1. +This keeps `TinyExp` smaller and avoids a separate side-effect helper whose behavior can stay explicit +at the write boundary. ### Config dumping From 326842270a8aa64cddcbe567cc1c0613dffb7e7f Mon Sep 17 00:00:00 2001 From: Zane Li Date: Tue, 31 Mar 2026 11:23:35 +0800 Subject: [PATCH 25/30] docs: remove unimplemented metric logging plans --- docs/phase1-minimal-helpers.md | 27 +++++---------------------- docs/philosophy.md | 1 - 2 files changed, 5 insertions(+), 23 deletions(-) diff --git a/docs/phase1-minimal-helpers.md b/docs/phase1-minimal-helpers.md index e3e2fb2..e88b55c 100644 --- a/docs/phase1-minimal-helpers.md +++ b/docs/phase1-minimal-helpers.md @@ -77,7 +77,6 @@ The following methods are the proposed Phase 1 surface area: - `get_run_dir() -> str` - `dump_config(path: str | None = None) -> str` -- `log_metrics(metrics: dict, *, step: int | None = None, epoch: int | None = None, filename: str = "metrics.jsonl") -> None` These are helpers, not control-flow abstractions. @@ -110,7 +109,6 @@ Phase 1 should establish simple, stable artifact conventions. The recommended default run layout is: - `output//config.yaml` -- `output//metrics.jsonl` - `output//last.ckpt` - `output//best.ckpt` - `output//log.txt` @@ -140,19 +138,6 @@ Expected behavior: - output reflects current config state after overrides - writing should happen only from the main process when running distributed -### Metric logging - -`log_metrics()` should append structured records to a local JSONL file. - -Expected behavior: - -- default file is `/metrics.jsonl` -- each record should include the provided metrics -- helper may also attach lightweight metadata such as timestamp, step, and epoch -- writing should happen only from the main process - -This gives TinyExp a useful local record format without introducing a full tracker framework. - ### Checkpoint helpers `checkpoint_cfg.save_checkpoint()` and `checkpoint_cfg.load_checkpoint()` should provide a standard way to persist and @@ -234,7 +219,6 @@ Recommended test coverage: - unit tests for run directory creation - unit tests for config dumping -- unit tests for metric logging - unit tests for checkpoint save/load - a small integration test for `mode=val` @@ -246,11 +230,10 @@ Recommended implementation order: 1. add run directory helpers 2. add config dumping -3. add metric logging -4. add `CheckpointCfg` with save/load -5. migrate `mnist_exp.py` -6. add `mode=val` -7. add tests +3. add `CheckpointCfg` with save/load +4. migrate `mnist_exp.py` +5. add `mode=val` +6. add tests This order keeps each change small and easy to validate. @@ -260,7 +243,7 @@ Phase 1 is successful if TinyExp can do all of the following while still feeling - keep experiments centered around one explicit entrypoint - preserve user-owned training loops -- save config and local metrics in a standard way +- save config in a standard way - save and resume checkpoints with minimal boilerplate - support a simple validation flow from a checkpoint diff --git a/docs/philosophy.md b/docs/philosophy.md index d31fdcd..0373fed 100644 --- a/docs/philosophy.md +++ b/docs/philosophy.md @@ -95,7 +95,6 @@ TinyExp should provide thin, reusable helpers for common experiment chores, such - output directory setup - config dumping -- lightweight metric logging - checkpoint save/load helpers exposed through focused config components - launcher compatibility From 5c6bc5f4868769d5e5be6dcf069b038e9ae4682d Mon Sep 17 00:00:00 2001 From: Zane Li Date: Tue, 31 Mar 2026 11:39:31 +0800 Subject: [PATCH 26/30] refactor: remove config dump helper --- docs/phase1-api-draft.md | 25 +------------------------ docs/phase1-file-by-file-plan.md | 7 +------ docs/phase1-minimal-helpers.md | 24 +++++------------------- docs/philosophy.md | 1 - tests/examples/test_mnist_exp_run.py | 3 +-- tests/test_tinyexp_artifacts.py | 13 ------------- tinyexp/__init__.py | 13 +------------ tinyexp/examples/mnist_exp.py | 1 - 8 files changed, 9 insertions(+), 78 deletions(-) diff --git a/docs/phase1-api-draft.md b/docs/phase1-api-draft.md index 9b7ed29..1f50735 100644 --- a/docs/phase1-api-draft.md +++ b/docs/phase1-api-draft.md @@ -53,12 +53,9 @@ Recommended Phase 1 methods on `TinyExp`: ```python def get_run_dir(self) -> str: ... - -def dump_config(self, path: str | None = None) -> str: - ... ``` -These belong on `TinyExp` because they are experiment-scoped, not feature-scoped. +This belongs on `TinyExp` because it is experiment-scoped, not feature-scoped. ## `CheckpointCfg` Draft @@ -162,24 +159,6 @@ Notes: - `meta` should stay lightweight - `extra_state` can extend the structure without forcing premature abstraction -## Config Dump Draft - -Configuration dumping should remain an experiment-level helper. - -### Draft method - -```python -def dump_config(self, path: str | None = None) -> str: - ... -``` - -### Expected behavior - -- default path is `/config.yaml` -- output reflects current configuration after Hydra overrides -- dump should be safe to call from examples -- distributed runs should avoid duplicate writes - ## Run Directory Draft Run directory behavior should remain simple in Phase 1. @@ -210,7 +189,6 @@ logger = self.logger_cfg.build_logger( save_dir=run_dir, distributed_rank=accelerator.rank, ) -self.dump_config() ``` ### Explicit W&B usage @@ -280,7 +258,6 @@ This is the intended balance: The following questions should stay open until the first implementation slice proves itself: - whether metrics deserve their own `MetricCfg` -- whether config dumping should later move into a dedicated artifact cfg - whether `CheckpointCfg` should move into its own module - whether shared recipe base classes are worth introducing diff --git a/docs/phase1-file-by-file-plan.md b/docs/phase1-file-by-file-plan.md index 21e865a..5cb8e51 100644 --- a/docs/phase1-file-by-file-plan.md +++ b/docs/phase1-file-by-file-plan.md @@ -21,7 +21,6 @@ one pass. The first implementation slice should focus on: - a stable run directory helper -- explicit config dumping - a new `CheckpointCfg` - `mode=val` support through explicit checkpoint loading - a migration of the MNIST example to validate the design @@ -88,9 +87,8 @@ These belong at the experiment level because they describe run intent rather tha Recommended methods on `TinyExp`: - `get_run_dir() -> str` -- `dump_config(path: str | None = None) -> str` -These are good fits for `TinyExp` because they are experiment-scoped rather than belonging to a single feature config. +This is a good fit for `TinyExp` because it is experiment-scoped rather than belonging to a single feature config. ### What should not be added here @@ -174,7 +172,6 @@ Expected updates: - call `self.get_run_dir()` - build the logger using `self.logger_cfg.build_logger(...)` -- call `self.dump_config()` - branch on `self.mode` #### training should remain explicit @@ -238,7 +235,6 @@ Phase 1 needs lightweight but meaningful coverage. Add or expand tests for: - `get_run_dir()` -- `dump_config()` These can live in: @@ -288,7 +284,6 @@ The order below minimizes risk and keeps the design easy to validate. - `mode` - `resume_from` - `get_run_dir()` - - `dump_config()` - `CheckpointCfg` 2. add checkpoint-focused tests 3. migrate `tinyexp/examples/mnist_exp.py` diff --git a/docs/phase1-minimal-helpers.md b/docs/phase1-minimal-helpers.md index e88b55c..51d05bf 100644 --- a/docs/phase1-minimal-helpers.md +++ b/docs/phase1-minimal-helpers.md @@ -76,7 +76,6 @@ More policy-driven settings should stay in examples unless they prove broadly re The following methods are the proposed Phase 1 surface area: - `get_run_dir() -> str` -- `dump_config(path: str | None = None) -> str` These are helpers, not control-flow abstractions. @@ -108,7 +107,6 @@ Phase 1 should establish simple, stable artifact conventions. The recommended default run layout is: -- `output//config.yaml` - `output//last.ckpt` - `output//best.ckpt` - `output//log.txt` @@ -123,21 +121,11 @@ management system. `get_run_dir()` should return the default run directory for the current experiment. Directory creation should happen in the method that actually writes files, such as logger setup, -config dumping, or checkpoint saving. +or checkpoint saving. This keeps `TinyExp` smaller and avoids a separate side-effect helper whose behavior can stay explicit at the write boundary. -### Config dumping - -`dump_config()` should write the effective experiment configuration to YAML. - -Expected behavior: - -- default path is `/config.yaml` -- output reflects current config state after overrides -- writing should happen only from the main process when running distributed - ### Checkpoint helpers `checkpoint_cfg.save_checkpoint()` and `checkpoint_cfg.load_checkpoint()` should provide a standard way to persist and @@ -218,7 +206,6 @@ Phase 1 should be backed by lightweight tests. Recommended test coverage: - unit tests for run directory creation -- unit tests for config dumping - unit tests for checkpoint save/load - a small integration test for `mode=val` @@ -229,11 +216,10 @@ The tests should stay CPU-first and deterministic. Recommended implementation order: 1. add run directory helpers -2. add config dumping -3. add `CheckpointCfg` with save/load -4. migrate `mnist_exp.py` -5. add `mode=val` -6. add tests +2. add `CheckpointCfg` with save/load +3. migrate `mnist_exp.py` +4. add `mode=val` +5. add tests This order keeps each change small and easy to validate. diff --git a/docs/philosophy.md b/docs/philosophy.md index 0373fed..1f5b589 100644 --- a/docs/philosophy.md +++ b/docs/philosophy.md @@ -94,7 +94,6 @@ Users should be able to: TinyExp should provide thin, reusable helpers for common experiment chores, such as: - output directory setup -- config dumping - checkpoint save/load helpers exposed through focused config components - launcher compatibility diff --git a/tests/examples/test_mnist_exp_run.py b/tests/examples/test_mnist_exp_run.py index 4c21354..29a6ab2 100644 --- a/tests/examples/test_mnist_exp_run.py +++ b/tests/examples/test_mnist_exp_run.py @@ -21,7 +21,7 @@ def test_mnist_run_val_mode_requires_resume_from(tmp_path: Path, monkeypatch) -> exp.run() -def test_mnist_run_val_mode_uses_checkpoint_and_dumps_config(tmp_path: Path, monkeypatch) -> None: +def test_mnist_run_val_mode_uses_checkpoint(tmp_path: Path, monkeypatch) -> None: exp_for_ckpt = Exp(output_root=str(tmp_path), exp_name="mnist_val") checkpoint_path = exp_for_ckpt.checkpoint_cfg.save_checkpoint( run_dir=str(tmp_path / "mnist_val"), @@ -55,4 +55,3 @@ def fake_evaluate(*, accelerator, logger, module_or_module_path, val_dataloader= assert called["accelerator"] is dummy_accelerator assert called["logger"] is dummy_logger assert called["module_or_module_path"] == checkpoint_path - assert (tmp_path / "mnist_val" / "config.yaml").is_file() diff --git a/tests/test_tinyexp_artifacts.py b/tests/test_tinyexp_artifacts.py index 84c72f7..e1d74a2 100644 --- a/tests/test_tinyexp_artifacts.py +++ b/tests/test_tinyexp_artifacts.py @@ -16,19 +16,6 @@ def test_get_run_dir(tmp_path: Path) -> None: assert exp.get_run_dir() == str(expected) -def test_dump_config_writes_yaml(tmp_path: Path, monkeypatch) -> None: - monkeypatch.setenv("RANK", "0") - exp = TinyExp(output_root=str(tmp_path), exp_name="demo_exp", mode="val", resume_from="checkpoint.ckpt") - - dumped = Path(exp.dump_config()) - - assert dumped == tmp_path / "demo_exp" / "config.yaml" - content = dumped.read_text(encoding="utf-8") - assert "exp_name: demo_exp" in content - assert "mode: val" in content - assert "resume_from: checkpoint.ckpt" in content - - def test_logger_cfg_creates_run_dir(tmp_path: Path) -> None: exp = TinyExp(output_root=str(tmp_path), exp_name="demo_exp") run_dir = Path(exp.get_run_dir()) diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index 47a5cb8..ecc7c2b 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -12,7 +12,7 @@ import torch from hydra.conf import HydraConf, RunDir from hydra.core.config_store import ConfigStore -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig from omegaconf.listconfig import ListConfig from .exceptions import UnknownConfigurationKeyError, UnsupportedCheckpointFormatError @@ -206,17 +206,6 @@ def build_logger(self, save_dir: str, distributed_rank: int = 0, filename: str = def get_run_dir(self) -> str: return os.path.join(self.output_root, self.exp_name) - def dump_config(self, path: Optional[str] = None) -> str: - run_dir = self.get_run_dir() - dump_path = Path(path) if path is not None else Path(run_dir) / "config.yaml" - - if _is_main_process(): - dump_path.parent.mkdir(parents=True, exist_ok=True) - cfg_dict = OmegaConf.to_container(OmegaConf.structured(self), resolve=True) - dump_path.write_text(OmegaConf.to_yaml(cfg_dict), encoding="utf-8") - - return str(dump_path) - def set_cfg(self, cfg_hydra, cfg_object=None): if cfg_object is None: cfg_object = self diff --git a/tinyexp/examples/mnist_exp.py b/tinyexp/examples/mnist_exp.py index 80e46b9..dd47019 100644 --- a/tinyexp/examples/mnist_exp.py +++ b/tinyexp/examples/mnist_exp.py @@ -150,7 +150,6 @@ def run(self) -> None: logger = self.logger_cfg.build_logger(save_dir=run_dir, distributed_rank=accelerator.rank) cfg_dict = OmegaConf.to_container(OmegaConf.structured(self), resolve=True) del cfg_dict["hydra"] - self.dump_config() cfg_msg = OmegaConf.to_yaml(cfg_dict).strip().replace("\n", "\n ") logger.info(f"-------- Configurations --------\n {cfg_msg}") From 18e2cfcadcd503e13a779c84a385ad9bf03dabe3 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Tue, 31 Mar 2026 11:43:04 +0800 Subject: [PATCH 27/30] docs: strengthen framework design guidance --- docs/philosophy.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/docs/philosophy.md b/docs/philosophy.md index 1f5b589..dd3f530 100644 --- a/docs/philosophy.md +++ b/docs/philosophy.md @@ -119,6 +119,31 @@ If yes, it may belong in TinyExp. If it starts to own the user workflow, hide core control flow, or push the project toward a heavy trainer-style architecture, it probably does not belong in TinyExp. +This rule should be applied strictly. + +TinyExp should not add a new helper, artifact, field, or abstraction just because it sounds generally useful or is a +common pattern in other frameworks. A framework-level addition should only be kept when its value is clear in the +current project, not as a placeholder for possible future needs. + +In practice, that means asking: + +- is there real repeated boilerplate across examples today? +- does this introduce a genuinely useful capability, or only another way to express something already visible? +- if this were removed, would users lose something important, or only a convenience wrapper? +- is the benefit strong enough to justify one more method, field, file, or documented convention? + +If these questions do not have a strong answer, the feature should usually stay out of TinyExp. + +Examples of things that often fail this test are: + +- thin one-line wrapper helpers added only for style or lint appeasement +- duplicate artifacts that do not add clear value over the experiment definition and logs +- speculative schema fields added only for future-proofing before any real compatibility need exists + +For example, if the experiment class already defines the configuration and the runtime logger already records the +effective config, that does not automatically justify a separate `dump_config()` helper or a default `config.yaml` +artifact. Those should exist only if they solve a concrete current problem that the existing structure does not. + ## Recommended Boundary ### TinyExp should own @@ -149,6 +174,8 @@ When extending TinyExp, prefer: - recipe-style examples over framework-owned trainers - local clarity over generic indirection - composable building blocks over lifecycle machinery +- removing weak abstractions rather than keeping them "just in case" +- one clear representation of a concept over several overlapping ones In short: From 63d58861a0dcfe6546457cb50060084ff4c53bbf Mon Sep 17 00:00:00 2001 From: Zane Li Date: Tue, 31 Mar 2026 13:18:38 +0800 Subject: [PATCH 28/30] feat: add resnet val checkpoint flow --- tests/examples/test_resnet_exp_run.py | 59 +++++++++++++++++++++++++++ tinyexp/examples/mnist_exp.py | 7 +--- tinyexp/examples/resnet_exp.py | 9 ++-- 3 files changed, 66 insertions(+), 9 deletions(-) create mode 100644 tests/examples/test_resnet_exp_run.py diff --git a/tests/examples/test_resnet_exp_run.py b/tests/examples/test_resnet_exp_run.py new file mode 100644 index 0000000..30022be --- /dev/null +++ b/tests/examples/test_resnet_exp_run.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import pytest +import torch.nn as nn + +from tinyexp.examples.resnet_exp import ResNetExp + + +def test_resnet_run_val_mode_requires_resume_from(tmp_path: Path, monkeypatch) -> None: + exp = ResNetExp(output_root=str(tmp_path), exp_name="resnet_val", mode="val", resume_from="") + + dummy_accelerator = SimpleNamespace(rank=0, device="cpu", is_main_process=True) + dummy_logger = SimpleNamespace(info=lambda *args, **kwargs: None) + + monkeypatch.setattr(exp.accelerator_cfg, "build_accelerator", lambda: dummy_accelerator) + monkeypatch.setattr(exp.logger_cfg, "build_logger", lambda **kwargs: dummy_logger) + + with pytest.raises(ValueError, match="resume_from"): + exp.run() + + +def test_resnet_run_val_mode_uses_checkpoint(tmp_path: Path, monkeypatch) -> None: + exp_for_ckpt = ResNetExp(output_root=str(tmp_path), exp_name="resnet_val") + checkpoint_path = exp_for_ckpt.checkpoint_cfg.save_checkpoint( + run_dir=str(tmp_path / "resnet_val"), + name="demo.ckpt", + model=nn.Linear(2, 2), + exp_name=exp_for_ckpt.exp_name, + exp_class=exp_for_ckpt.exp_class, + ) + + exp = ResNetExp(output_root=str(tmp_path), exp_name="resnet_val", mode="val", resume_from=checkpoint_path) + + dummy_accelerator = SimpleNamespace(rank=0, device="cpu", is_main_process=True) + dummy_logger = SimpleNamespace(info=lambda *args, **kwargs: None) + + monkeypatch.setattr(exp.accelerator_cfg, "build_accelerator", lambda: dummy_accelerator) + monkeypatch.setattr(exp.logger_cfg, "build_logger", lambda **kwargs: dummy_logger) + + called: dict[str, object] = {} + + def fake_evaluate(*, accelerator, logger, module_or_module_path, val_dataloader=None): + called["accelerator"] = accelerator + called["logger"] = logger + called["module_or_module_path"] = module_or_module_path + called["val_dataloader"] = val_dataloader + return 0.5 + + monkeypatch.setattr(exp, "_evaluate", fake_evaluate) + + exp.run() + + assert called["accelerator"] is dummy_accelerator + assert called["logger"] is dummy_logger + assert called["module_or_module_path"] == checkpoint_path + assert called["val_dataloader"] is None diff --git a/tinyexp/examples/mnist_exp.py b/tinyexp/examples/mnist_exp.py index dd47019..ccdfdd1 100644 --- a/tinyexp/examples/mnist_exp.py +++ b/tinyexp/examples/mnist_exp.py @@ -14,11 +14,6 @@ from tinyexp.exceptions import UnknownAcceleratorTypeError -class ResumeFromRequiredError(ValueError): - def __init__(self) -> None: - super().__init__("resume_from") - - class Net(nn.Module): def __init__(self) -> None: super().__init__() @@ -157,7 +152,7 @@ def run(self) -> None: self._train(accelerator=accelerator, logger=logger, cfg_dict=cfg_dict, run_dir=run_dir) elif self.mode == "val": if not self.resume_from: - raise ResumeFromRequiredError + raise ValueError("resume_from is required when mode='val'") # noqa: TRY003 self._evaluate(accelerator=accelerator, logger=logger, module_or_module_path=self.resume_from) else: raise NotImplementedError(f"Mode {self.mode} is not implemented") diff --git a/tinyexp/examples/resnet_exp.py b/tinyexp/examples/resnet_exp.py index f25fd44..cdc2594 100644 --- a/tinyexp/examples/resnet_exp.py +++ b/tinyexp/examples/resnet_exp.py @@ -327,9 +327,8 @@ def build_val_dataloader(self, accelerator): def run(self) -> None: accelerator = self.accelerator_cfg.build_accelerator() - logger = self.logger_cfg.build_logger( - save_dir=os.path.join(self.output_root, self.exp_name), distributed_rank=accelerator.rank - ) + run_dir = self.get_run_dir() + logger = self.logger_cfg.build_logger(save_dir=run_dir, distributed_rank=accelerator.rank) cfg_dict = OmegaConf.to_container(OmegaConf.structured(self), resolve=True) del cfg_dict["hydra"] cfg_msg = OmegaConf.to_yaml(cfg_dict).strip().replace("\n", "\n ") @@ -337,6 +336,10 @@ def run(self) -> None: if self.mode == "train": self._train(accelerator=accelerator, logger=logger, cfg_dict=cfg_dict) + elif self.mode == "val": + if not self.resume_from: + raise ValueError("resume_from is required when mode='val'") # noqa: TRY003 + self._evaluate(accelerator=accelerator, logger=logger, module_or_module_path=self.resume_from) else: raise NotImplementedError(f"Mode {self.mode} is not implemented") From e110ce5b7a5c8b7aaa2a6543177c3b0f25409f6d Mon Sep 17 00:00:00 2001 From: Zane Li Date: Tue, 31 Mar 2026 13:39:42 +0800 Subject: [PATCH 29/30] feat: add resnet train checkpoint resume flow --- tests/examples/test_resnet_exp_run.py | 115 ++++++++++++++++++++++++++ tinyexp/examples/resnet_exp.py | 37 +++++++-- 2 files changed, 147 insertions(+), 5 deletions(-) diff --git a/tests/examples/test_resnet_exp_run.py b/tests/examples/test_resnet_exp_run.py index 30022be..a1b6e57 100644 --- a/tests/examples/test_resnet_exp_run.py +++ b/tests/examples/test_resnet_exp_run.py @@ -4,6 +4,7 @@ from types import SimpleNamespace import pytest +import torch import torch.nn as nn from tinyexp.examples.resnet_exp import ResNetExp @@ -57,3 +58,117 @@ def fake_evaluate(*, accelerator, logger, module_or_module_path, val_dataloader= assert called["logger"] is dummy_logger assert called["module_or_module_path"] == checkpoint_path assert called["val_dataloader"] is None + + +def test_resnet_train_saves_checkpoint(tmp_path: Path, monkeypatch) -> None: + exp = ResNetExp(output_root=str(tmp_path), exp_name="resnet_train") + + class DummyAccelerator: + rank = 0 + device = "cpu" + is_main_process = True + world_size = 1 + + def prepare(self, module, optimizer): + return module, optimizer + + def unwrap_model(self, module): + return module + + def backward(self, loss): + loss.backward() + + train_batch = [(torch.randn(2, 2), torch.tensor([0, 1]))] + val_batch = [(torch.randn(2, 2), torch.tensor([0, 1]))] + + monkeypatch.setattr(exp.dataloader_cfg, "build_train_dataloader", lambda accelerator, redis_cache_cfg: train_batch) + monkeypatch.setattr(exp.dataloader_cfg, "build_val_dataloader", lambda accelerator: val_batch) + monkeypatch.setattr(exp.module_cfg, "build_module", lambda: nn.Linear(2, 2)) + monkeypatch.setattr( + exp.optimizer_cfg, + "build_optimizer", + lambda module, dataloader, accelerator: torch.optim.SGD(module.parameters(), lr=0.1), + ) + + saved: list[dict[str, object]] = [] + + def fake_save_checkpoint(**kwargs): + saved.append(kwargs) + if len(saved) >= 1: + raise StopIteration + return str(tmp_path / "resnet_train" / "last.ckpt") + + monkeypatch.setattr(exp.checkpoint_cfg, "save_checkpoint", fake_save_checkpoint) + monkeypatch.setattr(exp, "_evaluate", lambda **kwargs: 0.5) + + with pytest.raises(StopIteration): + exp._train( + accelerator=DummyAccelerator(), + logger=SimpleNamespace(info=lambda *args, **kwargs: None), + cfg_dict={}, + run_dir=str(tmp_path / "resnet_train"), + ) + + assert saved[0]["run_dir"] == str(tmp_path / "resnet_train") + assert saved[0]["name"] == exp.checkpoint_cfg.last_ckpt_name + assert saved[0]["epoch"] == 0 + assert saved[0]["global_step"] == 1 + + +def test_resnet_train_resume_loads_checkpoint_state(tmp_path: Path, monkeypatch) -> None: + exp = ResNetExp(output_root=str(tmp_path), exp_name="resnet_train", resume_from="resume.ckpt") + + class DummyAccelerator: + rank = 0 + device = "cpu" + is_main_process = True + world_size = 1 + + def prepare(self, module, optimizer): + return module, optimizer + + def unwrap_model(self, module): + return module + + def backward(self, loss): + loss.backward() + + train_batch = [(torch.randn(2, 2), torch.tensor([0, 1]))] + val_batch = [(torch.randn(2, 2), torch.tensor([0, 1]))] + + monkeypatch.setattr(exp.dataloader_cfg, "build_train_dataloader", lambda accelerator, redis_cache_cfg: train_batch) + monkeypatch.setattr(exp.dataloader_cfg, "build_val_dataloader", lambda accelerator: val_batch) + monkeypatch.setattr(exp.module_cfg, "build_module", lambda: nn.Linear(2, 2)) + monkeypatch.setattr( + exp.optimizer_cfg, + "build_optimizer", + lambda module, dataloader, accelerator: torch.optim.SGD(module.parameters(), lr=0.1), + ) + + load_calls: list[dict[str, object]] = [] + saved: list[dict[str, object]] = [] + + def fake_load_checkpoint(path, **kwargs): + load_calls.append({"path": path, **kwargs}) + return {"epoch": 4, "global_step": 17} + + def fake_save_checkpoint(**kwargs): + saved.append(kwargs) + raise StopIteration + + monkeypatch.setattr(exp.checkpoint_cfg, "load_checkpoint", fake_load_checkpoint) + monkeypatch.setattr(exp.checkpoint_cfg, "save_checkpoint", fake_save_checkpoint) + monkeypatch.setattr(exp, "_evaluate", lambda **kwargs: 0.5) + + with pytest.raises(StopIteration): + exp._train( + accelerator=DummyAccelerator(), + logger=SimpleNamespace(info=lambda *args, **kwargs: None), + cfg_dict={}, + run_dir=str(tmp_path / "resnet_train"), + ) + + assert load_calls[0]["path"] == "resume.ckpt" + assert load_calls[0]["map_location"] == "cpu" + assert saved[0]["epoch"] == 5 + assert saved[0]["global_step"] == 18 diff --git a/tinyexp/examples/resnet_exp.py b/tinyexp/examples/resnet_exp.py index cdc2594..1ab4c30 100644 --- a/tinyexp/examples/resnet_exp.py +++ b/tinyexp/examples/resnet_exp.py @@ -335,7 +335,7 @@ def run(self) -> None: logger.info(f"-------- Configurations --------\n {cfg_msg}") if self.mode == "train": - self._train(accelerator=accelerator, logger=logger, cfg_dict=cfg_dict) + self._train(accelerator=accelerator, logger=logger, cfg_dict=cfg_dict, run_dir=run_dir) elif self.mode == "val": if not self.resume_from: raise ValueError("resume_from is required when mode='val'") # noqa: TRY003 @@ -380,13 +380,28 @@ def _evaluate(self, accelerator, logger, module_or_module_path, val_dataloader=N if self.wandb_cfg.enable_wandb and accelerator.is_main_process: wandb.log({"val_metric": eval_metric}) - def _train(self, accelerator, logger, cfg_dict) -> None: + return eval_metric + + def _train(self, accelerator, logger, cfg_dict, run_dir: str) -> None: train_dataloader = self.dataloader_cfg.build_train_dataloader(accelerator, self.redis_cache_cfg) val_dataloader = self.dataloader_cfg.build_val_dataloader(accelerator) ori_module = self.module_cfg.build_module() ori_optimizer = self.optimizer_cfg.build_optimizer(ori_module, train_dataloader, accelerator) module, optimizer = accelerator.prepare(ori_module, ori_optimizer) lr_scheduler = self.lr_scheduler_cfg.build_lr_scheduler(optimizer) + start_epoch = 0 + global_step = 0 + + if self.resume_from: + checkpoint = self.checkpoint_cfg.load_checkpoint( + self.resume_from, + model=accelerator.unwrap_model(module), + optimizer=optimizer, + scheduler=lr_scheduler, + map_location=accelerator.device, + ) + start_epoch = int(checkpoint.get("epoch", -1)) + 1 + global_step = int(checkpoint.get("global_step", 0)) if self.wandb_cfg.enable_wandb and accelerator.rank == 0: self.wandb_cfg.build_wandb( @@ -394,9 +409,8 @@ def _train(self, accelerator, logger, cfg_dict) -> None: ) train_iter = iter(train_dataloader) - global_step = 0 - for global_epoch in range(90): + for global_epoch in range(start_epoch, 90): module.train() epoch_start_time = time.time() @@ -432,9 +446,22 @@ def _train(self, accelerator, logger, cfg_dict) -> None: ) lr_scheduler.step() - self._evaluate( + eval_metric = self._evaluate( accelerator=accelerator, logger=logger, module_or_module_path=module, val_dataloader=val_dataloader ) + if accelerator.is_main_process: + self.checkpoint_cfg.save_checkpoint( + run_dir=run_dir, + name=self.checkpoint_cfg.last_ckpt_name, + model=accelerator.unwrap_model(module), + optimizer=optimizer, + scheduler=lr_scheduler, + epoch=global_epoch, + global_step=global_step, + best_metric=eval_metric, + exp_name=self.exp_name, + exp_class=self.exp_class, + ) if __name__ == "__main__": From 16ce0a2f48a1cb275ada0143465ba7f4a86ed73a Mon Sep 17 00:00:00 2001 From: Zane Li Date: Tue, 31 Mar 2026 13:52:29 +0800 Subject: [PATCH 30/30] refactor: keep latest and best resnet checkpoints --- tests/examples/test_resnet_exp_run.py | 10 +++++++--- tinyexp/examples/resnet_exp.py | 18 +++++++++++++++++- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/examples/test_resnet_exp_run.py b/tests/examples/test_resnet_exp_run.py index a1b6e57..e2bb573 100644 --- a/tests/examples/test_resnet_exp_run.py +++ b/tests/examples/test_resnet_exp_run.py @@ -60,7 +60,7 @@ def fake_evaluate(*, accelerator, logger, module_or_module_path, val_dataloader= assert called["val_dataloader"] is None -def test_resnet_train_saves_checkpoint(tmp_path: Path, monkeypatch) -> None: +def test_resnet_train_saves_last_and_best_checkpoints(tmp_path: Path, monkeypatch) -> None: exp = ResNetExp(output_root=str(tmp_path), exp_name="resnet_train") class DummyAccelerator: @@ -94,7 +94,7 @@ def backward(self, loss): def fake_save_checkpoint(**kwargs): saved.append(kwargs) - if len(saved) >= 1: + if len(saved) >= 2: raise StopIteration return str(tmp_path / "resnet_train" / "last.ckpt") @@ -113,6 +113,9 @@ def fake_save_checkpoint(**kwargs): assert saved[0]["name"] == exp.checkpoint_cfg.last_ckpt_name assert saved[0]["epoch"] == 0 assert saved[0]["global_step"] == 1 + assert saved[0]["best_metric"] is None + assert saved[1]["name"] == exp.checkpoint_cfg.best_ckpt_name + assert saved[1]["best_metric"] == 0.5 def test_resnet_train_resume_loads_checkpoint_state(tmp_path: Path, monkeypatch) -> None: @@ -150,7 +153,7 @@ def backward(self, loss): def fake_load_checkpoint(path, **kwargs): load_calls.append({"path": path, **kwargs}) - return {"epoch": 4, "global_step": 17} + return {"epoch": 4, "global_step": 17, "best_metric": 0.7} def fake_save_checkpoint(**kwargs): saved.append(kwargs) @@ -172,3 +175,4 @@ def fake_save_checkpoint(**kwargs): assert load_calls[0]["map_location"] == "cpu" assert saved[0]["epoch"] == 5 assert saved[0]["global_step"] == 18 + assert saved[0]["best_metric"] == 0.7 diff --git a/tinyexp/examples/resnet_exp.py b/tinyexp/examples/resnet_exp.py index 1ab4c30..2239f70 100644 --- a/tinyexp/examples/resnet_exp.py +++ b/tinyexp/examples/resnet_exp.py @@ -391,6 +391,7 @@ def _train(self, accelerator, logger, cfg_dict, run_dir: str) -> None: lr_scheduler = self.lr_scheduler_cfg.build_lr_scheduler(optimizer) start_epoch = 0 global_step = 0 + best_metric = None if self.resume_from: checkpoint = self.checkpoint_cfg.load_checkpoint( @@ -402,6 +403,7 @@ def _train(self, accelerator, logger, cfg_dict, run_dir: str) -> None: ) start_epoch = int(checkpoint.get("epoch", -1)) + 1 global_step = int(checkpoint.get("global_step", 0)) + best_metric = checkpoint.get("best_metric") if self.wandb_cfg.enable_wandb and accelerator.rank == 0: self.wandb_cfg.build_wandb( @@ -458,10 +460,24 @@ def _train(self, accelerator, logger, cfg_dict, run_dir: str) -> None: scheduler=lr_scheduler, epoch=global_epoch, global_step=global_step, - best_metric=eval_metric, + best_metric=best_metric, exp_name=self.exp_name, exp_class=self.exp_class, ) + if best_metric is None or eval_metric > best_metric: + best_metric = eval_metric + self.checkpoint_cfg.save_checkpoint( + run_dir=run_dir, + name=self.checkpoint_cfg.best_ckpt_name, + model=accelerator.unwrap_model(module), + optimizer=optimizer, + scheduler=lr_scheduler, + epoch=global_epoch, + global_step=global_step, + best_metric=best_metric, + exp_name=self.exp_name, + exp_class=self.exp_class, + ) if __name__ == "__main__":