diff --git a/2.vae.ipynb b/2.vae.ipynb index a1a9e37..e35f326 100644 --- a/2.vae.ipynb +++ b/2.vae.ipynb @@ -399,10 +399,10 @@ "\n", "def load_atten(model, param):\n", " model.norm.load_state_dict(param.group_norm.state_dict())\n", - " model.q.load_state_dict(param.query.state_dict())\n", - " model.k.load_state_dict(param.key.state_dict())\n", - " model.v.load_state_dict(param.value.state_dict())\n", - " model.out.load_state_dict(param.proj_attn.state_dict())\n", + " model.q.load_state_dict(param.to_q.state_dict())\n", + " model.k.load_state_dict(param.to_k.state_dict())\n", + " model.v.load_state_dict(param.to_v.state_dict())\n", + " model.out.load_state_dict(param.to_out[0].state_dict())\n", "\n", "\n", "#encoder.in\n",