Skip to content

Commit 5e1534f

Browse files
Merge pull request #45 from gridfm/add_profiler
add lightning profiler cli argument
2 parents 974cdcb + a73721e commit 5e1534f

3 files changed

Lines changed: 39 additions & 1 deletion

File tree

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
venv
1+
*venv
2+
data_out
3+
logs
4+
site
5+
.julia
26
venv_pp
37
/data/
48
__pycache__/
@@ -10,3 +14,4 @@ gridfm_graphkit.egg-info
1014
mlruns
1115
*.pt
1216
.DS_Store
17+
.venv

gridfm_graphkit/__main__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ def main():
1818
train_parser.add_argument("--run_name", type=str, default="run")
1919
train_parser.add_argument("--log_dir", type=str, default="mlruns")
2020
train_parser.add_argument("--data_path", type=str, default="data")
21+
train_parser.add_argument(
22+
"--profiler",
23+
type=str,
24+
default=None,
25+
choices=["simple", "advanced", "pytorch"],
26+
help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.",
27+
)
2128

2229
# ---- FINETUNE SUBCOMMAND ----
2330
finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning")
@@ -27,6 +34,13 @@ def main():
2734
finetune_parser.add_argument("--run_name", type=str, default="run")
2835
finetune_parser.add_argument("--log_dir", type=str, default="mlruns")
2936
finetune_parser.add_argument("--data_path", type=str, default="data")
37+
finetune_parser.add_argument(
38+
"--profiler",
39+
type=str,
40+
default=None,
41+
choices=["simple", "advanced", "pytorch"],
42+
help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.",
43+
)
3044

3145
# ---- EVALUATE SUBCOMMAND ----
3246
evaluate_parser = subparsers.add_parser(
@@ -46,6 +60,13 @@ def main():
4660
evaluate_parser.add_argument("--run_name", type=str, default="run")
4761
evaluate_parser.add_argument("--log_dir", type=str, default="mlruns")
4862
evaluate_parser.add_argument("--data_path", type=str, default="data")
63+
evaluate_parser.add_argument(
64+
"--profiler",
65+
type=str,
66+
default=None,
67+
choices=["simple", "advanced", "pytorch"],
68+
help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.",
69+
)
4970
evaluate_parser.add_argument(
5071
"--compute_dc_ac_metrics",
5172
action="store_true",
@@ -72,6 +93,13 @@ def main():
7293
predict_parser.add_argument("--log_dir", type=str, default="mlruns")
7394
predict_parser.add_argument("--data_path", type=str, default="data")
7495
predict_parser.add_argument("--output_path", type=str, default="data")
96+
predict_parser.add_argument(
97+
"--profiler",
98+
type=str,
99+
default=None,
100+
choices=["simple", "advanced", "pytorch"],
101+
help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.",
102+
)
75103

76104
args = parser.parse_args()
77105
main_cli(args)

gridfm_graphkit/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def main_cli(args):
6666
state_dict = torch.load(args.model_path, map_location="cpu")
6767
model.load_state_dict(state_dict)
6868

69+
profiler = getattr(args, "profiler", None)
70+
6971
trainer = L.Trainer(
7072
logger=logger,
7173
accelerator=config_args.training.accelerator,
@@ -75,6 +77,7 @@ def main_cli(args):
7577
default_root_dir=args.log_dir,
7678
max_epochs=config_args.training.epochs,
7779
callbacks=get_training_callbacks(config_args),
80+
profiler=profiler,
7881
)
7982
if args.command == "train" or args.command == "finetune":
8083
trainer.fit(model=model, datamodule=litGrid)
@@ -87,6 +90,7 @@ def main_cli(args):
8790
num_nodes=1,
8891
log_every_n_steps=1,
8992
default_root_dir=args.log_dir,
93+
profiler=profiler,
9094
)
9195
test_trainer.test(model=model, datamodule=litGrid)
9296

@@ -119,6 +123,7 @@ def main_cli(args):
119123
num_nodes=1,
120124
log_every_n_steps=1,
121125
default_root_dir=args.log_dir,
126+
profiler=profiler,
122127
)
123128
predictions = predict_trainer.predict(model=model, datamodule=litGrid)
124129

0 commit comments

Comments
 (0)