diff --git a/wan/modules/model.py b/wan/modules/model.py index 6982fa15..1c240ee6 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -458,6 +458,9 @@ def forward( # time embeddings if t.dim() == 1: + # Note: expand only works for dimensions of size 1; view(-1, 1) must be called first. + t = t.view(-1, 1).expand(t.size(0), seq_len) + elif t.dim() == 2 and t.size(1) == 1: t = t.expand(t.size(0), seq_len) with torch.amp.autocast('cuda', dtype=torch.float32): bt = t.size(0)