-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.cpp
More file actions
130 lines (93 loc) · 3.79 KB
/
train.cpp
File metadata and controls
130 lines (93 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include "train.h"
void showImage(const torch::Tensor& tensor, const std::string& window_name) {
auto img_tensor = tensor.cpu();
img_tensor = img_tensor.detach();
img_tensor = img_tensor.squeeze();
cv::Mat img(cv::Size(28, 28), CV_32FC1, img_tensor.data_ptr<float>());
img.convertTo(img, CV_8UC1, 255.0);
cv::imshow(window_name, img);
cv::waitKey(0);
}
void trainMNIST() {
torch::manual_seed(1);
torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
// Hyperparameters
int input_size = 28 * 28;
int hidden_size = 64;
int num_classes = 10;
int num_epochs = 15;
int batch_size = 64;
double learning_rate = 1e-3;
double weight_decay = 1e-4;
double gamma = 0.8;
// Load data
std::string data_path = "../data/train.csv";
MNISTDataset dataset(data_path);
auto [train_data, test_data] = dataset.createDataset();
std::cout << "Train data size: " << train_data.size() << std::endl;
std::cout << "Test data size: " << test_data.size() << std::endl;
auto train_batches = createBatches(train_data, batch_size);
auto test_batches = createBatches(test_data, batch_size);
// Model Set up
std::vector<int64_t> layers_hidden = {input_size, hidden_size, num_classes};
KAN kan(layers_hidden);
kan->to(device);
torch::optim::AdamW optimizer(kan->parameters(), torch::optim::AdamWOptions(learning_rate).weight_decay(weight_decay));
torch::nn::CrossEntropyLoss criterion;
// Training loop
for (int epoch = 0; epoch < num_epochs; ++epoch) {
kan->train();
double epoch_loss = 0.0;
for (const auto& batch : train_batches) {
auto images = batch.first.to(device);
auto labels = batch.second.to(device);
// Forward pass
auto outputs = kan->forward(images);
auto loss = criterion(outputs, labels);
// Backward pass and optimization
optimizer.zero_grad();
loss.backward();
optimizer.step();
epoch_loss += loss.item<double>();
}
std::cout << "Epoch [" << epoch + 1 << "/" << num_epochs << "], Loss: " << epoch_loss / train_batches.size() << std::endl;
}
// Evalution steps
kan->eval();
double test_loss = 0.0;
int correct = 0;
int total = 0;
torch::NoGradGuard no_grad;
for (const auto& batch : test_batches) {
auto images = batch.first.to(device);
auto labels = batch.second.to(device);
// Forward
auto outputs = kan->forward(images);
auto loss = criterion(outputs, labels);
test_loss += loss.item<double>();
// Get predicted labels
auto predicted = outputs.argmax(1);
auto actual = labels.argmax(1);
correct += predicted.eq(actual).sum().item<int>();
total += labels.size(0);
}
std::cout << "Test Loss: " << test_loss / test_batches.size() << std::endl;
std::cout << "Test Accuracy: " << static_cast<double>(correct) / total << std::endl;
// Testing with one batch in test set:
for (const auto& batch : test_batches) {
auto images = batch.first.to(device);
auto labels = batch.second.to(device);
// Forward pass
auto outputs = kan->forward(images);
auto predicted = outputs.argmax(1);
auto actual = labels.argmax(1);
// Find the first correct prediction and visualize it
for (int i = 0; i < batch_size; ++i) {
if (predicted[i].item<int>() == actual[i].item<int>()) {
std::cout << "Corrected: " << std::to_string(predicted[i].item<int>()) << std::endl;
std::cout << "Prediction: Label " << std::to_string(predicted[i].item<int>()) << std::endl;
}
}
return;
}
}