Skip to content

Commit af36752

Browse files
authored
[Fix] Update swing score & additional recommends (#527)
1 parent e0dd499 commit af36752

1 file changed

Lines changed: 29 additions & 18 deletions

File tree

libreco/algorithms/swing.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,26 @@ def fit(
7979
self.show_start_time()
8080
user_interacts = build_sparse(train_data.sparse_interaction)
8181
item_interacts = build_sparse(train_data.sparse_interaction, transpose=True)
82-
self.rs_model = recfarm.Swing(
83-
self.top_k,
84-
self.alpha,
85-
self.max_cache_num,
86-
self.n_users,
87-
self.n_items,
88-
user_interacts,
89-
item_interacts,
90-
self.user_consumed,
91-
self.default_pred,
92-
)
93-
with time_block("swing computing", verbose=1):
94-
self.rs_model.compute_swing(self.num_threads, self.incremental)
82+
if self.incremental:
83+
assert self.rs_model is not None
84+
with time_block("update swing", verbose=1):
85+
self.rs_model.update_swing(
86+
self.num_threads, user_interacts, item_interacts
87+
)
88+
else:
89+
self.rs_model = recfarm.Swing(
90+
self.top_k,
91+
self.alpha,
92+
self.max_cache_num,
93+
self.n_users,
94+
self.n_items,
95+
user_interacts,
96+
item_interacts,
97+
self.user_consumed,
98+
self.default_pred,
99+
)
100+
with time_block("swing computing", verbose=1):
101+
self.rs_model.compute_swing(self.num_threads)
95102

96103
num = self.rs_model.num_swing_elements()
97104
density_ratio = 100 * num / (self.n_items * self.n_items)
@@ -137,17 +144,21 @@ def recommend_user(
137144
result_recs[u] = popular_recommendations(
138145
self.data_info, inner_id, n_rec
139146
)
147+
140148
if user_ids:
141-
computed_recs, no_rec_indices = self.rs_model.recommend(
149+
computed_recs, additional_rec_counts = self.rs_model.recommend(
142150
user_ids,
143151
n_rec,
144152
filter_consumed,
145153
random_rec,
146154
)
147-
for i in no_rec_indices:
148-
computed_recs[i] = popular_recommendations(
149-
self.data_info, inner_id=True, n_rec=n_rec
150-
)
155+
for rec, arc in zip(computed_recs, additional_rec_counts):
156+
if arc > 0:
157+
additional_recs = popular_recommendations(
158+
self.data_info, inner_id=True, n_rec=arc
159+
)
160+
rec.extend(additional_recs)
161+
151162
user_recs = construct_rec(self.data_info, user_ids, computed_recs, inner_id)
152163
result_recs.update(user_recs)
153164
return result_recs

0 commit comments

Comments
 (0)