When compute_dynamics_loss is called in world_model.py, it is fed in an argument called preds returned from dyn_model.predict(). For z prediction, preds have keys z_dist and z_hat_probs and there is no key z.
When computing the dynamics loss, specifically the loss for z, the code checks for whether z is in the list of keys preds which it never is (because the keys are z_dist and z_hat_probs). For this reason the z_pred_loss is never computed & used for training.
A fix should be made here, something like :
if 'z_dist' in preds: ...
When compute_dynamics_loss is called in world_model.py, it is fed in an argument called
predsreturned from dyn_model.predict(). Forzprediction, preds have keysz_distandz_hat_probsand there is no keyz.When computing the dynamics loss, specifically the loss for
z, the code checks for whetherzis in the list of keyspredswhich it never is (because the keys arez_distandz_hat_probs). For this reason thez_pred_lossis never computed & used for training.A fix should be made here, something like :