Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/test_tinyexp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys
import types
from dataclasses import dataclass, field

import pytest
Expand Down Expand Up @@ -29,6 +31,44 @@ class MyExperiment(TinyExp):
_ = MyExperiment()


def test_exp_name_defaults_from_main_module_file(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
dummy_main = types.ModuleType("__main__")
dummy_main.__file__ = str(tmp_path / "resnet_exp.py")
monkeypatch.setitem(sys.modules, "__main__", dummy_main)

exp = TinyExp()
assert exp.exp_name == "resnet_exp"


def test_exp_name_falls_back_to_argv(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
dummy_main = types.ModuleType("__main__")
monkeypatch.setitem(sys.modules, "__main__", dummy_main)
monkeypatch.setattr(sys, "argv", [str(tmp_path / "mnist_exp.py")])

exp = TinyExp()
assert exp.exp_name == "mnist_exp"


def test_exp_name_defaults_to_exp_for_dash_c(monkeypatch: pytest.MonkeyPatch) -> None:
dummy_main = types.ModuleType("__main__")
monkeypatch.setitem(sys.modules, "__main__", dummy_main)
monkeypatch.setattr(sys, "argv", ["-c"])

exp = TinyExp()
assert exp.exp_name == "exp"


def test_set_cfg_overrides_exp_name(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("RANK", "1") # avoid noisy stdout prints during tests
exp = TinyExp()

cfg = OmegaConf.create({"exp_name": "my_exp"})
exp.set_cfg(cfg)

assert exp.exp_name == "my_exp"
assert exp.overrided_cfg["exp_name"] == "my_exp"


def test_set_cfg_overrides_nested(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("RANK", "1") # avoid noisy stdout prints during tests
exp = _CfgExp()
Expand Down
25 changes: 25 additions & 0 deletions tinyexp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__license__ = "MIT"

import os
import sys
from dataclasses import dataclass, field
from typing import Optional

Expand All @@ -28,6 +29,23 @@ class _HydraConfig(HydraConf):
run: RunDir = field(default_factory=lambda: RunDir("./output"))


def _default_exp_name() -> str:
"""
Get the default experiment name from the main module or the command line.
e.g. if the main module is `resnet_exp.py`, the experiment name will be `resnet_exp`.
"""
main_module = sys.modules.get("__main__")
main_file = getattr(main_module, "__file__", None)
if isinstance(main_file, str) and main_file:
return os.path.splitext(os.path.basename(main_file))[0]

argv0 = sys.argv[0] if sys.argv else ""
if argv0 and argv0 != "-c":
return os.path.splitext(os.path.basename(argv0))[0]

return "exp"


@dataclass
class TinyExp:
"""
Expand All @@ -43,8 +61,15 @@ class TinyExp:
num_gpus_per_worker: float = 1.0 # Number of GPUs per worker, should be a float value between 0 and 1.

# Fully qualified import path for the experiment class, e.g. "tinyexp.examples.mnist_exp.Exp".
# It is used in Hydra config store to instantiate the experiment class, and in store_and_run_exp, the exp_class will be automatically set to the fully qualified import path of the experiment class.
# If you do not use store_and_run_exp, you can set this field to an empty string.
exp_class: str = ""

# The experiment name, will be used as the subdirectory name in the output directory.
# If not provided, the default experiment name will be the name of the main module or the command line.
# e.g. if the main module is `resnet_exp.py`, the experiment name will be `resnet_exp`.
exp_name: str = field(default_factory=_default_exp_name)

# log directory
output_root: str = "./output"

Expand Down
2 changes: 1 addition & 1 deletion tinyexp/examples/mnist_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def build_lr_scheduler(self, optimizer):
def run(self) -> None:
accelerator = self.accelerator_cfg.build_accelerator()
logger = self.logger_cfg.build_logger(
save_dir=os.path.join(self.output_root, self.__class__.__name__),
save_dir=os.path.join(self.output_root, self.exp_name),
distributed_rank=accelerator.rank,
)
cfg_dict = OmegaConf.to_container(OmegaConf.structured(self), resolve=True)
Expand Down
2 changes: 1 addition & 1 deletion tinyexp/examples/resnet_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def build_val_dataloader(self, accelerator):
def run(self) -> None:
accelerator = self.accelerator_cfg.build_accelerator()
logger = self.logger_cfg.build_logger(
save_dir=os.path.join(self.output_root, self.__class__.__name__), distributed_rank=accelerator.rank
save_dir=os.path.join(self.output_root, self.exp_name), distributed_rank=accelerator.rank
)
cfg_dict = OmegaConf.to_container(OmegaConf.structured(self), resolve=True)
del cfg_dict["hydra"]
Expand Down
Loading