Skip to content

Commit 5589e3b

Browse files
authored
[SiliconFlow] Op: Reflection pad2d (flagos-ai#1930)
1 parent 2ea7fad commit 5589e3b

5 files changed

Lines changed: 288 additions & 0 deletions

File tree

benchmark/test_special_perf.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
import torch
6+
import triton
67

78
import flag_gems
89
from benchmark.attri_util import BOOL_DTYPES, FLOAT_DTYPES, INT_DTYPES, BenchLevel
@@ -1025,6 +1026,54 @@ def input_kwargs(shape, dtype, device):
10251026
bench.run()
10261027

10271028

1029+
@pytest.mark.reflection_pad2d
1030+
@pytest.mark.parametrize(
1031+
"shape",
1032+
[
1033+
(3, 33, 33),
1034+
(2, 4, 32, 64),
1035+
(8, 16, 64, 64),
1036+
(32, 64, 128, 256),
1037+
(16, 32, 64, 128),
1038+
],
1039+
)
1040+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
1041+
@pytest.mark.parametrize(
1042+
"padding",
1043+
[
1044+
(1, 1, 1, 1),
1045+
(2, 3, 2, 3),
1046+
(3, 5, 3, 5),
1047+
(0, 4, 0, 4),
1048+
],
1049+
)
1050+
def test_reflection_pad2d_benchmark_tensor(shape, dtype, padding):
1051+
quantiles = [0.5, 0.2, 0.8]
1052+
1053+
x = torch.randn(shape, dtype=dtype, device=flag_gems.device)
1054+
ref_x = x.clone()
1055+
1056+
# PyTorch reference implementation
1057+
ms_torch, _, _ = triton.testing.do_bench(
1058+
lambda: torch.ops.aten.reflection_pad2d(ref_x, padding),
1059+
rep=100,
1060+
quantiles=quantiles,
1061+
)
1062+
1063+
# Triton implementation
1064+
with flag_gems.use_gems():
1065+
ms_triton, _, _ = triton.testing.do_bench(
1066+
lambda: flag_gems.reflection_pad2d(x, padding), rep=100, quantiles=quantiles
1067+
)
1068+
1069+
# Calculate speedup and return result
1070+
speedup = ms_torch / ms_triton
1071+
1072+
print(f"reflection_pad2d {shape} {dtype}:")
1073+
print(f" FlagGems: {ms_triton:.3f}ms")
1074+
print(f" Speedup: {speedup:.2f}x")
1075+
1076+
10281077
@pytest.mark.upsample_bicubic2d
10291078
@pytest.mark.parametrize("align_corners", [False, True])
10301079
def test_perf_upsample_bicubic2d(align_corners):

src/flag_gems/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ def torch_ge(v):
307307
("randperm", randperm),
308308
("reciprocal", reciprocal),
309309
("reciprocal_", reciprocal_),
310+
("reflection_pad2d", reflection_pad2d),
311+
("reflection_pad2d.out", reflection_pad2d_out),
310312
("relu", relu),
311313
("relu_", relu_),
312314
("relu6", relu6),

src/flag_gems/ops/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@
205205
from flag_gems.ops.randn_like import randn_like
206206
from flag_gems.ops.randperm import randperm
207207
from flag_gems.ops.reciprocal import reciprocal, reciprocal_
208+
from flag_gems.ops.reflection_pad2d import reflection_pad2d, reflection_pad2d_out
208209
from flag_gems.ops.relu import relu, relu_
209210
from flag_gems.ops.relu6 import relu6
210211
from flag_gems.ops.repeat import repeat
@@ -523,6 +524,8 @@
523524
"randperm",
524525
"reciprocal",
525526
"reciprocal_",
527+
"reflection_pad2d",
528+
"reflection_pad2d_out",
526529
"relu",
527530
"relu_",
528531
"relu6",
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import logging
2+
import math
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
@triton.jit
12+
def reflection_pad2d_kernel(
13+
in_ptr,
14+
out_ptr,
15+
B,
16+
H_in,
17+
W_in,
18+
pad_left,
19+
pad_top,
20+
H_out,
21+
W_out,
22+
BLOCK_HW: tl.constexpr,
23+
):
24+
pid_b = tl.program_id(axis=0)
25+
pid_n = tl.program_id(axis=1)
26+
27+
# Flatten 2D index to 1D for block processing
28+
offs_n = pid_n * BLOCK_HW + tl.arange(0, BLOCK_HW)
29+
# Decode to (h, w) coordinates
30+
h_idx = offs_n // W_out
31+
w_idx = offs_n % W_out
32+
33+
mask = (offs_n < H_out * W_out) & (pid_b < B)
34+
35+
base_in = pid_b * (H_in * W_in)
36+
base_out = pid_b * (H_out * W_out)
37+
38+
# Compute reflected indices for height
39+
y = h_idx.to(tl.int32) - pad_top
40+
Hm1 = H_in - 1
41+
pH = 2 * Hm1
42+
t_h = tl.abs(y)
43+
m_h = t_h % pH
44+
ih = tl.where(m_h < H_in, m_h, pH - m_h)
45+
46+
# Compute reflected indices for width
47+
x = w_idx.to(tl.int32) - pad_left
48+
Wm1 = W_in - 1
49+
pW = 2 * Wm1
50+
t_w = tl.abs(x)
51+
m_w = t_w % pW
52+
iw = tl.where(m_w < W_in, m_w, pW - m_w)
53+
54+
# Load from input and store to output
55+
in_offs = ih * W_in + iw
56+
vals = tl.load(in_ptr + base_in + in_offs, mask=mask, other=0)
57+
tl.store(out_ptr + base_out + offs_n, vals, mask=mask)
58+
59+
60+
@triton.jit
61+
def copy_tensor_kernel(in_ptr, out_ptr, B, H, W, BLOCK_HW: tl.constexpr):
62+
pid_b = tl.program_id(axis=0)
63+
pid_n = tl.program_id(axis=1)
64+
65+
offs_n = pid_n * BLOCK_HW + tl.arange(0, BLOCK_HW)
66+
mask = (offs_n < H * W) & (pid_b < B)
67+
68+
base = pid_b * (H * W)
69+
vals = tl.load(in_ptr + base + offs_n, mask=mask, other=0)
70+
tl.store(out_ptr + base + offs_n, vals, mask=mask)
71+
72+
73+
def launch_reflection_pad2d(input: torch.Tensor, padding, out: torch.Tensor = None):
74+
# Validate padding format
75+
if not isinstance(padding, (list, tuple)):
76+
raise ValueError("padding must be a sequence")
77+
if len(padding) != 4:
78+
raise ValueError(
79+
"padding must be a sequence of length 4: (pad_left, pad_right, pad_top, pad_bottom)"
80+
)
81+
pad_left, pad_right, pad_top, pad_bottom = [int(p) for p in padding]
82+
83+
# Validate padding values
84+
if pad_left < 0 or pad_right < 0 or pad_top < 0 or pad_bottom < 0:
85+
raise ValueError("padding values must be >= 0")
86+
87+
# Validate input
88+
if input.dim() < 3:
89+
raise ValueError("input must have at least 3 dimensions")
90+
if not input.is_cuda:
91+
raise ValueError("input must be a CUDA tensor")
92+
93+
x = input.contiguous()
94+
H_in = int(x.shape[-2])
95+
W_in = int(x.shape[-1])
96+
# Validate reflection padding constraints
97+
if H_in < 2 or W_in < 2:
98+
raise ValueError(
99+
"input spatial dimensions must be at least 2 for reflection padding when padding > 0"
100+
)
101+
if H_in <= 0 or W_in <= 0:
102+
raise ValueError("spatial dimensions must be > 0")
103+
if pad_left >= W_in or pad_right >= W_in or pad_top >= H_in or pad_bottom >= H_in:
104+
raise ValueError(
105+
"padding values must be less than the input spatial dimensions for reflection padding"
106+
)
107+
108+
H_out = H_in + pad_top + pad_bottom
109+
W_out = W_in + pad_left + pad_right
110+
111+
leading_shape = x.shape[:-2]
112+
B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1
113+
114+
# Handle output tensor
115+
if out is None:
116+
out = torch.empty(
117+
(*leading_shape, H_out, W_out), device=x.device, dtype=x.dtype
118+
)
119+
else:
120+
if not out.is_cuda:
121+
raise ValueError("out must be a CUDA tensor")
122+
expected_shape = (*leading_shape, H_out, W_out)
123+
if tuple(out.shape) != expected_shape:
124+
raise ValueError(
125+
f"out tensor has shape {tuple(out.shape)}, expected {expected_shape}"
126+
)
127+
if out.dtype != x.dtype:
128+
raise ValueError(
129+
f"out dtype {out.dtype} does not match input dtype {x.dtype}"
130+
)
131+
if out.device != x.device:
132+
raise ValueError("out must be on the same device as input")
133+
out = out.contiguous()
134+
135+
# No padding: just copy
136+
if pad_left == 0 and pad_right == 0 and pad_top == 0 and pad_bottom == 0:
137+
BLOCK_HW = 256
138+
grid = (B, triton.cdiv(H_in * W_in, BLOCK_HW))
139+
copy_tensor_kernel[grid](x, out, B, H_in, W_in, BLOCK_HW=BLOCK_HW)
140+
return out
141+
142+
BLOCK_HW = 256
143+
grid = (B, triton.cdiv(H_out * W_out, BLOCK_HW))
144+
reflection_pad2d_kernel[grid](
145+
x, out, B, H_in, W_in, pad_left, pad_top, H_out, W_out, BLOCK_HW=BLOCK_HW
146+
)
147+
return out
148+
149+
150+
def reflection_pad2d(input: torch.Tensor, padding):
151+
logger.debug("GEMS REFLECTION_PAD2D")
152+
return launch_reflection_pad2d(input, padding, out=None)
153+
154+
155+
def reflection_pad2d_out(input: torch.Tensor, padding, out: torch.Tensor):
156+
logger.debug("GEMS REFLECTION_PAD2D_OUT")
157+
return launch_reflection_pad2d(input, padding, out=out)

tests/test_special_ops.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,6 +2060,83 @@ def _verify_expert_level_sorting(
20602060
)
20612061

20622062

2063+
@pytest.mark.reflection_pad2d
2064+
@pytest.mark.parametrize(
2065+
"shape", [(3, 33, 33), (2, 4, 32, 64), (8, 16, 64, 64), (32, 64, 128, 256)]
2066+
)
2067+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
2068+
@pytest.mark.parametrize(
2069+
"padding",
2070+
[
2071+
(1, 1, 1, 1),
2072+
(2, 3, 2, 3),
2073+
(3, 5, 3, 5),
2074+
(0, 4, 0, 4),
2075+
(4, 0, 4, 0),
2076+
],
2077+
)
2078+
def test_reflection_pad2d(shape, dtype, padding):
2079+
x = torch.randn(shape, dtype=dtype, device=flag_gems.device)
2080+
2081+
ref_x = to_reference(x)
2082+
ref_out = torch.ops.aten.reflection_pad2d(ref_x, padding)
2083+
2084+
with flag_gems.use_gems():
2085+
act_out = flag_gems.reflection_pad2d(x, padding)
2086+
2087+
gems_assert_close(act_out, ref_out, dtype, equal_nan=True)
2088+
2089+
2090+
@pytest.mark.reflection_pad2d
2091+
@pytest.mark.parametrize("padding", [[1, 1, 1, 1], [2, 3, 4, 5]])
2092+
def test_reflection_pad2d_list_padding(padding):
2093+
# Test with list format: [pad_left, pad_right, pad_top, pad_bottom]
2094+
shape = (2, 4, 32, 64)
2095+
dtype = torch.float32
2096+
x = torch.randn(shape, dtype=dtype, device=flag_gems.device)
2097+
2098+
ref_x = to_reference(x.clone())
2099+
ref_out = torch.ops.aten.reflection_pad2d(ref_x, padding)
2100+
2101+
with flag_gems.use_gems():
2102+
act_out = flag_gems.reflection_pad2d(x, padding)
2103+
2104+
gems_assert_close(act_out, ref_out, dtype, equal_nan=True)
2105+
2106+
2107+
@pytest.mark.reflection_pad2d
2108+
def test_reflection_pad2d_empty_padding():
2109+
shape = (2, 4, 32, 64)
2110+
dtype = torch.float32
2111+
padding = (0, 0, 0, 0)
2112+
x = torch.randn(shape, dtype=dtype, device=flag_gems.device)
2113+
2114+
ref_x = to_reference(x.clone())
2115+
ref_out = torch.ops.aten.reflection_pad2d(ref_x, padding)
2116+
2117+
with flag_gems.use_gems():
2118+
act_out = flag_gems.reflection_pad2d(x, padding)
2119+
2120+
gems_assert_close(act_out, ref_out, dtype, equal_nan=True)
2121+
2122+
2123+
@pytest.mark.reflection_pad2d
2124+
@pytest.mark.parametrize("padding", [(1, 1, 1, 1), (2, 3, 4, 5)])
2125+
def test_reflection_pad2d_3d_input(padding):
2126+
# Test with 3D input (C, H, W) - no batch dimension
2127+
shape = (3, 32, 64)
2128+
dtype = torch.float32
2129+
x = torch.randn(shape, dtype=dtype, device=flag_gems.device)
2130+
2131+
ref_x = to_reference(x.clone())
2132+
ref_out = torch.ops.aten.reflection_pad2d(ref_x, padding)
2133+
2134+
with flag_gems.use_gems():
2135+
act_out = flag_gems.reflection_pad2d(x, padding)
2136+
2137+
gems_assert_close(act_out, ref_out, dtype, equal_nan=True)
2138+
2139+
20632140
@pytest.mark.upsample_bicubic2d
20642141
@pytest.mark.parametrize(
20652142
"N, C, H, W, outH, outW, align_corners, use_scale",

0 commit comments

Comments
 (0)