Skip to content

unflatten error #141

@zengbohan0217

Description

@zengbohan0217

how to solve this problem

Traceback (most recent call last):
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/generate.py", line 522, in
generate(args)
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/generate.py", line 415, in generate
video = wan_ti2v.generate(
^^^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/textimage2video.py", line 227, in generate
return self.t2v(
^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/textimage2video.py", line 380, in t2v
noise_pred_cond = self.model(
^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/modules/model.py", line 490, in forward
x = block(x, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/modules/model.py", line 243, in forward
y = self.self_attn(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/modules/model.py", line 145, in forward
x = flash_attention(
^^^^^^^^^^^^^^^^
File "/mnt/hdfs/user/zengbohan/code/video_gen/Wan2.2/wan/modules/attention.py", line 110, in flash_attention
deterministic=deterministic)[0].unflatten(0, (b, lq))
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_tensor.py", line 1376, in unflatten
return super().unflatten(dim, sizes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: unflatten: Provided sizes [1, 27280] don't multiply up to the size of dim 0 (24) in the input tensor

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions