|
| 1 | +import logging |
| 2 | +import math |
| 3 | + |
| 4 | +import torch |
| 5 | +import triton |
| 6 | +import triton.language as tl |
| 7 | + |
| 8 | +logger = logging.getLogger(__name__) |
| 9 | + |
| 10 | + |
| 11 | +@triton.jit |
| 12 | +def reflection_pad2d_kernel( |
| 13 | + in_ptr, |
| 14 | + out_ptr, |
| 15 | + B, |
| 16 | + H_in, |
| 17 | + W_in, |
| 18 | + pad_left, |
| 19 | + pad_top, |
| 20 | + H_out, |
| 21 | + W_out, |
| 22 | + BLOCK_HW: tl.constexpr, |
| 23 | +): |
| 24 | + pid_b = tl.program_id(axis=0) |
| 25 | + pid_n = tl.program_id(axis=1) |
| 26 | + |
| 27 | + # Flatten 2D index to 1D for block processing |
| 28 | + offs_n = pid_n * BLOCK_HW + tl.arange(0, BLOCK_HW) |
| 29 | + # Decode to (h, w) coordinates |
| 30 | + h_idx = offs_n // W_out |
| 31 | + w_idx = offs_n % W_out |
| 32 | + |
| 33 | + mask = (offs_n < H_out * W_out) & (pid_b < B) |
| 34 | + |
| 35 | + base_in = pid_b * (H_in * W_in) |
| 36 | + base_out = pid_b * (H_out * W_out) |
| 37 | + |
| 38 | + # Compute reflected indices for height |
| 39 | + y = h_idx.to(tl.int32) - pad_top |
| 40 | + Hm1 = H_in - 1 |
| 41 | + pH = 2 * Hm1 |
| 42 | + t_h = tl.abs(y) |
| 43 | + m_h = t_h % pH |
| 44 | + ih = tl.where(m_h < H_in, m_h, pH - m_h) |
| 45 | + |
| 46 | + # Compute reflected indices for width |
| 47 | + x = w_idx.to(tl.int32) - pad_left |
| 48 | + Wm1 = W_in - 1 |
| 49 | + pW = 2 * Wm1 |
| 50 | + t_w = tl.abs(x) |
| 51 | + m_w = t_w % pW |
| 52 | + iw = tl.where(m_w < W_in, m_w, pW - m_w) |
| 53 | + |
| 54 | + # Load from input and store to output |
| 55 | + in_offs = ih * W_in + iw |
| 56 | + vals = tl.load(in_ptr + base_in + in_offs, mask=mask, other=0) |
| 57 | + tl.store(out_ptr + base_out + offs_n, vals, mask=mask) |
| 58 | + |
| 59 | + |
| 60 | +@triton.jit |
| 61 | +def copy_tensor_kernel(in_ptr, out_ptr, B, H, W, BLOCK_HW: tl.constexpr): |
| 62 | + pid_b = tl.program_id(axis=0) |
| 63 | + pid_n = tl.program_id(axis=1) |
| 64 | + |
| 65 | + offs_n = pid_n * BLOCK_HW + tl.arange(0, BLOCK_HW) |
| 66 | + mask = (offs_n < H * W) & (pid_b < B) |
| 67 | + |
| 68 | + base = pid_b * (H * W) |
| 69 | + vals = tl.load(in_ptr + base + offs_n, mask=mask, other=0) |
| 70 | + tl.store(out_ptr + base + offs_n, vals, mask=mask) |
| 71 | + |
| 72 | + |
| 73 | +def launch_reflection_pad2d(input: torch.Tensor, padding, out: torch.Tensor = None): |
| 74 | + # Validate padding format |
| 75 | + if not isinstance(padding, (list, tuple)): |
| 76 | + raise ValueError("padding must be a sequence") |
| 77 | + if len(padding) != 4: |
| 78 | + raise ValueError( |
| 79 | + "padding must be a sequence of length 4: (pad_left, pad_right, pad_top, pad_bottom)" |
| 80 | + ) |
| 81 | + pad_left, pad_right, pad_top, pad_bottom = [int(p) for p in padding] |
| 82 | + |
| 83 | + # Validate padding values |
| 84 | + if pad_left < 0 or pad_right < 0 or pad_top < 0 or pad_bottom < 0: |
| 85 | + raise ValueError("padding values must be >= 0") |
| 86 | + |
| 87 | + # Validate input |
| 88 | + if input.dim() < 3: |
| 89 | + raise ValueError("input must have at least 3 dimensions") |
| 90 | + if not input.is_cuda: |
| 91 | + raise ValueError("input must be a CUDA tensor") |
| 92 | + |
| 93 | + x = input.contiguous() |
| 94 | + H_in = int(x.shape[-2]) |
| 95 | + W_in = int(x.shape[-1]) |
| 96 | + # Validate reflection padding constraints |
| 97 | + if H_in < 2 or W_in < 2: |
| 98 | + raise ValueError( |
| 99 | + "input spatial dimensions must be at least 2 for reflection padding when padding > 0" |
| 100 | + ) |
| 101 | + if H_in <= 0 or W_in <= 0: |
| 102 | + raise ValueError("spatial dimensions must be > 0") |
| 103 | + if pad_left >= W_in or pad_right >= W_in or pad_top >= H_in or pad_bottom >= H_in: |
| 104 | + raise ValueError( |
| 105 | + "padding values must be less than the input spatial dimensions for reflection padding" |
| 106 | + ) |
| 107 | + |
| 108 | + H_out = H_in + pad_top + pad_bottom |
| 109 | + W_out = W_in + pad_left + pad_right |
| 110 | + |
| 111 | + leading_shape = x.shape[:-2] |
| 112 | + B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1 |
| 113 | + |
| 114 | + # Handle output tensor |
| 115 | + if out is None: |
| 116 | + out = torch.empty( |
| 117 | + (*leading_shape, H_out, W_out), device=x.device, dtype=x.dtype |
| 118 | + ) |
| 119 | + else: |
| 120 | + if not out.is_cuda: |
| 121 | + raise ValueError("out must be a CUDA tensor") |
| 122 | + expected_shape = (*leading_shape, H_out, W_out) |
| 123 | + if tuple(out.shape) != expected_shape: |
| 124 | + raise ValueError( |
| 125 | + f"out tensor has shape {tuple(out.shape)}, expected {expected_shape}" |
| 126 | + ) |
| 127 | + if out.dtype != x.dtype: |
| 128 | + raise ValueError( |
| 129 | + f"out dtype {out.dtype} does not match input dtype {x.dtype}" |
| 130 | + ) |
| 131 | + if out.device != x.device: |
| 132 | + raise ValueError("out must be on the same device as input") |
| 133 | + out = out.contiguous() |
| 134 | + |
| 135 | + # No padding: just copy |
| 136 | + if pad_left == 0 and pad_right == 0 and pad_top == 0 and pad_bottom == 0: |
| 137 | + BLOCK_HW = 256 |
| 138 | + grid = (B, triton.cdiv(H_in * W_in, BLOCK_HW)) |
| 139 | + copy_tensor_kernel[grid](x, out, B, H_in, W_in, BLOCK_HW=BLOCK_HW) |
| 140 | + return out |
| 141 | + |
| 142 | + BLOCK_HW = 256 |
| 143 | + grid = (B, triton.cdiv(H_out * W_out, BLOCK_HW)) |
| 144 | + reflection_pad2d_kernel[grid]( |
| 145 | + x, out, B, H_in, W_in, pad_left, pad_top, H_out, W_out, BLOCK_HW=BLOCK_HW |
| 146 | + ) |
| 147 | + return out |
| 148 | + |
| 149 | + |
| 150 | +def reflection_pad2d(input: torch.Tensor, padding): |
| 151 | + logger.debug("GEMS REFLECTION_PAD2D") |
| 152 | + return launch_reflection_pad2d(input, padding, out=None) |
| 153 | + |
| 154 | + |
| 155 | +def reflection_pad2d_out(input: torch.Tensor, padding, out: torch.Tensor): |
| 156 | + logger.debug("GEMS REFLECTION_PAD2D_OUT") |
| 157 | + return launch_reflection_pad2d(input, padding, out=out) |
0 commit comments