diff --git a/swift/model/chunk_gated_delta_rule.py b/swift/model/chunk_gated_delta_rule.py index c54423b3a0..d7da2a71f5 100644 --- a/swift/model/chunk_gated_delta_rule.py +++ b/swift/model/chunk_gated_delta_rule.py @@ -4,13 +4,13 @@ import torch import warnings -from mindspeed.lite.ops.triton.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h -from mindspeed.lite.ops.triton.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o -from mindspeed.lite.ops.triton.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd -from mindspeed.lite.ops.triton.cumsum import chunk_local_cumsum -from mindspeed.lite.ops.triton.solve_tril import solve_tril -from mindspeed.lite.ops.triton.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard -from mindspeed.lite.ops.triton.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd +from mindspeed.ops.triton.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h +from mindspeed.ops.triton.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o +from mindspeed.ops.triton.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from mindspeed.ops.triton.cumsum import chunk_local_cumsum +from mindspeed.ops.triton.solve_tril import solve_tril +from mindspeed.ops.triton.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard +from mindspeed.ops.triton.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd from typing import Optional