Skip to content

Commit 06d37df

Browse files
[Advanced Compiler]Add Unfold backward (flagos-ai#1784)
1 parent 8bc56db commit 06d37df

5 files changed

Lines changed: 162 additions & 0 deletions

File tree

benchmark/test_special_perf.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,3 +905,43 @@ def input_kwargs(shape, dtype, device):
905905
)
906906
bench.set_gems(flag_gems.per_token_group_quant_fp8)
907907
bench.run()
908+
909+
910+
@pytest.mark.unfold
911+
def test_perf_unfold_backward():
912+
def unfold_backward_input_fn(config, dtype, device):
913+
input_sizes, dim, size, step = config
914+
d = dim % len(input_sizes)
915+
num_windows = (input_sizes[d] - size) // step + 1
916+
grad_shape = (
917+
list(input_sizes[:d]) + [num_windows] + list(input_sizes[d + 1 :]) + [size]
918+
)
919+
grad_in = torch.randn(grad_shape, dtype=dtype, device=device)
920+
yield grad_in, list(input_sizes), dim, size, step
921+
922+
class UnfoldBackwardBenchmark(Benchmark):
923+
def set_shapes(self, shape_file_path=None):
924+
self.shapes = [
925+
((32, 64), 1, 16, 16),
926+
((16, 33), 0, 5, 2),
927+
((4, 8, 12), -1, 6, 4),
928+
((7, 13), 1, 13, 3),
929+
((6, 20), 1, 7, 4),
930+
((2, 3, 17), -1, 9, 1),
931+
((2, 17), 1, 4, 6),
932+
]
933+
934+
def set_more_shapes(self):
935+
return None
936+
937+
def get_input_iter(self, cur_dtype):
938+
for config in self.shapes:
939+
yield from unfold_backward_input_fn(config, cur_dtype, self.device)
940+
941+
bench = UnfoldBackwardBenchmark(
942+
op_name="unfold_backward",
943+
torch_op=torch.ops.aten.unfold_backward,
944+
dtypes=[torch.float16, torch.float32, torch.bfloat16],
945+
)
946+
bench.set_gems(flag_gems.unfold_backward)
947+
bench.run()

src/flag_gems/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def torch_ge(v):
344344
("true_divide.Tensor", true_divide),
345345
("true_divide_.Scalar", true_divide_),
346346
("true_divide_.Tensor", true_divide_),
347+
("unfold_backward", unfold_backward),
347348
("uniform_", uniform_),
348349
("upsample_linear1d", upsample_linear1d),
349350
("upsample_nearest1d", upsample_nearest1d),

src/flag_gems/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@
217217
from flag_gems.ops.topk import topk
218218
from flag_gems.ops.trace import trace
219219
from flag_gems.ops.triu import triu, triu_
220+
from flag_gems.ops.unfold_backward import unfold_backward
220221
from flag_gems.ops.uniform import uniform_
221222
from flag_gems.ops.unique import _unique2
222223
from flag_gems.ops.upsample_bicubic2d_aa import _upsample_bicubic2d_aa
@@ -528,6 +529,7 @@
528529
"true_divide",
529530
"true_divide_",
530531
"true_divide_out",
532+
"unfold_backward",
531533
"uniform_",
532534
"upsample_linear1d",
533535
"upsample_nearest1d",
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import logging
2+
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
@triton.jit
11+
def _unfold_backward_kernel(
12+
grad_in_ptr,
13+
grad_out_ptr,
14+
numel_in,
15+
prod_after,
16+
L,
17+
size,
18+
step,
19+
D,
20+
inner_total,
21+
BLOCK: tl.constexpr,
22+
):
23+
pid = tl.program_id(0)
24+
offs = pid * BLOCK + tl.arange(0, BLOCK)
25+
mask = offs < numel_in
26+
27+
vals = tl.load(grad_in_ptr + offs, mask=mask, other=0)
28+
vals_f32 = tl.cast(vals, tl.float32)
29+
30+
k = offs % size
31+
tmp1 = offs // size
32+
after_lin = tmp1 % prod_after
33+
tmp2 = offs // (prod_after * size)
34+
s = tmp2 % L
35+
before_lin = offs // inner_total
36+
37+
pos = s * step + k
38+
39+
out_id = ((before_lin * D) + pos) * prod_after + after_lin
40+
41+
tl.atomic_add(grad_out_ptr + out_id, vals_f32, mask=mask)
42+
43+
44+
def unfold_backward(
45+
grad_in: torch.Tensor, input_sizes, dim: int, size: int, step: int
46+
) -> torch.Tensor:
47+
logger.debug("GEMS UNFOLD BACKWARD")
48+
if step <= 0:
49+
raise ValueError("step must be > 0")
50+
51+
if not isinstance(input_sizes, (list, tuple)):
52+
input_sizes = list(input_sizes)
53+
input_sizes = [int(s) for s in input_sizes]
54+
ndim = len(input_sizes)
55+
d = dim % ndim
56+
57+
D = int(input_sizes[d])
58+
L = (D - int(size)) // int(step) + 1
59+
60+
prod_after = 1
61+
for s_ in input_sizes[d + 1 :]:
62+
prod_after *= int(s_)
63+
inner_total = int(L) * int(prod_after) * int(size)
64+
65+
device = grad_in.device
66+
grad_out_f32 = torch.zeros(input_sizes, dtype=torch.float32, device=device)
67+
68+
numel_in = grad_in.numel()
69+
70+
BLOCK = 128
71+
grid = lambda meta: (triton.cdiv(numel_in, meta["BLOCK"]),)
72+
73+
_unfold_backward_kernel[grid](
74+
grad_in,
75+
grad_out_f32,
76+
numel_in,
77+
prod_after,
78+
L,
79+
size,
80+
step,
81+
D,
82+
inner_total,
83+
BLOCK=BLOCK,
84+
)
85+
86+
if grad_in.dtype != torch.float32:
87+
return grad_out_f32.to(grad_in.dtype)
88+
return grad_out_f32

tests/test_special_ops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,3 +1929,34 @@ def _verify_expert_level_sorting(
19291929
gems_assert_close(
19301930
num_tokens_post_pad, to_reference(num_tokens_post_pad_vllm), dtype=dtype
19311931
)
1932+
1933+
1934+
@pytest.mark.unfold
1935+
@pytest.mark.parametrize(
1936+
"input_sizes, dim, size, step",
1937+
[
1938+
((32, 64), 1, 16, 16),
1939+
((16, 33), 0, 5, 2),
1940+
((4, 8, 12), -1, 6, 4),
1941+
((7, 13), 1, 13, 3),
1942+
((6, 20), 1, 7, 4),
1943+
((2, 3, 17), -1, 9, 1),
1944+
((2, 17), 1, 4, 6),
1945+
],
1946+
)
1947+
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16])
1948+
def test_unfold_backward(input_sizes, dim, size, step, dtype):
1949+
d = dim % len(input_sizes)
1950+
num_windows = (input_sizes[d] - size) // step + 1
1951+
grad_shape = (
1952+
list(input_sizes[:d]) + [num_windows] + list(input_sizes[d + 1 :]) + [size]
1953+
)
1954+
1955+
grad_in = torch.randn(grad_shape, dtype=dtype, device=device)
1956+
1957+
ref_grad = to_reference(grad_in, True)
1958+
ref_out = torch.ops.aten.unfold_backward(ref_grad, input_sizes, dim, size, step)
1959+
1960+
with flag_gems.use_gems():
1961+
res_out = flag_gems.unfold_backward(grad_in, input_sizes, dim, size, step)
1962+
gems_assert_close(res_out, ref_out, dtype, reduce_dim=size)

0 commit comments

Comments
 (0)