Skip to content

Commit 2eabd90

Browse files
committed
Merge branch 'release_revisions'
- begin post-ICML implementation cleanup, consolidation, and revision
2 parents 648c3b0 + aae59fc commit 2eabd90

3 files changed

Lines changed: 141 additions & 219 deletions

File tree

prefpy/evbwie.py

Lines changed: 138 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
from . import stats
99

1010

11-
class EMMMixPLResult:
11+
#deprecated
12+
class _EMMMixPLResult_legacy:
13+
"""
14+
Description:
15+
Legacy class used to generate EMM solutions files for all experiments
16+
prior to publication of Zhao, Piech, & Xia (2016). All new code should
17+
use the new class.
18+
"""
1219
def __init__(self, num_alts, num_votes, num_mix, true_params, epsilon, max_iters, epsilon_mm, max_iters_mm, init_guess, soln_params, runtime):
1320
self.num_alts = num_alts
1421
self.num_votes = num_votes
@@ -22,6 +29,25 @@ def __init__(self, num_alts, num_votes, num_mix, true_params, epsilon, max_iters
2229
self.soln_params = soln_params
2330
self.runtime = runtime
2431

32+
class EMMMixPLResult:
33+
"""
34+
Description:
35+
Class used to store important values pertaining to an execution of the
36+
EMM algorithm and its return. This class contains values to assist
37+
further investigations of solutions provided by this method.
38+
"""
39+
def __init__(self, num_alts, num_votes, num_mix, true_params, epsilon, epsilon_mm, iters, init_guess, soln_params, runtime):
40+
self.num_alts = num_alts
41+
self.num_votes = num_votes
42+
self.num_mix = num_mix
43+
self.true_params = true_params
44+
self.epsilon = epsilon
45+
self.epsilon_mm = epsilon_mm
46+
self.iters = iters
47+
self.init_guess = init_guess
48+
self.soln_params = soln_params
49+
self.runtime = runtime
50+
2551
class EMMMixPLAggregator(aggregate.RankAggregator):
2652

2753
def c(x_i, j):
@@ -71,10 +97,51 @@ def omega(k, j, z, x):
7197
sum_out += sum_in
7298
return sum_out
7399

74-
def aggregate(self, rankings, K, epsilon, max_iters, epsilon_mm, max_iters_mm):
100+
def aggregate(self, rankings, K, epsilon, epsilon_mm, iters):
101+
"""
102+
Description:
103+
Takes in a set of rankings and computes the model
104+
parameters for a mixture of Plackett-Luce models.
105+
Parameters:
106+
rankings: set of rankings to aggregate
107+
K: number of mixture components to compute
108+
epsilon: convergence condition threshold value for overall EM algorithm
109+
epsilon_mm: convergence condition threshold value for MM algorithm
110+
iters: dict, iterations configuration for EM and MM algorithms
111+
"""
75112
x = rankings # shorter pseudonym for voting data
76113
self.n = len(rankings) # number of votes
77114

115+
# "fixed" iterations type variables
116+
outer_iters = None
117+
inner_iters = None
118+
inner_range = None
119+
120+
# Additional "scaling" iterations type variables
121+
inner_iters_base = None
122+
scaling_divisor = None
123+
124+
# Additional "total" iterations type variables
125+
total_iters = None
126+
isIncremented = False
127+
128+
if "type" not in iters:
129+
raise ValueError("iters dict must contain key \"type\"")
130+
iters_type = iters["type"]
131+
if iters_type == "fixed":
132+
outer_iters = iters["em_iters"]
133+
inner_iters = iters["mm_iters"]
134+
elif iters_type == "scaling":
135+
outer_iters = iters["em_iters"]
136+
inner_iters_base = iters["mm_iters_base"]
137+
scaling_divisor = iters["scaling_divisor"]
138+
elif iters_type == "total":
139+
total_iters = iters["total_iters"]
140+
outer_iters = iters["em_iters"]
141+
inner_iters = total_iters // outer_iters
142+
else:
143+
raise ValueError("iters dict value for key \"type\" is invalid: " + str(iters_type))
144+
78145
# pre-compute the delta values
79146
delta_i_j_s = np.empty((self.n, self.m, self.m + 1))
80147
for i in range(self.n):
@@ -92,45 +159,29 @@ def aggregate(self, rankings, K, epsilon, max_iters, epsilon_mm, max_iters_mm):
92159
p_h = np.copy(p_h0)
93160
pi_h = np.copy(pi_h0)
94161

95-
for g in range(max_iters):
162+
for g in range(outer_iters):
96163

97164
p_h1 = np.empty((K, self.m))
98165
pi_h1 = np.empty(K)
99166
z_h1 = np.empty((self.n, K))
100167

101168
# E-Step:
102-
for i in range(self.n):
103-
for k in range(K):
104-
denom_sum = 0
105-
for k2 in range(K):
106-
denom_sum += pi_h[k2] * EMMMixPLAggregator.f(x[i], p_h[k2])
107-
z_h1[i][k] = (pi_h[k] * EMMMixPLAggregator.f(x[i], p_h[k])) / denom_sum
169+
self._EStep(K, x, z_h1, pi_h, p_h)
108170

109171
# M-Step:
110-
for l in range(max_iters_mm):
111-
#for l in range(int(g/50) + 5):
112-
for k in range(K):
113-
normconst = 0
114-
if l == 0: # only need to compute pi at first MM iteration
115-
pi_h1[k] = np.sum(z_h1.T[k]) / len(z_h1)
116-
for j in range(self.m):
117-
omega_k_j = EMMMixPLAggregator.omega(k, j, z_h1, x) # numerator
118-
denom_sum = 0
119-
for i in range(self.n):
120-
sum1 = 0
121-
for t in range(len(x[i])):
122-
sum2 = 0
123-
sum3 = 0
124-
for s in range(t, self.m):
125-
sum2 += p_h[k][EMMMixPLAggregator.c(x[i], s)]
126-
for s in range(t, self.m + 1):
127-
sum3 += delta_i_j_s[i][j][s]
128-
sum1 += z_h1[i][k] * (sum2 ** -1) * sum3
129-
denom_sum += sum1
130-
p_h1[k][j] = omega_k_j / denom_sum
131-
normconst += p_h1[k][j]
132-
for j in range(self.m):
133-
p_h1[k][j] /= normconst
172+
if iters_type == "fixed":
173+
inner_range = range(inner_iters)
174+
elif iters_type == "total" and not isIncremented:
175+
test = (g + 1) * inner_iters + (outer_iters - g - 1) * (inner_iters + 1)
176+
if test < total_iters:
177+
inner_iters += 1
178+
isIncremented = True
179+
inner_range = range(inner_iters)
180+
elif iters_type == "scaling":
181+
inner_range = range(int(g/scaling_divisor) + inner_iters_base)
182+
183+
for l in inner_range:
184+
self._MStep(l, K, x, z_h1, pi_h1, p_h1, p_h, delta_i_j_s)
134185

135186
if (epsilon_mm != None and
136187
np.all(np.absolute(p_h1 - p_h) < epsilon_mm)):
@@ -150,6 +201,51 @@ def aggregate(self, rankings, K, epsilon, max_iters, epsilon_mm, max_iters_mm):
150201

151202
return (pi_h1, p_h1, pi_h0, p_h0)
152203

204+
205+
def _EStep(self, K, x, z_h1, pi_h, p_h):
206+
"""
207+
Description:
208+
Internal function for computing the E-Step of the EMM algorithm.
209+
"""
210+
# E-Step:
211+
for i in range(self.n):
212+
for k in range(K):
213+
denom_sum = 0
214+
for k2 in range(K):
215+
denom_sum += pi_h[k2] * EMMMixPLAggregator.f(x[i], p_h[k2])
216+
z_h1[i][k] = (pi_h[k] * EMMMixPLAggregator.f(x[i], p_h[k])) / denom_sum
217+
218+
219+
def _MStep(self, l, K, x, z_h1, pi_h1, p_h1, p_h, delta_i_j_s):
220+
"""
221+
Description:
222+
Internal function for computing the M-Step of the EMM algorithm,
223+
which is itself an MM algorithm.
224+
"""
225+
for k in range(K):
226+
normconst = 0
227+
if l == 0: # only need to compute pi at first MM iteration
228+
pi_h1[k] = np.sum(z_h1.T[k]) / len(z_h1)
229+
for j in range(self.m):
230+
omega_k_j = EMMMixPLAggregator.omega(k, j, z_h1, x) # numerator
231+
denom_sum = 0
232+
for i in range(self.n):
233+
sum1 = 0
234+
for t in range(len(x[i])):
235+
sum2 = 0
236+
sum3 = 0
237+
for s in range(t, self.m):
238+
sum2 += p_h[k][EMMMixPLAggregator.c(x[i], s)]
239+
for s in range(t, self.m + 1):
240+
sum3 += delta_i_j_s[i][j][s]
241+
sum1 += z_h1[i][k] * (sum2 ** -1) * sum3
242+
denom_sum += sum1
243+
p_h1[k][j] = omega_k_j / denom_sum
244+
normconst += p_h1[k][j]
245+
for j in range(self.m):
246+
p_h1[k][j] /= normconst
247+
248+
153249
def main():
154250
n = 100
155251
m = 4
@@ -161,7 +257,14 @@ def main():
161257
print("EMM Algorithm:")
162258

163259
emmagg = EMMMixPLAggregator(cand_set)
164-
pi, p = emmagg.aggregate(votes, K=2, epsilon=1e-8, max_iters=1000, epsilon_mm=1e-8, max_iters_mm=10)
260+
pi, p, pi_h0, p_h0 = emmagg.aggregate(votes,
261+
K=2,
262+
epsilon=None,
263+
epsilon_mm=None,
264+
iters={"type" : "fixed",
265+
"em_iters": 20,
266+
"mm_iters": 5}
267+
)
165268

166269
sol_params = np.empty(2*m+1)
167270
sol_params[0] = pi[0]
@@ -171,7 +274,7 @@ def main():
171274
print("Ground-Truth Parameters:\n" + str(params))
172275
print("Final Solution:\n" + str(sol_params))
173276
print("\t\"1 - alpha\" = " + str(pi[1]))
174-
print("WSSE:\n" + str(stats.mix2PL_wsse(params, sol_params, m)))
277+
print("MSE:\n" + str(stats.mix2PL_sse(params, sol_params, m)))
175278

176279
if __name__ == "__main__":
177280
main()

0 commit comments

Comments
 (0)