-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtrain_simple.py
More file actions
68 lines (55 loc) · 2.24 KB
/
train_simple.py
File metadata and controls
68 lines (55 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python3
import logging
from encoding.assembly.assembly_loader import load_assembly
from encoding.features.factory import FeatureExtractorFactory
from encoding.downsample.downsampling import Downsampler
from encoding.models.nested_cv import NestedCVModel
from encoding.trainer import AbstractTrainer
def main():
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# 1) Load the packaged assembly
assembly_path = "/storage/coda1/p-aivanova7/0/shared/litcoder_core/scripts/assembly_lebel_uts03.pkl"
logger.info(f"Loading assembly from {assembly_path}")
assembly = load_assembly(assembly_path)
# 2) Create the wordrate-only feature extractor
extractor = FeatureExtractorFactory.create_extractor(
modality="wordrate",
model_name="wordrate",
config={},
cache_dir="cache",
)
# 3) Set up other components
downsampler = Downsampler()
model = NestedCVModel(model_name="ridge_regression")
fir_delays = [1, 2, 3, 4]
# Correct Lebel trimming configuration (matches train_lebel.py/unified.py)
trimming_config = {
"train_features_start": 10, "train_features_end": -5,
"train_targets_start": 0, "train_targets_end": None,
"test_features_start": 50, "test_features_end": -5,
"test_targets_start": 40, "test_targets_end": None,
}
downsample_config = {}
trainer = AbstractTrainer(
assembly=assembly,
feature_extractors=[extractor],
downsampler=downsampler,
model=model,
fir_delays=fir_delays,
trimming_config=trimming_config,
use_train_test_split=True,
logger_backend="wandb",
wandb_project_name="lebel-wordrate",
dataset_type="lebel",
results_dir="results",
downsample_config=downsample_config,
)
logger.info("Starting training (wordrate only, no extra kwargs)...")
metrics = trainer.train()
logger.info("\n=== Final Results ===")
logger.info(f"Median correlation: {metrics.get('median_score', float('nan')):.4f}")
if "n_significant" in metrics:
logger.info(f"Significant voxels: {metrics['n_significant']}")
if __name__ == "__main__":
main()