-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
117 lines (93 loc) · 4.74 KB
/
test.py
File metadata and controls
117 lines (93 loc) · 4.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from torch.utils.data import DataLoader
from datetime import datetime
import logging
import pprint
import os
import argparse
from modules.models.vqvae import VQVAE
from modules.models.forecasting import Forecaster
from modules.utils.visualize import visualize_reconstructions, visualize_codebook, visualize_codebook_utilization, visualize_forecasts
from modules.data.preprocessing import TimeSeriesDataset, ForecastDataset
from modules.utils.helpers import count_parameters, load_configuration
def main() -> None:
# Parse command-line arguments
parser = argparse.ArgumentParser(description='Testing VQVAE and/or Forecaster')
parser.add_argument("--config", type=str, default="config/default.yaml", help="Path to configuration file (default: config/default.yaml)")
parser.add_argument("--vqvae", type=str, default=None, help="Path to pretrained VQVAE model weights")
parser.add_argument("--forecaster", type=str, default=None, help="Path to pretrained Forecaster model weights")
args = parser.parse_args()
# Logging
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = f"results/run_{timestamp}"
os.makedirs(save_dir, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s",
handlers=[
logging.FileHandler(os.path.join(save_dir, "test.log"))
]
)
logging.info(f"Command-line arguments: {vars(args)}")
config = load_configuration(args.config)
logging.info(f"CONFIGURATION\n {pprint.pformat(config)}")
# Configuration
num_embeddings = config["vqvae"]["num_embeddings"]
embedding_dim = config["vqvae"]["embedding_dim"]
commitment_cost = config["vqvae"]["commitment_cost"]
hidden_channels = config["vqvae"]["hidden_channels"]
compression_factor = config["vqvae"]["compression_factor"]
input_length = config["forecaster"]["input_length"]
output_length = config["forecaster"]["output_length"]
d_model = config["forecaster"]["d_model"]
num_heads = config["forecaster"]["num_heads"]
num_encoder_layers = config["forecaster"]["num_encoder_layers"]
ff_dim = config["forecaster"]["ff_dim"]
dropout = config["forecaster"]["dropout"]
test_folder = config["data"]["test_path"]
num_workers = config["data"]["num_workers"]
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize and Test Models
if args.vqvae:
vqvae = VQVAE(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
commitment_cost=commitment_cost,
hidden_channels=hidden_channels,
compression_factor=compression_factor
).to(device)
vqvae.load_state_dict(torch.load(args.vqvae, map_location=device))
logging.info(f"Loaded pretrained VQVAE weights from {args.vqvae}")
logging.info(f"VQVAE number of parameters: {count_parameters(vqvae)}")
timesteps = config["vqvae"]["training"]["timesteps"]
batch_size = config["vqvae"]["training"]["batch_size"]
test_dataset = TimeSeriesDataset(timesteps, folder_path=test_folder)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
visualize_reconstructions(vqvae, test_dataloader, save_dir, device)
visualize_codebook(vqvae, save_dir)
visualize_codebook_utilization(vqvae, test_dataloader, vqvae.num_embeddings, save_dir, device)
else:
raise ValueError("You must specify a pretrained VQVAE")
if args.forecaster:
assert input_length % vqvae.compression_factor == 0, \
"The Forecaster `input_length` must be divisible by the `compression_factor` of the VQVAE"
forecaster = Forecaster(
context_length=input_length // vqvae.compression_factor,
input_length=input_length,
output_length=output_length,
vocab_size=vqvae.num_embeddings,
d_model=d_model,
num_heads=num_heads,
num_encoder_layers=num_encoder_layers,
ff_dim=ff_dim,
dropout=dropout
).to(device)
forecaster.load_state_dict(torch.load(args.forecaster, map_location=device))
logging.info(f"Loaded pretrained Forecaster weights from {args.forecaster}")
logging.info(f"Forecaster number of parameters: {count_parameters(forecaster)}")
batch_size = config["forecaster"]["training"]["batch_size"]
test_dataset = ForecastDataset(input_length, output_length, folder_path=test_folder)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
visualize_forecasts(forecaster, vqvae, test_dataloader, save_dir, device)
if __name__ == "__main__":
main()