-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_efficient_training.py
More file actions
126 lines (105 loc) · 5 KB
/
run_efficient_training.py
File metadata and controls
126 lines (105 loc) · 5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# run_efficient_training.py
import argparse
import logging
import os
from typing import Optional
from trajectory_model import train_trajectory_model_efficient
from preprocess_dataset import preprocess_dataset_memmap
# very top of run_efficient_training.py (before importing NumPy)
import os, tempfile
os.environ["TMPDIR"] = "/home/jack/data/social_nav/tmp" # any big partition
tempfile.tempdir = os.environ["TMPDIR"] # safety
def main():
"""Main function for efficient trajectory prediction training."""
parser = argparse.ArgumentParser(description="Efficient Pedestrian Trajectory Prediction")
parser.add_argument("--dataset_path", type=str, required=True,
help="Path to dataset (e.g., /home/jack/data/social_nav/crossroad)")
parser.add_argument("--preprocess_only", action="store_true",
help="Only preprocess the dataset, don't train")
parser.add_argument("--output_dir", type=str, default="./model_output",
help="Output directory for model and results")
parser.add_argument("--preprocessed_dir", type=str, default="/home/jack/data/social_nav/preprocessed_data",
help="Directory to store preprocessed data")
parser.add_argument("--num_epochs", type=int, default=30,
help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=8,
help="Batch size")
parser.add_argument("--sequence_length", type=int, default=5,
help="Sequence length")
parser.add_argument("--learning_rate", type=float, default=1e-4,
help="Learning rate")
parser.add_argument("--target_width", type=int, default=320,
help="Target frame width")
parser.add_argument("--target_height", type=int, default=320,
help="Target frame height")
parser.add_argument("--yolo_model_path", type=str,
default="/home/jack/src/attention/models/yolo11n.onnx",
help="Path to YOLO model")
parser.add_argument("--embedding_dim", type=int, default=128,
help="Embedding dimension")
parser.add_argument("--num_heads", type=int, default=4,
help="Number of attention heads")
parser.add_argument("--tensorboard_dir", type=str, default="./log_dir",
help="Directory to save TensorBoard logs")
parser.add_argument("--resume_checkpoint", type=str, default=None,
help="Path to checkpoint file to resume training from")
args = parser.parse_args()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("efficient_trajectory")
# Import necessary modules
# Create preprocessed data directory
os.makedirs(args.preprocessed_dir, exist_ok=True)
# Preprocess training data (stride=1)
#train_data_path = preprocess_dataset(
train_data_path = preprocess_dataset_memmap(
dataset_path=args.dataset_path,
output_path=args.preprocessed_dir,
sequence_length=args.sequence_length,
target_width=args.target_width,
target_height=args.target_height,
yolo_model_path=args.yolo_model_path,
stride=1, # Use stride 1 for training
max_per_sequence=1000 # Limit for faster iteration
)
# Preprocess validation data (stride=2)
#val_data_path = preprocess_dataset(
val_data_path = preprocess_dataset_memmap(
dataset_path=args.dataset_path,
output_path=args.preprocessed_dir,
sequence_length=args.sequence_length,
target_width=args.target_width,
target_height=args.target_height,
yolo_model_path=args.yolo_model_path,
stride=2, # Use stride 2 for validation (different subset)
max_per_sequence=500 # Smaller set for validation
)
if not train_data_path:
logger.error("Preprocessing failed on train data, cannot continue")
return
if not val_data_path:
logger.error("Preprocessing failed on val data, cannot continue")
return
if args.preprocess_only:
logger.info("Preprocessing completed, skipping training as requested")
return
# Train the model with preprocessed data
train_trajectory_model_efficient(
preprocessed_train_path=train_data_path,
preprocessed_val_path=val_data_path,
output_dir=args.output_dir,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
embedding_dim=args.embedding_dim,
num_heads=args.num_heads,
debug_image_dir=os.path.join(args.output_dir, "debug_images"),
tensorboard_dir=args.tensorboard_dir,
resume_checkpoint=args.resume_checkpoint
)
logger.info("Training completed successfully")
if __name__ == "__main__":
main()