-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_train.py
More file actions
55 lines (43 loc) · 2.06 KB
/
test_train.py
File metadata and controls
55 lines (43 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# FitLayout - Python GNN Demo
# (c) 2026 Radek Burget <burgetr@fit.vut.cz>
# This script trains a Graph Convolutional Network (GCN) on a dataset of AreaTree objects.
# The resulting model is saved in the 'models' directory.
import os
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
from flclient.flclient import default_prefix_string, R, SEGM
from graph.creator import AreaGraphCreator
from graph.dataset import RemoteDataset, LocalDataset
from train import Train
from models import GCNC
from config import fl, tags, relations, params
# Create the graph creator
gc = AreaGraphCreator(fl, relations, tags)
# Examine the dataset
#dataset = RemoteDataset(gc, limit=100) # Using the remote dataset directly
dataset = LocalDataset('data/graphs') # Using the locally saved graphs created using convert_all.py
print("# of classes: ", dataset.num_classes)
print("# of features: ", dataset.num_features)
print("# of node features: ", dataset.num_node_features)
print("# of edge features: ", dataset.num_edge_features)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCNC(dataset.num_node_features, dataset.num_classes).to(device)
# Split the dataset to train and validation
torch.manual_seed(42) # Use a fixed seed for reproducibility of the split
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
train_dataloader = DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=params["shuffle"])
# Train the model
train = Train(model, train_dataloader, val_dataloader, params)
train.train_loop()
# Save the trained model
if not os.path.exists("models"):
os.makedirs("models")
torch.save(model.state_dict(), "models/last.pt")
print("Model saved to models/last.pt")