88from . import stats
99
1010
11- class EMMMixPLResult :
11+ #deprecated
12+ class _EMMMixPLResult_legacy :
13+ """
14+ Description:
15+ Legacy class used to generate EMM solutions files for all experiments
16+ prior to publication of Zhao, Piech, & Xia (2016). All new code should
17+ use the new class.
18+ """
1219 def __init__ (self , num_alts , num_votes , num_mix , true_params , epsilon , max_iters , epsilon_mm , max_iters_mm , init_guess , soln_params , runtime ):
1320 self .num_alts = num_alts
1421 self .num_votes = num_votes
@@ -22,6 +29,25 @@ def __init__(self, num_alts, num_votes, num_mix, true_params, epsilon, max_iters
2229 self .soln_params = soln_params
2330 self .runtime = runtime
2431
32+ class EMMMixPLResult :
33+ """
34+ Description:
35+ Class used to store important values pertaining to an execution of the
36+ EMM algorithm and its return. This class contains values to assist
37+ further investigations of solutions provided by this method.
38+ """
39+ def __init__ (self , num_alts , num_votes , num_mix , true_params , epsilon , epsilon_mm , iters , init_guess , soln_params , runtime ):
40+ self .num_alts = num_alts
41+ self .num_votes = num_votes
42+ self .num_mix = num_mix
43+ self .true_params = true_params
44+ self .epsilon = epsilon
45+ self .epsilon_mm = epsilon_mm
46+ self .iters = iters
47+ self .init_guess = init_guess
48+ self .soln_params = soln_params
49+ self .runtime = runtime
50+
2551class EMMMixPLAggregator (aggregate .RankAggregator ):
2652
2753 def c (x_i , j ):
@@ -71,10 +97,51 @@ def omega(k, j, z, x):
7197 sum_out += sum_in
7298 return sum_out
7399
74- def aggregate (self , rankings , K , epsilon , max_iters , epsilon_mm , max_iters_mm ):
100+ def aggregate (self , rankings , K , epsilon , epsilon_mm , iters ):
101+ """
102+ Description:
103+ Takes in a set of rankings and computes the model
104+ parameters for a mixture of Plackett-Luce models.
105+ Parameters:
106+ rankings: set of rankings to aggregate
107+ K: number of mixture components to compute
108+ epsilon: convergence condition threshold value for overall EM algorithm
109+ epsilon_mm: convergence condition threshold value for MM algorithm
110+ iters: dict, iterations configuration for EM and MM algorithms
111+ """
75112 x = rankings # shorter pseudonym for voting data
76113 self .n = len (rankings ) # number of votes
77114
115+ # "fixed" iterations type variables
116+ outer_iters = None
117+ inner_iters = None
118+ inner_range = None
119+
120+ # Additional "scaling" iterations type variables
121+ inner_iters_base = None
122+ scaling_divisor = None
123+
124+ # Additional "total" iterations type variables
125+ total_iters = None
126+ isIncremented = False
127+
128+ if "type" not in iters :
129+ raise ValueError ("iters dict must contain key \" type\" " )
130+ iters_type = iters ["type" ]
131+ if iters_type == "fixed" :
132+ outer_iters = iters ["em_iters" ]
133+ inner_iters = iters ["mm_iters" ]
134+ elif iters_type == "scaling" :
135+ outer_iters = iters ["em_iters" ]
136+ inner_iters_base = iters ["mm_iters_base" ]
137+ scaling_divisor = iters ["scaling_divisor" ]
138+ elif iters_type == "total" :
139+ total_iters = iters ["total_iters" ]
140+ outer_iters = iters ["em_iters" ]
141+ inner_iters = total_iters // outer_iters
142+ else :
143+ raise ValueError ("iters dict value for key \" type\" is invalid: " + str (iters_type ))
144+
78145 # pre-compute the delta values
79146 delta_i_j_s = np .empty ((self .n , self .m , self .m + 1 ))
80147 for i in range (self .n ):
@@ -92,45 +159,29 @@ def aggregate(self, rankings, K, epsilon, max_iters, epsilon_mm, max_iters_mm):
92159 p_h = np .copy (p_h0 )
93160 pi_h = np .copy (pi_h0 )
94161
95- for g in range (max_iters ):
162+ for g in range (outer_iters ):
96163
97164 p_h1 = np .empty ((K , self .m ))
98165 pi_h1 = np .empty (K )
99166 z_h1 = np .empty ((self .n , K ))
100167
101168 # E-Step:
102- for i in range (self .n ):
103- for k in range (K ):
104- denom_sum = 0
105- for k2 in range (K ):
106- denom_sum += pi_h [k2 ] * EMMMixPLAggregator .f (x [i ], p_h [k2 ])
107- z_h1 [i ][k ] = (pi_h [k ] * EMMMixPLAggregator .f (x [i ], p_h [k ])) / denom_sum
169+ self ._EStep (K , x , z_h1 , pi_h , p_h )
108170
109171 # M-Step:
110- for l in range (max_iters_mm ):
111- #for l in range(int(g/50) + 5):
112- for k in range (K ):
113- normconst = 0
114- if l == 0 : # only need to compute pi at first MM iteration
115- pi_h1 [k ] = np .sum (z_h1 .T [k ]) / len (z_h1 )
116- for j in range (self .m ):
117- omega_k_j = EMMMixPLAggregator .omega (k , j , z_h1 , x ) # numerator
118- denom_sum = 0
119- for i in range (self .n ):
120- sum1 = 0
121- for t in range (len (x [i ])):
122- sum2 = 0
123- sum3 = 0
124- for s in range (t , self .m ):
125- sum2 += p_h [k ][EMMMixPLAggregator .c (x [i ], s )]
126- for s in range (t , self .m + 1 ):
127- sum3 += delta_i_j_s [i ][j ][s ]
128- sum1 += z_h1 [i ][k ] * (sum2 ** - 1 ) * sum3
129- denom_sum += sum1
130- p_h1 [k ][j ] = omega_k_j / denom_sum
131- normconst += p_h1 [k ][j ]
132- for j in range (self .m ):
133- p_h1 [k ][j ] /= normconst
172+ if iters_type == "fixed" :
173+ inner_range = range (inner_iters )
174+ elif iters_type == "total" and not isIncremented :
175+ test = (g + 1 ) * inner_iters + (outer_iters - g - 1 ) * (inner_iters + 1 )
176+ if test < total_iters :
177+ inner_iters += 1
178+ isIncremented = True
179+ inner_range = range (inner_iters )
180+ elif iters_type == "scaling" :
181+ inner_range = range (int (g / scaling_divisor ) + inner_iters_base )
182+
183+ for l in inner_range :
184+ self ._MStep (l , K , x , z_h1 , pi_h1 , p_h1 , p_h , delta_i_j_s )
134185
135186 if (epsilon_mm != None and
136187 np .all (np .absolute (p_h1 - p_h ) < epsilon_mm )):
@@ -150,6 +201,51 @@ def aggregate(self, rankings, K, epsilon, max_iters, epsilon_mm, max_iters_mm):
150201
151202 return (pi_h1 , p_h1 , pi_h0 , p_h0 )
152203
204+
205+ def _EStep (self , K , x , z_h1 , pi_h , p_h ):
206+ """
207+ Description:
208+ Internal function for computing the E-Step of the EMM algorithm.
209+ """
210+ # E-Step:
211+ for i in range (self .n ):
212+ for k in range (K ):
213+ denom_sum = 0
214+ for k2 in range (K ):
215+ denom_sum += pi_h [k2 ] * EMMMixPLAggregator .f (x [i ], p_h [k2 ])
216+ z_h1 [i ][k ] = (pi_h [k ] * EMMMixPLAggregator .f (x [i ], p_h [k ])) / denom_sum
217+
218+
219+ def _MStep (self , l , K , x , z_h1 , pi_h1 , p_h1 , p_h , delta_i_j_s ):
220+ """
221+ Description:
222+ Internal function for computing the M-Step of the EMM algorithm,
223+ which is itself an MM algorithm.
224+ """
225+ for k in range (K ):
226+ normconst = 0
227+ if l == 0 : # only need to compute pi at first MM iteration
228+ pi_h1 [k ] = np .sum (z_h1 .T [k ]) / len (z_h1 )
229+ for j in range (self .m ):
230+ omega_k_j = EMMMixPLAggregator .omega (k , j , z_h1 , x ) # numerator
231+ denom_sum = 0
232+ for i in range (self .n ):
233+ sum1 = 0
234+ for t in range (len (x [i ])):
235+ sum2 = 0
236+ sum3 = 0
237+ for s in range (t , self .m ):
238+ sum2 += p_h [k ][EMMMixPLAggregator .c (x [i ], s )]
239+ for s in range (t , self .m + 1 ):
240+ sum3 += delta_i_j_s [i ][j ][s ]
241+ sum1 += z_h1 [i ][k ] * (sum2 ** - 1 ) * sum3
242+ denom_sum += sum1
243+ p_h1 [k ][j ] = omega_k_j / denom_sum
244+ normconst += p_h1 [k ][j ]
245+ for j in range (self .m ):
246+ p_h1 [k ][j ] /= normconst
247+
248+
153249def main ():
154250 n = 100
155251 m = 4
@@ -161,7 +257,14 @@ def main():
161257 print ("EMM Algorithm:" )
162258
163259 emmagg = EMMMixPLAggregator (cand_set )
164- pi , p = emmagg .aggregate (votes , K = 2 , epsilon = 1e-8 , max_iters = 1000 , epsilon_mm = 1e-8 , max_iters_mm = 10 )
260+ pi , p , pi_h0 , p_h0 = emmagg .aggregate (votes ,
261+ K = 2 ,
262+ epsilon = None ,
263+ epsilon_mm = None ,
264+ iters = {"type" : "fixed" ,
265+ "em_iters" : 20 ,
266+ "mm_iters" : 5 }
267+ )
165268
166269 sol_params = np .empty (2 * m + 1 )
167270 sol_params [0 ] = pi [0 ]
@@ -171,7 +274,7 @@ def main():
171274 print ("Ground-Truth Parameters:\n " + str (params ))
172275 print ("Final Solution:\n " + str (sol_params ))
173276 print ("\t \" 1 - alpha\" = " + str (pi [1 ]))
174- print ("WSSE :\n " + str (stats .mix2PL_wsse (params , sol_params , m )))
277+ print ("MSE :\n " + str (stats .mix2PL_sse (params , sol_params , m )))
175278
176279if __name__ == "__main__" :
177280 main ()
0 commit comments