Skip to content

Commit a793094

Browse files
authored
Merge pull request #764 from camsys/overflow-protection-2
Overflow protection
2 parents d8e836f + ff17220 commit a793094

8 files changed

Lines changed: 97 additions & 17 deletions

File tree

activitysim/core/interaction_sample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def _interaction_sample(
404404
allow_zero_probs=allow_zero_probs,
405405
trace_label=trace_label,
406406
trace_choosers=choosers,
407+
overflow_protection=not allow_zero_probs,
407408
)
408409
chunk_sizer.log_df(trace_label, "probs", probs)
409410

activitysim/core/interaction_sample_simulate.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -248,20 +248,27 @@ def _interaction_sample_simulate(
248248

249249
# convert to probabilities (utilities exponentiated and normalized to probs)
250250
# probs is same shape as utilities, one row per chooser and one column for alternative
251-
probs = logit.utils_to_probs(
252-
state,
253-
utilities_df,
254-
allow_zero_probs=allow_zero_probs,
255-
trace_label=trace_label,
256-
trace_choosers=choosers,
257-
)
258-
chunk_sizer.log_df(trace_label, "probs", probs)
259-
260251
if want_logsums:
261-
logsums = logit.utils_to_logsums(
262-
utilities_df, allow_zero_probs=allow_zero_probs
252+
probs, logsums = logit.utils_to_probs(
253+
state,
254+
utilities_df,
255+
allow_zero_probs=allow_zero_probs,
256+
trace_label=trace_label,
257+
trace_choosers=choosers,
258+
overflow_protection=not allow_zero_probs,
259+
return_logsums=True,
263260
)
264261
chunk_sizer.log_df(trace_label, "logsums", logsums)
262+
else:
263+
probs = logit.utils_to_probs(
264+
state,
265+
utilities_df,
266+
allow_zero_probs=allow_zero_probs,
267+
trace_label=trace_label,
268+
trace_choosers=choosers,
269+
overflow_protection=not allow_zero_probs,
270+
)
271+
chunk_sizer.log_df(trace_label, "probs", probs)
265272

266273
del utilities_df
267274
chunk_sizer.log_df(trace_label, "utilities_df", None)

activitysim/core/logit.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import logging
6+
import warnings
67
from builtins import object
78

89
import numpy as np
@@ -130,6 +131,8 @@ def utils_to_probs(
130131
exponentiated=False,
131132
allow_zero_probs=False,
132133
trace_choosers=None,
134+
overflow_protection: bool = True,
135+
return_logsums: bool = False,
133136
):
134137
"""
135138
Convert a table of utilities to probabilities.
@@ -155,6 +158,20 @@ def utils_to_probs(
155158
by report_bad_choices because it can't deduce hh_id from the interaction_dataset
156159
which is indexed on index values from alternatives df
157160
161+
overflow_protection : bool, default True
162+
Always shift utility values such that the maximum utility in each row is
163+
zero. This constant per-row shift should not fundamentally alter the
164+
computed probabilities, but will ensure that an overflow does not occur
165+
that will create infinite or NaN values. This will also provide effective
166+
protection against underflow; extremely rare probabilities will round to
167+
zero, but by definition they are extremely rare and losing them entirely
168+
should not impact the simulation in a measureable fashion, and at least one
169+
(and sometimes only one) alternative is guaranteed to have non-zero
170+
probability, as long as at least one alternative has a finite utility value.
171+
If utility values are certain to be well-behaved and non-extreme, enabling
172+
overflow_protection will have no benefit but impose a modest computational
173+
overhead cost.
174+
158175
Returns
159176
-------
160177
probs : pandas.DataFrame
@@ -167,9 +184,27 @@ def utils_to_probs(
167184
# utils_arr = utils.values.astype('float')
168185
utils_arr = utils.values
169186

170-
if utils_arr.dtype == np.float32 and utils_arr.max() > 85:
187+
if allow_zero_probs:
188+
if overflow_protection:
189+
warnings.warn(
190+
"cannot set overflow_protection with allow_zero_probs", stacklevel=2
191+
)
192+
overflow_protection = utils_arr.dtype == np.float32 and utils_arr.max() > 85
193+
if overflow_protection:
194+
raise ValueError(
195+
"cannot prevent expected overflow with allow_zero_probs"
196+
)
197+
else:
198+
overflow_protection = overflow_protection or (
199+
utils_arr.dtype == np.float32 and utils_arr.max() > 85
200+
)
201+
202+
if overflow_protection:
171203
# exponentiated utils will overflow, downshift them
172-
utils_arr -= utils_arr.max(1, keepdims=True)
204+
shifts = utils_arr.max(1, keepdims=True)
205+
utils_arr -= shifts
206+
else:
207+
shifts = None
173208

174209
if not exponentiated:
175210
# TODO: reduce memory usage by exponentiating in-place.
@@ -185,6 +220,15 @@ def utils_to_probs(
185220

186221
arr_sum = utils_arr.sum(axis=1)
187222

223+
if return_logsums:
224+
with np.errstate(divide="ignore" if allow_zero_probs else "warn"):
225+
logsums = np.log(arr_sum)
226+
if shifts is not None:
227+
logsums += np.squeeze(shifts, 1)
228+
logsums = pd.Series(logsums, index=utils.index)
229+
else:
230+
logsums = None
231+
188232
if not allow_zero_probs:
189233
zero_probs = arr_sum == 0.0
190234
if zero_probs.any():
@@ -222,6 +266,8 @@ def utils_to_probs(
222266

223267
probs = pd.DataFrame(utils_arr, columns=utils.columns, index=utils.index)
224268

269+
if return_logsums:
270+
return probs, logsums
225271
return probs
226272

227273

activitysim/core/pathbuilder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,7 @@ def build_virtual_path(
994994
utilities_df,
995995
allow_zero_probs=True,
996996
trace_label=trace_label,
997+
overflow_protection=False,
997998
)
998999
chunk_sizer.log_df(trace_label, "probs", probs)
9991000

activitysim/core/simulate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,7 @@ def compute_nested_probabilities(
10321032
trace_label=trace_label,
10331033
exponentiated=True,
10341034
allow_zero_probs=True,
1035+
overflow_protection=False,
10351036
)
10361037

10371038
nested_probabilities = pd.concat([nested_probabilities, probs], axis=1)

activitysim/core/test/test_logit.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,40 @@ def test_utils_to_probs_raises():
8181
idx = pd.Index(name="household_id", data=[1])
8282
with pytest.raises(RuntimeError) as excinfo:
8383
logit.utils_to_probs(
84-
state, pd.DataFrame([[1, 2, np.inf, 3]], index=idx), trace_label=None
84+
state,
85+
pd.DataFrame([[1, 2, np.inf, 3]], index=idx),
86+
trace_label=None,
87+
overflow_protection=False,
8588
)
8689
assert "infinite exponentiated utilities" in str(excinfo.value)
8790

8891
with pytest.raises(RuntimeError) as excinfo:
8992
logit.utils_to_probs(
90-
state, pd.DataFrame([[-999, -999, -999, -999]], index=idx), trace_label=None
93+
state,
94+
pd.DataFrame([[1, 2, 9999, 3]], index=idx),
95+
trace_label=None,
96+
overflow_protection=False,
97+
)
98+
assert "infinite exponentiated utilities" in str(excinfo.value)
99+
100+
with pytest.raises(RuntimeError) as excinfo:
101+
logit.utils_to_probs(
102+
state,
103+
pd.DataFrame([[-999, -999, -999, -999]], index=idx),
104+
trace_label=None,
105+
overflow_protection=False,
91106
)
92107
assert "all probabilities are zero" in str(excinfo.value)
93108

109+
# test that overflow protection works
110+
z = logit.utils_to_probs(
111+
state,
112+
pd.DataFrame([[1, 2, 9999, 3]], index=idx),
113+
trace_label=None,
114+
overflow_protection=True,
115+
)
116+
assert np.asarray(z).ravel() == pytest.approx(np.asarray([0.0, 0.0, 1.0, 0.0]))
117+
94118

95119
def test_make_choices_only_one():
96120
state = workflow.State().default_settings()

activitysim/examples/placeholder_sandag/test/regress/final_1_zone_tours_sh.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ tour_id,person_id,tour_type,tour_type_count,tour_type_num,tour_num,tour_count,to
2020
2373898,57899,work,1,1,1,1,mandatory,1,3402.0,3746.0,20552,47.0,7.0,17.0,10.0,,,WALK,1.0388895039783694,no_subtours,,0out_0in,work
2121
2373980,57901,work,2,1,1,2,mandatory,1,3115.0,3746.0,20552,25.0,6.0,12.0,6.0,,,SHARED3FREE,0.6022315390131013,no_subtours,,0out_0in,work
2222
2373981,57901,work,2,2,2,2,mandatory,1,3115.0,3746.0,20552,150.0,15.0,20.0,5.0,,,SHARED2FREE,0.6232767878249469,no_subtours,,1out_0in,work
23-
2563802,62531,school,1,1,1,1,mandatory,1,3460.0,3316.0,21869,180.0,20.0,20.0,0.0,,,SHARED3FREE,-0.7094603590463964,,,0out_0in,school
23+
2563802,62531,school,1,1,1,1,mandatory,1,3460.0,3316.0,21869,181.0,20.0,21.0,1.0,,,SHARED3FREE,-0.7094603590463964,,,0out_0in,school
2424
2563821,62532,escort,1,1,1,1,non_mandatory,1,3398.0,3316.0,21869,20.0,6.0,7.0,1.0,,12.499268454965652,SHARED2FREE,-1.4604154628072699,,,0out_0in,escort
2525
2563862,62533,escort,3,1,1,4,non_mandatory,1,3402.0,3316.0,21869,1.0,5.0,6.0,1.0,,12.534424209198946,SHARED3FREE,-1.2940574569954848,,,0out_3in,escort
2626
2563863,62533,escort,3,2,2,4,non_mandatory,1,3519.0,3316.0,21869,99.0,11.0,11.0,0.0,,12.466623656700463,SHARED2FREE,-0.9326373013150777,,,0out_0in,escort

activitysim/examples/placeholder_sandag/test/regress/final_1_zone_trips_sh.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ trip_id,person_id,household_id,primary_purpose,trip_num,outbound,trip_count,dest
5555
18991850,57901,20552,work,2,True,2,3115,3460,2373981,work,,16,DRIVEALONEFREE,0.10597046751418379
5656
18991853,57901,20552,work,1,False,1,3746,3115,2373981,home,,20,SHARED2FREE,0.23660752783217825
5757
20510417,62531,21869,school,1,True,1,3460,3316,2563802,school,,20,SHARED3FREE,-1.4448137456466916
58-
20510421,62531,21869,school,1,False,1,3316,3460,2563802,home,,20,WALK,-1.5207459403958272
58+
20510421,62531,21869,school,1,False,1,3316,3460,2563802,home,,21,WALK,-1.5207459403958272
5959
20510569,62532,21869,escort,1,True,1,3398,3316,2563821,escort,,6,SHARED2FREE,0.17869598454022895
6060
20510573,62532,21869,escort,1,False,1,3316,3398,2563821,home,,7,DRIVEALONEFREE,0.20045149458253975
6161
20510897,62533,21869,escort,1,True,1,3402,3316,2563862,escort,,5,SHARED3FREE,0.7112775892674524

0 commit comments

Comments
 (0)