@@ -79,19 +79,26 @@ def fit(
7979 self .show_start_time ()
8080 user_interacts = build_sparse (train_data .sparse_interaction )
8181 item_interacts = build_sparse (train_data .sparse_interaction , transpose = True )
82- self .rs_model = recfarm .Swing (
83- self .top_k ,
84- self .alpha ,
85- self .max_cache_num ,
86- self .n_users ,
87- self .n_items ,
88- user_interacts ,
89- item_interacts ,
90- self .user_consumed ,
91- self .default_pred ,
92- )
93- with time_block ("swing computing" , verbose = 1 ):
94- self .rs_model .compute_swing (self .num_threads , self .incremental )
82+ if self .incremental :
83+ assert self .rs_model is not None
84+ with time_block ("update swing" , verbose = 1 ):
85+ self .rs_model .update_swing (
86+ self .num_threads , user_interacts , item_interacts
87+ )
88+ else :
89+ self .rs_model = recfarm .Swing (
90+ self .top_k ,
91+ self .alpha ,
92+ self .max_cache_num ,
93+ self .n_users ,
94+ self .n_items ,
95+ user_interacts ,
96+ item_interacts ,
97+ self .user_consumed ,
98+ self .default_pred ,
99+ )
100+ with time_block ("swing computing" , verbose = 1 ):
101+ self .rs_model .compute_swing (self .num_threads )
95102
96103 num = self .rs_model .num_swing_elements ()
97104 density_ratio = 100 * num / (self .n_items * self .n_items )
@@ -137,17 +144,21 @@ def recommend_user(
137144 result_recs [u ] = popular_recommendations (
138145 self .data_info , inner_id , n_rec
139146 )
147+
140148 if user_ids :
141- computed_recs , no_rec_indices = self .rs_model .recommend (
149+ computed_recs , additional_rec_counts = self .rs_model .recommend (
142150 user_ids ,
143151 n_rec ,
144152 filter_consumed ,
145153 random_rec ,
146154 )
147- for i in no_rec_indices :
148- computed_recs [i ] = popular_recommendations (
149- self .data_info , inner_id = True , n_rec = n_rec
150- )
155+ for rec , arc in zip (computed_recs , additional_rec_counts ):
156+ if arc > 0 :
157+ additional_recs = popular_recommendations (
158+ self .data_info , inner_id = True , n_rec = arc
159+ )
160+ rec .extend (additional_recs )
161+
151162 user_recs = construct_rec (self .data_info , user_ids , computed_recs , inner_id )
152163 result_recs .update (user_recs )
153164 return result_recs
0 commit comments