-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtestKAN.cpp
More file actions
68 lines (56 loc) · 2.68 KB
/
testKAN.cpp
File metadata and controls
68 lines (56 loc) · 2.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include "testKAN.h"
void test_KANLinear() {
int64_t in_features = 2;
int64_t out_features = 1;
std::cout << "CHECKING: ..." << std::endl;
KANLinear kan_linear(in_features, out_features);
int64_t epoch = 30;
auto optimizer = torch::optim::LBFGS(kan_linear->parameters(), torch::optim::LBFGSOptions(1).max_iter(100));
torch::Tensor input = torch::randn({2024, in_features});
for (int i = 0; i < epoch; ++i) {
auto closure = [&]() -> torch::Tensor {
optimizer.zero_grad();
torch::Tensor output = kan_linear->forward(input);
torch::Tensor reg_loss = kan_linear->regularization_loss(1, 0);
auto u = input.index({torch::indexing::Slice(), 0});
auto v = input.index({torch::indexing::Slice(), 1});
torch::Tensor target = (u + v) / (1 + u * v);
torch::Tensor loss = torch::nn::functional::mse_loss(output.squeeze(-1), target);
torch::Tensor total_loss = loss + 1e-5 * reg_loss;
std::cout << "epoch: " << i << "Total Loss: " << total_loss.item<double>() << std::endl;
total_loss.backward();
return total_loss;
};
optimizer.step(closure);
}
torch::Tensor new_input = torch::randn({3, in_features});
std::cout << "Input: " << new_input << std::endl;
std::cout << "Output: " << kan_linear->forward(new_input) << std::endl;
}
void test_KAN() {
std::vector<int64_t> layers_hidden = {2, 5, 1};
KAN kan(layers_hidden);
auto optimizer = torch::optim::LBFGS(kan->parameters(), torch::optim::LBFGSOptions(1).max_iter(100));
int64_t epoch = 20;
auto input = torch::rand({1024, 2});
for (int i = 0; i < epoch; ++i) {
auto closure = [&]() -> torch::Tensor {
optimizer.zero_grad();
torch::Tensor output = kan->forward(input);
torch::Tensor reg_loss = kan->regularization_loss(1, 0);
auto x1 = input.index({torch::indexing::Slice(), 0});
auto x2 = input.index({torch::indexing::Slice(), 1});
const double pi = M_PI;
torch::Tensor target = torch::exp(torch::sin(pi * x1) + x2 * x2);
torch::Tensor loss = torch::nn::functional::mse_loss(output.squeeze(-1), target);
torch::Tensor total_loss = loss + 1e-5 * reg_loss;
std::cout << "epoch: " << i+1 << "Total Loss: " << total_loss.item<double>() << std::endl;
total_loss.backward();
return total_loss;
};
optimizer.step(closure);
}
auto new_input = torch::rand({5, 2});
std::cout << "Input: " << new_input << std::endl;
std::cout << "Output: " << kan->forward(new_input) << std::endl;
}