Skip to content

Commit b15c338

Browse files
factnnclaude
andauthored
【KernelGen】Migrate replication_pad1d from experimental_ops to ops (flagos-ai#2005)
* Migrate replication_pad1d from experimental_ops to ops Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: add logger.debug for tracing * fix: add missing blank lines in test_special_ops.py * fix: complete truncated benchmark function (E999) and remove blank lines between decorators (E304) --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f748021 commit b15c338

7 files changed

Lines changed: 115 additions & 229 deletions

File tree

benchmark/test_special_perf.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,37 @@ def upsample_bicubic2d_input_fn(shape, dtype, device):
10471047
bench.run()
10481048

10491049

1050+
@pytest.mark.replication_pad1d
1051+
def test_perf_replication_pad1d():
1052+
def replication_pad1d_input_fn(config, dtype, device):
1053+
shape, padding = config
1054+
x = torch.randn(shape, dtype=dtype, device=device)
1055+
yield x, list(padding)
1056+
1057+
class ReplicationPad1dBenchmark(Benchmark):
1058+
def set_shapes(self, shape_file_path=None):
1059+
self.shapes = [
1060+
((2, 3, 7), (1, 2)),
1061+
((4, 16, 64), (3, 1)),
1062+
((8, 32, 256), (1, 2)),
1063+
((32, 256), (3, 1)),
1064+
]
1065+
1066+
def set_more_shapes(self):
1067+
return None
1068+
1069+
def get_input_iter(self, cur_dtype):
1070+
for config in self.shapes:
1071+
yield from replication_pad1d_input_fn(config, cur_dtype, self.device)
1072+
1073+
bench = ReplicationPad1dBenchmark(
1074+
op_name="replication_pad1d",
1075+
torch_op=torch.ops.aten.replication_pad1d,
1076+
dtypes=FLOAT_DTYPES,
1077+
)
1078+
bench.run()
1079+
1080+
10501081
@pytest.mark.unfold
10511082
def test_perf_unfold_backward():
10521083
def unfold_backward_input_fn(config, dtype, device):

experimental_tests/performance/replication_pad1d_test.py

Lines changed: 0 additions & 122 deletions
This file was deleted.

experimental_tests/unit/replication_pad1d_test.py

Lines changed: 0 additions & 87 deletions
This file was deleted.

src/flag_gems/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ def torch_ge(v):
314314
("repeat_interleave.self_int", repeat_interleave_self_int),
315315
("repeat_interleave.self_Tensor", repeat_interleave_self_tensor),
316316
("repeat_interleave.Tensor", repeat_interleave_tensor),
317+
("replication_pad1d", replication_pad1d),
318+
("replication_pad1d.out", replication_pad1d_out),
317319
("replication_pad3d", replication_pad3d),
318320
("resolve_conj", resolve_conj),
319321
("resolve_neg", resolve_neg),

src/flag_gems/ops/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@
210210
repeat_interleave_self_tensor,
211211
repeat_interleave_tensor,
212212
)
213+
from flag_gems.ops.replication_pad1d import replication_pad1d, replication_pad1d_out
213214
from flag_gems.ops.replication_pad3d import replication_pad3d
214215
from flag_gems.ops.resolve_conj import resolve_conj
215216
from flag_gems.ops.resolve_neg import resolve_neg
@@ -519,6 +520,8 @@
519520
"repeat_interleave_self_int",
520521
"repeat_interleave_self_tensor",
521522
"repeat_interleave_tensor",
523+
"replication_pad1d",
524+
"replication_pad1d_out",
522525
"replication_pad3d",
523526
"resolve_conj",
524527
"resolve_neg",

src/flag_gems/experimental_ops/replication_pad1d.py renamed to src/flag_gems/ops/replication_pad1d.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2+
import logging
3+
14
import torch
25
import triton
36
import triton.language as tl
47

8+
from flag_gems.runtime import torch_device_fn
9+
10+
logger = logging.getLogger(__name__)
11+
512

613
@triton.jit
714
def replication_pad1d_kernel(
@@ -46,9 +53,6 @@ def replication_pad1d_kernel(
4653

4754

4855
def _launch_replication_pad1d_kernel(input: torch.Tensor, padding, out: torch.Tensor):
49-
if not input.is_cuda or not out.is_cuda:
50-
raise RuntimeError("Triton kernels require CUDA tensors")
51-
5256
if isinstance(padding, torch.Tensor):
5357
padding = tuple(padding.tolist())
5458
left, right = int(padding[0]), int(padding[1])
@@ -68,7 +72,6 @@ def _launch_replication_pad1d_kernel(input: torch.Tensor, padding, out: torch.Te
6872
else:
6973
C, W_in = input.shape
7074
B = 1
71-
N = 1 # dummy
7275
in_s_c, in_s_w = input.stride()
7376
in_s_n = 0
7477
if out.dim() == 2:
@@ -89,26 +92,28 @@ def _launch_replication_pad1d_kernel(input: torch.Tensor, padding, out: torch.Te
8992
)
9093

9194
grid = (triton.cdiv(W_out, 256), B * C)
92-
replication_pad1d_kernel[grid](
93-
input,
94-
out,
95-
B,
96-
C,
97-
W_in,
98-
W_out,
99-
left,
100-
in_s_n if dim == 3 else in_s_n,
101-
in_s_c,
102-
in_s_w,
103-
out_s_n if (dim == 3 or out.dim() == 3) else 0,
104-
out_s_c,
105-
out_s_w,
106-
BLOCK_SIZE=256,
107-
)
95+
with torch_device_fn.device(input.device):
96+
replication_pad1d_kernel[grid](
97+
input,
98+
out,
99+
B,
100+
C,
101+
W_in,
102+
W_out,
103+
left,
104+
in_s_n if dim == 3 else in_s_n,
105+
in_s_c,
106+
in_s_w,
107+
out_s_n if (dim == 3 or out.dim() == 3) else 0,
108+
out_s_c,
109+
out_s_w,
110+
BLOCK_SIZE=256,
111+
)
108112
return out
109113

110114

111115
def replication_pad1d(input: torch.Tensor, padding):
116+
logger.debug("GEMS REPLICATION_PAD1D")
112117
if isinstance(padding, torch.Tensor):
113118
padding = tuple(padding.tolist())
114119
left, right = int(padding[0]), int(padding[1])
@@ -134,6 +139,7 @@ def replication_pad1d(input: torch.Tensor, padding):
134139

135140

136141
def replication_pad1d_out(input: torch.Tensor, padding, out: torch.Tensor):
142+
logger.debug("GEMS REPLICATION_PAD1D_OUT")
137143
if isinstance(padding, torch.Tensor):
138144
padding = tuple(padding.tolist())
139145
left, right = int(padding[0]), int(padding[1])

0 commit comments

Comments
 (0)