batch = ChoiceDataset(...)
bemb = BEMBFlex(..., pred_item=True, ...)
proba = bemb.predict_proba(batch) # shape = (len(batch), num_items)
batch = ChoiceDataset(...)
# not that batch doesn't need to have a label attribute.
bemb = BEMBFlex(..., pred_item=False, num_classes=..., ...)
proba = bemb.predict_proba(batch) # shape = (len(batch), num_classes)
Towards BEMB
v1.0.Corresponding branch:
api-updateWe are planning to refine and expand the current API of
BEMBFlex.The
pred_itemand multiple class prediction.item_index[i]). This only allows us to do binary classification . We might have to drop this feature.batch.labelor multi-classbatch.item_index. We plan to support arbitrary multi-class classifications.pred_item=True, the model will know the number of classes is exactly thenum_itemsparameter. Also, in this case, yourChoiceDatasetobject does not need to have alabelattribute, since the model will look for theitem_indexas the ground truth for training.pred_item=False, now you need to supply anum_classesto theBEMBFlex.__init__()method. Also, you would need alabelattribute in theChoiceDatasetobject. Thelabelattribute should be aLongTensorwith values from{0, 1, ..., num_classes}.Post-Estimation
predict_proba(), the same name as inference methods of scikit-learn models.@torch.no_grad()as a decorator, so you can use it however you want without being worried about gradient tracking.pred_items = True, thebatchneedsitem_indexattribute only if it's involved in the utility computation (e.g., within-category computation).pred_items = False,thebatchdoes not need to have alabelattribute.predict_proba()is used as the following:Renaming Variables.
price-variation is ambiguous, we propose to change it tosessionitem-variation instead (this is precisely the definition of such variables).