how to solve this problem
Traceback (most recent call last):
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/generate.py", line 522, in
generate(args)
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/generate.py", line 415, in generate
video = wan_ti2v.generate(
^^^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/textimage2video.py", line 227, in generate
return self.t2v(
^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/textimage2video.py", line 380, in t2v
noise_pred_cond = self.model(
^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/modules/model.py", line 490, in forward
x = block(x, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/modules/model.py", line 243, in forward
y = self.self_attn(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/modules/model.py", line 145, in forward
x = flash_attention(
^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/modules/attention.py", line 110, in flash_attention
deterministic=deterministic)[0].unflatten(0, (b, lq))
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_tensor.py", line 1376, in unflatten
return super().unflatten(dim, sizes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: unflatten: Provided sizes [1, 27280] don't multiply up to the size of dim 0 (24) in the input tensor
how to solve this problem
Traceback (most recent call last):
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/generate.py", line 522, in
generate(args)
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/generate.py", line 415, in generate
video = wan_ti2v.generate(
^^^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/textimage2video.py", line 227, in generate
return self.t2v(
^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/textimage2video.py", line 380, in t2v
noise_pred_cond = self.model(
^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/modules/model.py", line 490, in forward
x = block(x, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/modules/model.py", line 243, in forward
y = self.self_attn(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/modules/model.py", line 145, in forward
x = flash_attention(
^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/modules/attention.py", line 110, in flash_attention
deterministic=deterministic)[0].unflatten(0, (b, lq))
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_tensor.py", line 1376, in unflatten
return super().unflatten(dim, sizes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: unflatten: Provided sizes [1, 27280] don't multiply up to the size of dim 0 (24) in the input tensor