forked from AlgRUC/JittorGeometric
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrecsys_example.py
More file actions
161 lines (138 loc) · 6.23 KB
/
recsys_example.py
File metadata and controls
161 lines (138 loc) · 6.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
'''
Description:
Author: zhengyp
Date: 2025-07-13
'''
import os.path as osp
import argparse
import pdb
import jittor as jt
from jittor import nn
import sys,os
root = osp.dirname(osp.dirname(osp.abspath(__file__)))
sys.path.append(root)
from jittor_geometric.dataloader import RecsysDataLoader
from jittor_geometric.datasets import MovieLens1M, MovieLens100K, Yelp2018
from jittor_geometric.datasets.recsys import Hit, MRR, NDCG, Recall, DataStruct
from jittor_geometric.nn.models import LightGCN, SimGCL, XSimGCL, DirectAU
from tqdm import tqdm
# Enable CUDA for GPU acceleration
jt.flags.use_cuda = 1
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default="ml-1m", help='graph dataset')
parser.add_argument('--model', default="lightgcn", help='model name')
parser.add_argument('--embedding_size', type=int, default=64, help='size of embedding')
parser.add_argument('--nlayer', type=int, default=2, help='number of layers')
parser.add_argument('--num_epochs', type=int, default=300, help='Training epochs')
parser.add_argument('--patience', type=int, default=10, help='patience')
parser.add_argument('--eval_step', type=int, default=1, help='eval step')
parser.add_argument('--reg_weight', type=float, default=1e-4, help='weight of regularization loss')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=0, help='weight decay')
args = parser.parse_args()
path = osp.join(osp.dirname(osp.realpath(__file__)), '../data')
# Load recommendation dataset
if args.dataset == 'ml-100k':
dataset = MovieLens100K(root=path)
elif args.dataset == 'ml-1m':
dataset = MovieLens1M(root=path)
elif args.dataset == 'yelp2018':
dataset = Yelp2018(root=path)
data = dataset.get(0)
num_epochs = args.num_epochs + 1
# Create data loader for training
train_loader = RecsysDataLoader(
edge_index = data.train_edge_index,
num_items = int(data.num_items),
batch_size = 4096,
num_neg = 1,
shuffle = True,
)
# Initialize recommendation model
if args.model.lower() == 'lightgcn':
model = LightGCN(data.num_users, data.num_items, args.embedding_size, args.nlayer, data.train_edge_index, reg_weight=args.reg_weight)
elif args.model.lower() == 'simgcl':
model = SimGCL(data.num_users, data.num_items, args.embedding_size, args.nlayer, data.train_edge_index, reg_weight=args.reg_weight)
elif args.model.lower() == 'xsimgcl':
model = XSimGCL(data.num_users, data.num_items, args.embedding_size, args.nlayer, data.train_edge_index, reg_weight=args.reg_weight)
elif args.model.lower() == 'directau':
model = DirectAU(data.num_users, data.num_items, args.embedding_size, data.train_edge_index)
optimizer = nn.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Evaluation function for different data splits
def evaluate(model, data, split="val", k=[10]):
model.eval()
with jt.no_grad():
if split == "val":
split_edge_index = data.val_edge_index
elif split == "test":
split_edge_index = data.test_edge_index
elif split == "train":
split_edge_index = data.train_edge_index
else:
raise ValueError(f"Unknown split: {split}")
return evaluate_topk_from_edges(model, split_edge_index, data.train_edge_index, data.num_items, k=k)
# Build evaluation results from top-k predictions
def build_eval_result(scores, positive_u, positive_i, ks=[10]):
topk_scores, topk_idx = jt.topk(scores, k=max(ks), dim=1, largest=True)
pos_matrix = jt.zeros((scores.shape), dtype=jt.int)
pos_matrix[positive_u, positive_i] = 1
pos_len = pos_matrix.sum(dim=1, keepdims=True)
pos_idx = pos_matrix.gather(1, topk_idx)
result = jt.concat([pos_idx, pos_len], dim=1)
return result
# Evaluate top-k recommendation performance
def evaluate_topk_from_edges(model, split_edge_index, train_edge_index, num_items, k=[10]):
data_struct = DataStruct()
test_users = jt.unique(split_edge_index[0]).numpy().tolist()
pbar = tqdm(test_users, desc="Evaluating", unit="user")
for uid in pbar:
# Ground-truth items in evaluation set
true_items = split_edge_index[1][split_edge_index[0] == uid].numpy().tolist()
# Items already interacted in the training set
known_items = train_edge_index[1][train_edge_index[0] == uid].numpy().tolist()
num_items = int(num_items)
user_tensor = jt.array([int(uid)], dtype="int32")
scores = model.full_predict(user_tensor).reshape(-1, num_items)
scores[:, known_items] = -1e9 # Filter out known items
batch_result = build_eval_result(scores, [0] * len(true_items), true_items, ks=k)
data_struct.update_tensor("rec.topk", batch_result.numpy())
rec_mat = jt.Var(data_struct.get_tensor("rec.topk"))
metrics = {
**Hit(k=k).calculate_metric(rec_mat),
**MRR(k=k).calculate_metric(rec_mat),
**NDCG(k=k).calculate_metric(rec_mat),
**Recall(k=k).calculate_metric(rec_mat),
}
return metrics
# Training loop with early stopping
best_score = -float("inf")
patience_counter = 0
for epoch in range(1, num_epochs):
print(f"Epoch {epoch}/{num_epochs}")
model.train()
# Train on batches of user-item interactions
for users, pos_items, neg_items in train_loader:
optimizer.zero_grad()
loss = model(users, pos_items, neg_items)
optimizer.step(loss)
# Evaluate and check early stopping
if epoch % args.eval_step == 0 or epoch == num_epochs:
val_metrics = evaluate(model, data, split="val", k=[10])
print(f"[Validation @ Epoch {epoch}]")
for metric, value in val_metrics.items():
print(f"{metric}: {value:.4f}")
current_score = val_metrics["ndcg@10"]
if current_score > best_score:
best_score = current_score
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= args.patience:
print(f"Early stopping at epoch {epoch}. Best NDCG@10={best_score:.4f}")
break
# Final test evaluation
test_metrics = evaluate(model, data, split="test", k=[10])
print("Test result:")
for metric, value in test_metrics.items():
print(f"{metric}: {value:.4f}")