-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluate.py
More file actions
242 lines (213 loc) · 7.97 KB
/
evaluate.py
File metadata and controls
242 lines (213 loc) · 7.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import argparse
import time
import numpy as np
from baselines import greedy_distance, random_mmd
from pretender import pretender_mmd
from utils import (
compute_mmd,
create_source_matrix,
create_target_matrix,
sample_items,
split_items,
)
def main():
# Set up argument parsing.
parser = argparse.ArgumentParser(
description="Evaluate Pretender for preference transfer on a selected dataset, reporting average MMD metrics and remaining time."
)
parser.add_argument(
"--dataset",
type=str,
default="movielens",
choices=["movielens", "lastfm", "amazon"],
help="Dataset to use: 'movielens', 'lastfm', or 'amazon'.",
)
parser.add_argument(
"--split",
type=str,
default="disjoint",
choices=["disjoint", "overlap"],
help="Split type.",
)
parser.add_argument(
"--K", type=int, default=100, help="Number of items to interact with (K)."
)
parser.add_argument(
"--L",
type=int,
default=1000,
help="Number of iterations for the Frank-Wolfe algorithm.",
)
parser.add_argument(
"--rounding_trial", type=int, default=100, help="Number of rounding trials."
)
parser.add_argument(
"--C",
type=float,
default=10,
help="Scaling constant that controls the emphasis on labels.",
)
parser.add_argument(
"--sigma", type=float, default=1.0, help="Sigma parameter for the RBF kernel."
)
args = parser.parse_args()
# Set the global random seed.
np.random.seed(0)
# Load items and user preferences according to the chosen dataset.
if args.dataset == "lastfm":
from lastfm import lastfm_items, lastfm_preferences
# lastfm_items returns (X, title_dict, mapping_orig_to_new)
X, title_dict, mapping_orig_to_new = lastfm_items(
"hetrec/user_taggedartists.dat"
)
preferences = lastfm_preferences(
"hetrec/user_artists.dat", mapping_orig_to_new=mapping_orig_to_new
)
elif args.dataset == "amazon":
from amazon import amazon_items, amazon_preferences
# amazon_items returns (X, title_dict, mapping_asin_to_new)
X, title_dict, mapping_asin_to_new = amazon_items(
"reviews_Home_and_Kitchen_5.json", min_item_interactions=32
)
preferences = amazon_preferences(
"reviews_Home_and_Kitchen_5.json",
mapping_asin_to_new=mapping_asin_to_new,
min_user_interactions=32,
)
elif args.dataset == "movielens":
from movielens import movielens_items, movielens_preferences
# movielens_items returns (X, title_dict)
X, title_dict = movielens_items("ml-100k/u.item")
preferences = movielens_preferences("ml-100k/u.data")
else:
raise ValueError("Unsupported dataset.")
if args.split == "overlap":
# Sample items to form source and target sets.
X_source, title_dict_source, indices_source = sample_items(
X, title_dict, prob=0.5
)
X_target, title_dict_target, indices_target = sample_items(
X, title_dict, prob=0.5
)
elif args.split == "disjoint":
# Split items into source and target sets.
(
X_source,
title_dict_source,
indices_source,
X_target,
title_dict_target,
indices_target,
) = split_items(X, title_dict)
else:
raise ValueError("Unsupported split type.")
# Set K from the command-line argument.
K_val = args.K
# Prepare lists to store per-user MMD values.
mmds_cont = []
mmds_discs = []
mmds_greedy = []
mmds_random = []
# Get sorted list of user IDs from the preferences dictionary.
user_ids = sorted(preferences.keys())
num_users = len(user_ids)
print(f"Processing {num_users} users...")
start_time = time.time()
# Process each user.
for i, user in enumerate(user_ids):
# Estimate remaining time.
elapsed = time.time() - start_time
avg_time = elapsed / i if i > 0 else 0.0
remaining = avg_time * (num_users - (i))
print(
f"Processing user {i + 1}/{num_users}, estimated remaining time: {remaining:.2f} sec"
)
# Extract the user's preferences.
user_pref = preferences[user]
favorite_items = set(user_pref["favorite"])
unfavorite_items = set(user_pref["unfavorite"])
# Determine the positions in X_source corresponding to favorite and unfavorite items.
fav_positions = [
j for j, orig_idx in enumerate(indices_source) if orig_idx in favorite_items
]
unfav_positions = [
j
for j, orig_idx in enumerate(indices_source)
if orig_idx in unfavorite_items
]
# Create the augmented source and target matrices.
Xy_source = create_source_matrix(
X_source, fav_positions, unfav_positions, args.C
)
Xy_target = create_target_matrix(X_target, args.C)
# Compute uniform weights for the source measure.
n_source = Xy_source.shape[0]
w_source = np.ones(n_source) / n_source
# Run the Pretender algorithm, returning continuous weights.
selected_mask, w = pretender_mmd(
Xy_source,
Xy_target,
K_val=K_val,
L=args.L,
rounding_trial=args.rounding_trial,
sigma=args.sigma,
return_continuous=True,
)
# Compute MMD for the continuous solution.
mmd_cont = compute_mmd(
Xy_source, Xy_target, w1=w_source, w2=w, sigma=args.sigma
)
# Construct the discrete weight vector.
discrete_w = np.where(selected_mask, 1.0 / K_val, 0.0)
mmd_disc = compute_mmd(
Xy_source, Xy_target, w1=w_source, w2=discrete_w, sigma=args.sigma
)
# Compute MMD for the greedy baseline.
greedy_mask = greedy_distance(Xy_source, Xy_target, K_val)
discrete_w_greedy = np.where(greedy_mask, 1.0 / K_val, 0.0)
mmd_greedy = compute_mmd(
Xy_source, Xy_target, w1=w_source, w2=discrete_w_greedy, sigma=args.sigma
)
# Compute MMD for the random baseline.
random_mask = random_mmd(Xy_source, Xy_target, K_val)
discrete_w_random = np.where(random_mask, 1.0 / K_val, 0.0)
mmd_random = compute_mmd(
Xy_source, Xy_target, w1=w_source, w2=discrete_w_random, sigma=args.sigma
)
# Append results for this user.
mmds_cont.append(mmd_cont)
mmds_discs.append(mmd_disc)
mmds_greedy.append(mmd_greedy)
mmds_random.append(mmd_random)
print("\nCurrent Results (averaged across all users):")
print(
f" Continuous: mean = {np.mean(mmds_cont):.4f}, std = {np.std(mmds_cont):.4f}"
)
print(
f" Discrete: mean = {np.mean(mmds_discs):.4f}, std = {np.std(mmds_discs):.4f}"
)
print(
f" Greedy: mean = {np.mean(mmds_greedy):.4f}, std = {np.std(mmds_greedy):.4f}"
)
print(
f" Random: mean = {np.mean(mmds_random):.4f}, std = {np.std(mmds_random):.4f}"
)
# Report overall mean and standard deviation.
print("\nOverall Results (averaged across all users):")
print(
f"dataset = {args.dataset}, split = {args.split}, K = {K_val}, L = {args.L}, rounding_trial = {args.rounding_trial}, C = {args.C}, sigma = {args.sigma}"
)
print(
f" Continuous: mean = {np.mean(mmds_cont):.4f}, std = {np.std(mmds_cont):.4f}"
)
print(
f" Discrete: mean = {np.mean(mmds_discs):.4f}, std = {np.std(mmds_discs):.4f}"
)
print(
f" Greedy: mean = {np.mean(mmds_greedy):.4f}, std = {np.std(mmds_greedy):.4f}"
)
print(
f" Random: mean = {np.mean(mmds_random):.4f}, std = {np.std(mmds_random):.4f}"
)
if __name__ == "__main__":
main()