@@ -18,6 +18,13 @@ def main():
1818 train_parser .add_argument ("--run_name" , type = str , default = "run" )
1919 train_parser .add_argument ("--log_dir" , type = str , default = "mlruns" )
2020 train_parser .add_argument ("--data_path" , type = str , default = "data" )
21+ train_parser .add_argument (
22+ "--profiler" ,
23+ type = str ,
24+ default = None ,
25+ choices = ["simple" , "advanced" , "pytorch" ],
26+ help = "Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'." ,
27+ )
2128
2229 # ---- FINETUNE SUBCOMMAND ----
2330 finetune_parser = subparsers .add_parser ("finetune" , help = "Run fine-tuning" )
@@ -27,6 +34,13 @@ def main():
2734 finetune_parser .add_argument ("--run_name" , type = str , default = "run" )
2835 finetune_parser .add_argument ("--log_dir" , type = str , default = "mlruns" )
2936 finetune_parser .add_argument ("--data_path" , type = str , default = "data" )
37+ finetune_parser .add_argument (
38+ "--profiler" ,
39+ type = str ,
40+ default = None ,
41+ choices = ["simple" , "advanced" , "pytorch" ],
42+ help = "Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'." ,
43+ )
3044
3145 # ---- EVALUATE SUBCOMMAND ----
3246 evaluate_parser = subparsers .add_parser (
@@ -46,6 +60,13 @@ def main():
4660 evaluate_parser .add_argument ("--run_name" , type = str , default = "run" )
4761 evaluate_parser .add_argument ("--log_dir" , type = str , default = "mlruns" )
4862 evaluate_parser .add_argument ("--data_path" , type = str , default = "data" )
63+ evaluate_parser .add_argument (
64+ "--profiler" ,
65+ type = str ,
66+ default = None ,
67+ choices = ["simple" , "advanced" , "pytorch" ],
68+ help = "Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'." ,
69+ )
4970 evaluate_parser .add_argument (
5071 "--compute_dc_ac_metrics" ,
5172 action = "store_true" ,
@@ -72,6 +93,13 @@ def main():
7293 predict_parser .add_argument ("--log_dir" , type = str , default = "mlruns" )
7394 predict_parser .add_argument ("--data_path" , type = str , default = "data" )
7495 predict_parser .add_argument ("--output_path" , type = str , default = "data" )
96+ predict_parser .add_argument (
97+ "--profiler" ,
98+ type = str ,
99+ default = None ,
100+ choices = ["simple" , "advanced" , "pytorch" ],
101+ help = "Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'." ,
102+ )
75103
76104 args = parser .parse_args ()
77105 main_cli (args )
0 commit comments