From e529219817898c2f9f1bd492b8bb42cd818b0263 Mon Sep 17 00:00:00 2001 From: Zane Li Date: Fri, 23 Jan 2026 11:19:37 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=F0=9F=8E=B8=20automatically=20get=20ex?= =?UTF-8?q?periments=20name?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit automatically get experiments name --- tests/test_tinyexp.py | 40 ++++++++++++++++++++++++++++++++++ tinyexp/__init__.py | 25 +++++++++++++++++++++ tinyexp/examples/mnist_exp.py | 2 +- tinyexp/examples/resnet_exp.py | 2 +- 4 files changed, 67 insertions(+), 2 deletions(-) diff --git a/tests/test_tinyexp.py b/tests/test_tinyexp.py index a100265..0773e17 100644 --- a/tests/test_tinyexp.py +++ b/tests/test_tinyexp.py @@ -1,3 +1,5 @@ +import sys +import types from dataclasses import dataclass, field import pytest @@ -29,6 +31,44 @@ class MyExperiment(TinyExp): _ = MyExperiment() +def test_exp_name_defaults_from_main_module_file(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: + dummy_main = types.ModuleType("__main__") + dummy_main.__file__ = str(tmp_path / "resnet_exp.py") + monkeypatch.setitem(sys.modules, "__main__", dummy_main) + + exp = TinyExp() + assert exp.exp_name == "resnet_exp" + + +def test_exp_name_falls_back_to_argv(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: + dummy_main = types.ModuleType("__main__") + monkeypatch.setitem(sys.modules, "__main__", dummy_main) + monkeypatch.setattr(sys, "argv", [str(tmp_path / "mnist_exp.py")]) + + exp = TinyExp() + assert exp.exp_name == "mnist_exp" + + +def test_exp_name_defaults_to_exp_for_dash_c(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_main = types.ModuleType("__main__") + monkeypatch.setitem(sys.modules, "__main__", dummy_main) + monkeypatch.setattr(sys, "argv", ["-c"]) + + exp = TinyExp() + assert exp.exp_name == "exp" + + +def test_set_cfg_overrides_exp_name(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("RANK", "1") # avoid noisy stdout prints during tests + exp = TinyExp() + + cfg = OmegaConf.create({"exp_name": "my_exp"}) + exp.set_cfg(cfg) + + assert exp.exp_name == "my_exp" + assert exp.overrided_cfg["exp_name"] == "my_exp" + + def test_set_cfg_overrides_nested(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("RANK", "1") # avoid noisy stdout prints during tests exp = _CfgExp() diff --git a/tinyexp/__init__.py b/tinyexp/__init__.py index 6f73a6c..6b238a9 100644 --- a/tinyexp/__init__.py +++ b/tinyexp/__init__.py @@ -3,6 +3,7 @@ __license__ = "MIT" import os +import sys from dataclasses import dataclass, field from typing import Optional @@ -28,6 +29,23 @@ class _HydraConfig(HydraConf): run: RunDir = field(default_factory=lambda: RunDir("./output")) +def _default_exp_name() -> str: + """ + Get the default experiment name from the main module or the command line. + e.g. if the main module is `resnet_exp.py`, the experiment name will be `resnet_exp`. + """ + main_module = sys.modules.get("__main__") + main_file = getattr(main_module, "__file__", None) + if isinstance(main_file, str) and main_file: + return os.path.splitext(os.path.basename(main_file))[0] + + argv0 = sys.argv[0] if sys.argv else "" + if argv0 and argv0 != "-c": + return os.path.splitext(os.path.basename(argv0))[0] + + return "exp" + + @dataclass class TinyExp: """ @@ -43,8 +61,15 @@ class TinyExp: num_gpus_per_worker: float = 1.0 # Number of GPUs per worker, should be a float value between 0 and 1. # Fully qualified import path for the experiment class, e.g. "tinyexp.examples.mnist_exp.Exp". + # It is used in Hydra config store to instantiate the experiment class, and in store_and_run_exp, the exp_class will be automatically set to the fully qualified import path of the experiment class. + # If you do not use store_and_run_exp, you can set this field to an empty string. exp_class: str = "" + # The experiment name, will be used as the subdirectory name in the output directory. + # If not provided, the default experiment name will be the name of the main module or the command line. + # e.g. if the main module is `resnet_exp.py`, the experiment name will be `resnet_exp`. + exp_name: str = field(default_factory=_default_exp_name) + # log directory output_root: str = "./output" diff --git a/tinyexp/examples/mnist_exp.py b/tinyexp/examples/mnist_exp.py index 2c3f908..fd3b714 100644 --- a/tinyexp/examples/mnist_exp.py +++ b/tinyexp/examples/mnist_exp.py @@ -143,7 +143,7 @@ def build_lr_scheduler(self, optimizer): def run(self) -> None: accelerator = self.accelerator_cfg.build_accelerator() logger = self.logger_cfg.build_logger( - save_dir=os.path.join(self.output_root, self.__class__.__name__), + save_dir=os.path.join(self.output_root, self.exp_name), distributed_rank=accelerator.rank, ) cfg_dict = OmegaConf.to_container(OmegaConf.structured(self), resolve=True) diff --git a/tinyexp/examples/resnet_exp.py b/tinyexp/examples/resnet_exp.py index 574e781..4554804 100644 --- a/tinyexp/examples/resnet_exp.py +++ b/tinyexp/examples/resnet_exp.py @@ -328,7 +328,7 @@ def build_val_dataloader(self, accelerator): def run(self) -> None: accelerator = self.accelerator_cfg.build_accelerator() logger = self.logger_cfg.build_logger( - save_dir=os.path.join(self.output_root, self.__class__.__name__), distributed_rank=accelerator.rank + save_dir=os.path.join(self.output_root, self.exp_name), distributed_rank=accelerator.rank ) cfg_dict = OmegaConf.to_container(OmegaConf.structured(self), resolve=True) del cfg_dict["hydra"]