From 3ee4394c3de37ada3739928d8e175b1f912fe546 Mon Sep 17 00:00:00 2001 From: luxuantao Date: Tue, 13 Jun 2023 15:23:22 +0800 Subject: [PATCH] fix 2.vae.ipynb load weight bug --- 2.vae.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/2.vae.ipynb b/2.vae.ipynb index a1a9e37..e35f326 100644 --- a/2.vae.ipynb +++ b/2.vae.ipynb @@ -399,10 +399,10 @@ "\n", "def load_atten(model, param):\n", " model.norm.load_state_dict(param.group_norm.state_dict())\n", - " model.q.load_state_dict(param.query.state_dict())\n", - " model.k.load_state_dict(param.key.state_dict())\n", - " model.v.load_state_dict(param.value.state_dict())\n", - " model.out.load_state_dict(param.proj_attn.state_dict())\n", + " model.q.load_state_dict(param.to_q.state_dict())\n", + " model.k.load_state_dict(param.to_k.state_dict())\n", + " model.v.load_state_dict(param.to_v.state_dict())\n", + " model.out.load_state_dict(param.to_out[0].state_dict())\n", "\n", "\n", "#encoder.in\n",