diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml
index 5589de9c..f6f99017 100644
--- a/.github/workflows/python-publish.yml
+++ b/.github/workflows/python-publish.yml
@@ -15,12 +15,12 @@ on:
jobs:
- build_wheels:
- name: Build wheels on ${{ matrix.os }}
+ build_native_wheels:
+ name: Build native wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
- os: [ubuntu-latest, windows-latest, macos-13, macos-14]
+ os: [ubuntu-latest, windows-latest, macos-15-intel, macos-15]
steps:
- uses: actions/checkout@v4
@@ -28,15 +28,17 @@ jobs:
# Used to host cibuildwheel
- uses: actions/setup-python@v5
- - name: Install cibuildwheel
+ - name: Install cibuildwheel and native deps
run: |
python -m pip install --upgrade pip
python -m pip install cibuildwheel==3.2.1
python fetch_dependencies.py
- - name: Build wheels
+ - name: Build native wheels
env:
CIBW_SKIP: 'pp* manylinux_i686 musllinux*'
+ VKDISPATCH_BUILD_TARGET: native
+ CIBW_ENVIRONMENT: VKDISPATCH_BUILD_TARGET=native
run: python -m cibuildwheel --output-dir wheelhouse
# to supply options, put them in 'env', like:
@@ -47,28 +49,44 @@ jobs:
with:
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
path: ./wheelhouse/*.whl
- build_sdist:
- name: Build source distribution
+ build_python_dists:
+ name: Build native/core/meta sdists and pure wheels
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- - name: Install dependencies
+ - uses: actions/setup-python@v5
+
+ - name: Install build tooling
run: |
python -m pip install --upgrade pip
+ python -m pip install build
+
+ - name: Build native source distribution
+ env:
+ VKDISPATCH_BUILD_TARGET: native
+ run: |
python fetch_dependencies.py
+ python -m build --sdist --outdir dist
+
+ - name: Build core wheel and source distribution
+ env:
+ VKDISPATCH_BUILD_TARGET: core
+ run: python -m build --wheel --sdist --outdir dist
- - name: Build sdist
- run: pipx run build --sdist
+ - name: Build meta wheel and source distribution
+ env:
+ VKDISPATCH_BUILD_TARGET: meta
+ run: python -m build --wheel --sdist --outdir dist
- uses: actions/upload-artifact@v4
with:
- name: cibw-sdist
- path: dist/*.tar.gz
+ name: cibw-python-dists
+ path: dist/*
publish-to-pypi:
name: Publish Python package to PyPI
# if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
- needs: [build_wheels, build_sdist]
+ needs: [build_native_wheels, build_python_dists]
runs-on: ubuntu-latest
environment:
name: pypi
diff --git a/docs/Makefile b/docs/Makefile
index 4bf195e2..4c660da8 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -27,11 +27,9 @@ bundle_lib:
@rm -rf "$(LIB_DEST)/vkdispatch"
@mkdir -p "$(LIB_DEST)"
@rm -f "$(LIB_DEST)/$(LIB_BUNDLE)"
- @rm -f "$(LIB_DEST)/vkdispatch_native.brython.js"
@rm -rf "$(LIB_STAGE)"
@mkdir -p "$(LIB_STAGE)"
@cp -r ../vkdispatch "$(LIB_STAGE)/vkdispatch"
- @cp -r special_pages/libs/vkdispatch_native "$(LIB_STAGE)/vkdispatch_native"
@cd "$(LIB_STAGE)" && $(PYTHON) -m brython make_package vkdispatch \
--src-dir . \
--output-path "$(CURDIR)/$(LIB_DEST)/$(LIB_BUNDLE)"
diff --git a/docs/special/brython_shader_lab.rst b/docs/special/brython_shader_lab.rst
deleted file mode 100644
index aeeffe87..00000000
--- a/docs/special/brython_shader_lab.rst
+++ /dev/null
@@ -1,16 +0,0 @@
-Brython Shader Lab
-==================
-
-This page redirects to a standalone HTML app page.
-
-.. raw:: html
-
-
-
-
- Redirecting to the Brython shader lab page.
- If you are not redirected, open
- the standalone HTML page.
-
diff --git a/docs/special/index.rst b/docs/special/index.rst
index da840951..370fe451 100644
--- a/docs/special/index.rst
+++ b/docs/special/index.rst
@@ -6,4 +6,4 @@ Standalone pages integrated into the docs navigation.
.. toctree::
:maxdepth: 1
- brython_shader_lab
+ shader_playground
diff --git a/docs/special/shader_playground.rst b/docs/special/shader_playground.rst
new file mode 100644
index 00000000..364329e1
--- /dev/null
+++ b/docs/special/shader_playground.rst
@@ -0,0 +1,16 @@
+Shader Playground
+==================
+
+This page redirects to a standalone HTML app page.
+
+.. raw:: html
+
+
+ Redirecting to the shader playground page.
+ If you are not redirected, open
+ the standalone HTML page.
+
diff --git a/docs/special_pages/libs/vkdispatch_native/__init__.py b/docs/special_pages/libs/vkdispatch_native/__init__.py
deleted file mode 100644
index 673b054f..00000000
--- a/docs/special_pages/libs/vkdispatch_native/__init__.py
+++ /dev/null
@@ -1,1107 +0,0 @@
-"""Brython-friendly pure-Python shim for ``vkdispatch_native``.
-
-This module mirrors the Cython-exposed API used by ``vkdispatch`` and provides
-an in-memory fake runtime suitable for docs execution and shader-source
-compilation paths.
-"""
-
-# NOTE: Keep this file dependency-light so it works under Brython.
-
-LOG_LEVEL_VERBOSE = 0
-LOG_LEVEL_INFO = 1
-LOG_LEVEL_WARNING = 2
-LOG_LEVEL_ERROR = 3
-
-# Descriptor type enum values mirrored from vkdispatch_native/stages_extern.pxd.
-DESCRIPTOR_TYPE_STORAGE_BUFFER = 1
-DESCRIPTOR_TYPE_STORAGE_IMAGE = 2
-DESCRIPTOR_TYPE_UNIFORM_BUFFER = 3
-DESCRIPTOR_TYPE_UNIFORM_IMAGE = 4
-DESCRIPTOR_TYPE_SAMPLER = 5
-
-# Image format block sizes for formats exposed in vkdispatch.base.image.image_format.
-_IMAGE_BLOCK_SIZES = {
- 13: 1,
- 14: 1,
- 20: 2,
- 21: 2,
- 27: 3,
- 28: 3,
- 41: 4,
- 42: 4,
- 74: 2,
- 75: 2,
- 76: 2,
- 81: 4,
- 82: 4,
- 83: 4,
- 88: 6,
- 89: 6,
- 90: 6,
- 95: 8,
- 96: 8,
- 97: 8,
- 98: 4,
- 99: 4,
- 100: 4,
- 101: 8,
- 102: 8,
- 103: 8,
- 104: 12,
- 105: 12,
- 106: 12,
- 107: 16,
- 108: 16,
- 109: 16,
- 110: 8,
- 111: 8,
- 112: 8,
- 113: 16,
- 114: 16,
- 115: 16,
- 116: 24,
- 117: 24,
- 118: 24,
- 119: 32,
- 120: 32,
- 121: 32,
-}
-
-# --- Runtime state ---
-
-_initialized = False
-_debug_mode = False
-_log_level = LOG_LEVEL_WARNING
-_error_string = None
-_next_handle = 1
-
-_contexts = {}
-_signals = {}
-_buffers = {}
-_command_lists = {}
-_compute_plans = {}
-_descriptor_sets = {}
-_images = {}
-_samplers = {}
-_fft_plans = {}
-
-# Device limits exposed through get_devices(); mutable so docs UI can tune them.
-_DEFAULT_SUBGROUP_SIZE = 32
-_DEFAULT_MAX_WORKGROUP_SIZE = (1024, 1024, 64)
-_DEFAULT_MAX_WORKGROUP_INVOCATIONS = 1024
-_DEFAULT_MAX_WORKGROUP_COUNT = (65535, 65535, 65535)
-_DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE = 64 * 1024
-
-_device_subgroup_size = _DEFAULT_SUBGROUP_SIZE
-_device_max_workgroup_size = _DEFAULT_MAX_WORKGROUP_SIZE
-_device_max_workgroup_invocations = _DEFAULT_MAX_WORKGROUP_INVOCATIONS
-_device_max_workgroup_count = _DEFAULT_MAX_WORKGROUP_COUNT
-_device_max_compute_shared_memory_size = _DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE
-
-
-# --- Internal objects ---
-
-class _Signal:
- __slots__ = ("done",)
-
- def __init__(self, done=True):
- self.done = bool(done)
-
-
-class _Context:
- __slots__ = (
- "device_indices",
- "queue_families",
- "queue_count",
- "queue_to_device",
- "stopped",
- )
-
- def __init__(self, device_indices, queue_families):
- self.device_indices = list(device_indices)
- self.queue_families = [list(fam) for fam in queue_families]
-
- normalized = []
- for fam in self.queue_families:
- normalized.append(fam if len(fam) > 0 else [0])
- self.queue_families = normalized
-
- self.queue_count = sum(len(fam) for fam in self.queue_families)
- if self.queue_count <= 0:
- self.queue_families = [[0]]
- self.queue_count = 1
-
- queue_to_device = []
- for dev_idx, fam in enumerate(self.queue_families):
- for _ in fam:
- queue_to_device.append(dev_idx)
-
- if len(queue_to_device) == 0:
- queue_to_device = [0]
-
- self.queue_to_device = queue_to_device
- self.stopped = False
-
-
-class _Buffer:
- __slots__ = (
- "context_handle",
- "size",
- "device_data",
- "staging_data",
- "signal_handles",
- )
-
- def __init__(self, context_handle, queue_count, size):
- self.context_handle = context_handle
- self.size = int(size)
-
- if queue_count <= 0:
- queue_count = 1
-
- self.device_data = [bytearray(self.size) for _ in range(queue_count)]
- self.staging_data = [bytearray(self.size) for _ in range(queue_count)]
-
- signal_handles = []
- for _ in range(queue_count):
- signal_handles.append(_new_handle(_signals, _Signal(done=True)))
- self.signal_handles = signal_handles
-
-
-class _CommandList:
- __slots__ = ("context_handle", "commands", "compute_instance_size")
-
- def __init__(self, context_handle):
- self.context_handle = context_handle
- self.commands = []
- self.compute_instance_size = 0
-
-
-class _ComputePlan:
- __slots__ = ("context_handle", "shader_source", "bindings", "pc_size", "shader_name")
-
- def __init__(self, context_handle, shader_source, bindings, pc_size, shader_name):
- self.context_handle = context_handle
- self.shader_source = shader_source
- self.bindings = list(bindings)
- self.pc_size = int(pc_size)
- self.shader_name = shader_name
-
-
-class _DescriptorSet:
- __slots__ = ("plan_handle", "buffer_bindings", "image_bindings")
-
- def __init__(self, plan_handle):
- self.plan_handle = plan_handle
- self.buffer_bindings = {}
- self.image_bindings = {}
-
-
-class _Image:
- __slots__ = (
- "context_handle",
- "extent",
- "layers",
- "format",
- "type",
- "view_type",
- "generate_mips",
- "block_size",
- "queue_data",
- )
-
- def __init__(
- self,
- context_handle,
- queue_count,
- extent,
- layers,
- format_,
- image_type,
- view_type,
- generate_mips,
- ):
- self.context_handle = context_handle
- self.extent = tuple(extent)
- self.layers = int(layers)
- self.format = int(format_)
- self.type = int(image_type)
- self.view_type = int(view_type)
- self.generate_mips = int(generate_mips)
-
- self.block_size = image_format_block_size(self.format)
-
- if queue_count <= 0:
- queue_count = 1
-
- width = max(1, int(self.extent[0]))
- height = max(1, int(self.extent[1]))
- depth = max(1, int(self.extent[2]))
- layer_count = max(1, self.layers)
- total_bytes = width * height * depth * layer_count * self.block_size
-
- self.queue_data = [bytearray(total_bytes) for _ in range(queue_count)]
-
-
-class _Sampler:
- __slots__ = (
- "context_handle",
- "mag_filter",
- "min_filter",
- "mip_mode",
- "address_mode",
- "mip_lod_bias",
- "min_lod",
- "max_lod",
- "border_color",
- )
-
- def __init__(
- self,
- context_handle,
- mag_filter,
- min_filter,
- mip_mode,
- address_mode,
- mip_lod_bias,
- min_lod,
- max_lod,
- border_color,
- ):
- self.context_handle = context_handle
- self.mag_filter = int(mag_filter)
- self.min_filter = int(min_filter)
- self.mip_mode = int(mip_mode)
- self.address_mode = int(address_mode)
- self.mip_lod_bias = float(mip_lod_bias)
- self.min_lod = float(min_lod)
- self.max_lod = float(max_lod)
- self.border_color = int(border_color)
-
-
-class _FFTPlan:
- __slots__ = (
- "context_handle",
- "dims",
- "axes",
- "buffer_size",
- "input_buffer_size",
- "kernel_num",
- )
-
- def __init__(
- self,
- context_handle,
- dims,
- axes,
- buffer_size,
- input_buffer_size,
- kernel_num,
- ):
- self.context_handle = context_handle
- self.dims = list(dims)
- self.axes = list(axes)
- self.buffer_size = int(buffer_size)
- self.input_buffer_size = int(input_buffer_size)
- self.kernel_num = int(kernel_num)
-
-
-# --- Internal helpers ---
-
-
-def _new_handle(registry, obj):
- global _next_handle
- handle = _next_handle
- _next_handle += 1
- registry[handle] = obj
- return handle
-
-
-def _to_bytes(value):
- if value is None:
- return b""
- if isinstance(value, bytes):
- return value
- if isinstance(value, bytearray):
- return bytes(value)
- if isinstance(value, memoryview):
- return value.tobytes()
- try:
- return bytes(value)
- except Exception:
- return b""
-
-
-def _normalize_extent(extent):
- values = list(extent)
- if len(values) < 3:
- values.extend([1] * (3 - len(values)))
- return (int(values[0]), int(values[1]), int(values[2]))
-
-
-def _queue_indices(ctx, queue_index, all_on_negative=False):
- if ctx is None or ctx.queue_count <= 0:
- return []
-
- if queue_index is None:
- return [0]
-
- queue_index = int(queue_index)
-
- if all_on_negative and queue_index in (-1, -2):
- return list(range(ctx.queue_count))
-
- if 0 <= queue_index < ctx.queue_count:
- return [queue_index]
-
- return []
-
-
-def _set_error(message):
- global _error_string
- _error_string = str(message)
-
-
-def _clear_error():
- global _error_string
- _error_string = None
-
-
-def _as_positive_int(name, value):
- try:
- parsed = int(value)
- except Exception as exc:
- raise ValueError("%s must be an integer" % name) from exc
-
- if parsed <= 0:
- raise ValueError("%s must be greater than zero" % name)
-
- return parsed
-
-
-def _as_positive_triplet(name, value):
- try:
- parts = list(value)
- except Exception as exc:
- raise ValueError("%s must contain exactly 3 integers" % name) from exc
-
- if len(parts) != 3:
- raise ValueError("%s must contain exactly 3 integers" % name)
-
- return (
- _as_positive_int("%s[0]" % name, parts[0]),
- _as_positive_int("%s[1]" % name, parts[1]),
- _as_positive_int("%s[2]" % name, parts[2]),
- )
-
-
-# --- API: context/init/errors/logging ---
-
-
-def reset_device_options():
- global _device_subgroup_size
- global _device_max_workgroup_size
- global _device_max_workgroup_invocations
- global _device_max_workgroup_count
- global _device_max_compute_shared_memory_size
-
- _device_subgroup_size = _DEFAULT_SUBGROUP_SIZE
- _device_max_workgroup_size = _DEFAULT_MAX_WORKGROUP_SIZE
- _device_max_workgroup_invocations = _DEFAULT_MAX_WORKGROUP_INVOCATIONS
- _device_max_workgroup_count = _DEFAULT_MAX_WORKGROUP_COUNT
- _device_max_compute_shared_memory_size = _DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE
-
-
-def set_device_options(
- subgroup_size=None,
- max_workgroup_size=None,
- max_workgroup_invocations=None,
- max_workgroup_count=None,
- max_compute_shared_memory_size=None,
-):
- global _device_subgroup_size
- global _device_max_workgroup_size
- global _device_max_workgroup_invocations
- global _device_max_workgroup_count
- global _device_max_compute_shared_memory_size
-
- if subgroup_size is not None:
- _device_subgroup_size = _as_positive_int("subgroup_size", subgroup_size)
-
- if max_workgroup_size is not None:
- _device_max_workgroup_size = _as_positive_triplet(
- "max_workgroup_size",
- max_workgroup_size,
- )
-
- if max_workgroup_invocations is not None:
- _device_max_workgroup_invocations = _as_positive_int(
- "max_workgroup_invocations",
- max_workgroup_invocations,
- )
-
- if max_workgroup_count is not None:
- _device_max_workgroup_count = _as_positive_triplet(
- "max_workgroup_count",
- max_workgroup_count,
- )
-
- if max_compute_shared_memory_size is not None:
- _device_max_compute_shared_memory_size = _as_positive_int(
- "max_compute_shared_memory_size",
- max_compute_shared_memory_size,
- )
-
-
-def init(debug, log_level):
- global _initialized, _debug_mode, _log_level
- _initialized = True
- _debug_mode = bool(debug)
- _log_level = int(log_level)
- _clear_error()
-
-
-def log(log_level, text, file_str, line_str):
- # Keep logging quiet in docs/brython by default.
- # Function kept for API compatibility.
- _ = log_level
- _ = text
- _ = file_str
- _ = line_str
-
-
-def set_log_level(log_level):
- global _log_level
- _log_level = int(log_level)
-
-
-def get_devices():
- if not _initialized:
- init(False, _log_level)
-
- # One plausible fake discrete GPU with compute+graphics queue families.
- device_tuple = (
- 0, # version_variant
- 1, # version_major
- 3, # version_minor
- 0, # version_patch
- 1001000, # driver_version
- 0x1BAD, # vendor_id
- 0x0001, # device_id
- 2, # device_type (Discrete GPU)
- "VKDispatch Web Dummy GPU",
- 1, # shader_buffer_float32_atomics
- 1, # shader_buffer_float32_atomic_add
- 1, # float_64_support
- 1, # float_16_support
- 1, # int_64_support
- 1, # int_16_support
- 1, # storage_buffer_16_bit_access
- 1, # uniform_and_storage_buffer_16_bit_access
- 1, # storage_push_constant_16
- 1, # storage_input_output_16
- _device_max_workgroup_size, # max_workgroup_size
- _device_max_workgroup_invocations, # max_workgroup_invocations
- _device_max_workgroup_count, # max_workgroup_count
- 8, # max_descriptor_set_count
- 256, # max_push_constant_size
- 1 << 30, # max_storage_buffer_range
- 65536, # max_uniform_buffer_range
- 16, # uniform_buffer_alignment
- _device_subgroup_size, # subgroup_size
- 0x7FFFFFFF, # supported_stages
- 0x7FFFFFFF, # supported_operations
- 1, # quad_operations_in_all_stages
- _device_max_compute_shared_memory_size, # max_compute_shared_memory_size
- [
- (8, 0x006), # compute + transfer
- (4, 0x007), # graphics + compute + transfer
- ],
- 1, # scalar_block_layout
- 1, # timeline_semaphores
- bytes((0x56, 0x4B, 0x44, 0x30, 0x57, 0x45, 0x42, 0x31, 0x44, 0x55, 0x4D, 0x4D, 0x59, 0x00, 0x00, 0x01)),
- )
-
- return [device_tuple]
-
-
-def context_create(device_indicies, queue_families):
- try:
- ctx = _Context(device_indicies, queue_families)
- return _new_handle(_contexts, ctx)
- except Exception as exc:
- _set_error("Failed to create context: %s" % exc)
- return 0
-
-
-def signal_wait(signal_ptr, wait_for_timestamp, queue_index):
- _ = wait_for_timestamp
- _ = queue_index
- signal_obj = _signals.get(int(signal_ptr))
- if signal_obj is None:
- return True
- return bool(signal_obj.done)
-
-
-def signal_insert(context, queue_index):
- _ = context
- _ = queue_index
- return _new_handle(_signals, _Signal(done=True))
-
-
-def signal_destroy(signal_ptr):
- _signals.pop(int(signal_ptr), None)
-
-
-def context_destroy(context):
- _contexts.pop(int(context), None)
-
-
-def get_error_string():
- if _error_string is None:
- return 0
- return _error_string
-
-
-def context_stop_threads(context):
- ctx = _contexts.get(int(context))
- if ctx is not None:
- ctx.stopped = True
-
-
-# --- API: buffers ---
-
-
-def buffer_create(context, size, per_device):
- _ = per_device
- ctx = _contexts.get(int(context))
- if ctx is None:
- _set_error("Invalid context handle for buffer_create")
- return 0
-
- size = int(size)
- if size < 0:
- size = 0
-
- return _new_handle(_buffers, _Buffer(int(context), ctx.queue_count, size))
-
-
-def buffer_destroy(buffer):
- obj = _buffers.pop(int(buffer), None)
- if obj is None:
- return
-
- for signal_handle in obj.signal_handles:
- _signals.pop(signal_handle, None)
-
-
-def buffer_get_queue_signal(buffer, queue_index):
- obj = _buffers.get(int(buffer))
- if obj is None:
- return _new_handle(_signals, _Signal(done=True))
-
- queue_index = int(queue_index)
- if queue_index < 0 or queue_index >= len(obj.signal_handles):
- queue_index = 0
-
- return obj.signal_handles[queue_index]
-
-
-def buffer_wait_staging_idle(buffer, queue_index):
- _ = buffer
- _ = queue_index
- return True
-
-
-def buffer_write_staging(buffer, queue_index, data, size):
- obj = _buffers.get(int(buffer))
- if obj is None:
- return
-
- queue_index = int(queue_index)
- if queue_index < 0 or queue_index >= len(obj.staging_data):
- return
-
- payload = _to_bytes(data)
- size = min(int(size), len(payload), obj.size)
- if size <= 0:
- return
-
- obj.staging_data[queue_index][:size] = payload[:size]
-
-
-def buffer_read_staging(buffer, queue_index, size):
- obj = _buffers.get(int(buffer))
- if obj is None:
- return bytes(int(size))
-
- queue_index = int(queue_index)
- if queue_index < 0 or queue_index >= len(obj.staging_data):
- return bytes(int(size))
-
- size = int(size)
- if size <= 0:
- return b""
-
- data = obj.staging_data[queue_index]
- if size <= len(data):
- return bytes(data[:size])
-
- return bytes(data) + bytes(size - len(data))
-
-
-def buffer_write(buffer, offset, size, index):
- obj = _buffers.get(int(buffer))
- if obj is None:
- return
-
- offset = int(offset)
- size = int(size)
-
- if size <= 0 or offset < 0:
- return
-
- ctx = _contexts.get(obj.context_handle)
- if ctx is None:
- return
-
- queue_indices = _queue_indices(ctx, index, all_on_negative=True)
- if len(queue_indices) == 0:
- return
-
- for queue_index in queue_indices:
- if queue_index >= len(obj.device_data) or queue_index >= len(obj.staging_data):
- continue
-
- end = min(offset + size, obj.size)
- copy_size = end - offset
- if copy_size <= 0:
- continue
-
- obj.device_data[queue_index][offset:end] = obj.staging_data[queue_index][:copy_size]
-
- signal_handle = obj.signal_handles[queue_index]
- signal_obj = _signals.get(signal_handle)
- if signal_obj is not None:
- signal_obj.done = True
-
-
-def buffer_read(buffer, offset, size, index):
- obj = _buffers.get(int(buffer))
- if obj is None:
- return
-
- offset = int(offset)
- size = int(size)
-
- if size <= 0 or offset < 0:
- return
-
- queue_index = int(index)
- if queue_index < 0 or queue_index >= len(obj.device_data):
- return
-
- end = min(offset + size, obj.size)
- copy_size = end - offset
- if copy_size <= 0:
- return
-
- obj.staging_data[queue_index][:copy_size] = obj.device_data[queue_index][offset:end]
-
- signal_handle = obj.signal_handles[queue_index]
- signal_obj = _signals.get(signal_handle)
- if signal_obj is not None:
- signal_obj.done = True
-
-
-# --- API: command lists ---
-
-
-def command_list_create(context):
- if int(context) not in _contexts:
- _set_error("Invalid context handle for command_list_create")
- return 0
-
- return _new_handle(_command_lists, _CommandList(int(context)))
-
-
-def command_list_destroy(command_list):
- _command_lists.pop(int(command_list), None)
-
-
-def command_list_get_instance_size(command_list):
- obj = _command_lists.get(int(command_list))
- if obj is None:
- return 0
-
- return int(obj.compute_instance_size)
-
-
-def command_list_reset(command_list):
- obj = _command_lists.get(int(command_list))
- if obj is None:
- return
-
- obj.commands = []
- obj.compute_instance_size = 0
-
-
-def command_list_submit(command_list, data, instance_count, index):
- _ = data
- _ = instance_count
- _ = index
-
- obj = _command_lists.get(int(command_list))
- if obj is None:
- return True
-
- # No-op fake execution path: commands are accepted but not executed.
- # Keep the command list intact (native keeps it until reset/destroy).
- _ = obj.commands
- return True
-
-
-# --- API: descriptor sets ---
-
-
-def descriptor_set_create(plan):
- if int(plan) not in _compute_plans:
- _set_error("Invalid compute plan handle for descriptor_set_create")
- return 0
-
- return _new_handle(_descriptor_sets, _DescriptorSet(int(plan)))
-
-
-def descriptor_set_destroy(descriptor_set):
- _descriptor_sets.pop(int(descriptor_set), None)
-
-
-def descriptor_set_write_buffer(
- descriptor_set,
- binding,
- object,
- offset,
- range,
- uniform,
- read_access,
- write_access,
-):
- ds = _descriptor_sets.get(int(descriptor_set))
- if ds is None:
- return
-
- ds.buffer_bindings[int(binding)] = (
- int(object),
- int(offset),
- int(range),
- int(uniform),
- int(read_access),
- int(write_access),
- )
-
-
-def descriptor_set_write_image(
- descriptor_set,
- binding,
- object,
- sampler_obj,
- read_access,
- write_access,
-):
- ds = _descriptor_sets.get(int(descriptor_set))
- if ds is None:
- return
-
- ds.image_bindings[int(binding)] = (
- int(object),
- int(sampler_obj),
- int(read_access),
- int(write_access),
- )
-
-
-# --- API: images/samplers ---
-
-
-def image_create(context, extent, layers, format, type, view_type, generate_mips):
- ctx = _contexts.get(int(context))
- if ctx is None:
- _set_error("Invalid context handle for image_create")
- return 0
-
- norm_extent = _normalize_extent(extent)
- obj = _Image(
- int(context),
- ctx.queue_count,
- norm_extent,
- int(layers),
- int(format),
- int(type),
- int(view_type),
- int(generate_mips),
- )
-
- return _new_handle(_images, obj)
-
-
-def image_destroy(image):
- _images.pop(int(image), None)
-
-
-def image_create_sampler(
- context,
- mag_filter,
- min_filter,
- mip_mode,
- address_mode,
- mip_lod_bias,
- min_lod,
- max_lod,
- border_color,
-):
- if int(context) not in _contexts:
- _set_error("Invalid context handle for image_create_sampler")
- return 0
-
- sampler = _Sampler(
- int(context),
- mag_filter,
- min_filter,
- mip_mode,
- address_mode,
- mip_lod_bias,
- min_lod,
- max_lod,
- border_color,
- )
- return _new_handle(_samplers, sampler)
-
-
-def image_destroy_sampler(sampler):
- _samplers.pop(int(sampler), None)
-
-
-def image_write(image, data, offset, extent, baseLayer, layerCount, device_index):
- _ = offset
- _ = baseLayer
-
- obj = _images.get(int(image))
- if obj is None:
- return
-
- payload = _to_bytes(data)
-
- extent = _normalize_extent(extent)
- layer_count = max(1, int(layerCount))
- region_size = max(0, extent[0] * extent[1] * extent[2] * layer_count * obj.block_size)
- if region_size <= 0:
- return
-
- copy_size = min(region_size, len(payload))
- if copy_size <= 0:
- return
-
- ctx = _contexts.get(obj.context_handle)
- if ctx is None:
- return
-
- queue_indices = _queue_indices(ctx, device_index, all_on_negative=True)
- if len(queue_indices) == 0:
- return
-
- for queue_index in queue_indices:
- if queue_index < 0 or queue_index >= len(obj.queue_data):
- continue
- obj.queue_data[queue_index][:copy_size] = payload[:copy_size]
-
-
-def image_format_block_size(format):
- return int(_IMAGE_BLOCK_SIZES.get(int(format), 4))
-
-
-def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index):
- _ = offset
- _ = extent
- _ = baseLayer
- _ = layerCount
-
- obj = _images.get(int(image))
- out_size = max(0, int(out_size))
-
- if obj is None:
- return bytes(out_size)
-
- queue_index = int(device_index)
- if queue_index < 0 or queue_index >= len(obj.queue_data):
- queue_index = 0
-
- data = obj.queue_data[queue_index]
- if out_size <= len(data):
- return bytes(data[:out_size])
-
- return bytes(data) + bytes(out_size - len(data))
-
-
-# --- API: compute stage ---
-
-
-def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name):
- if int(context) not in _contexts:
- _set_error("Invalid context handle for stage_compute_plan_create")
- return 0
-
- source_bytes = _to_bytes(shader_source)
- name_bytes = _to_bytes(shader_name)
-
- plan = _ComputePlan(int(context), source_bytes, list(bindings), int(pc_size), name_bytes)
- return _new_handle(_compute_plans, plan)
-
-
-def stage_compute_plan_destroy(plan):
- _compute_plans.pop(int(plan), None)
-
-
-def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z):
- cl = _command_lists.get(int(command_list))
- cp = _compute_plans.get(int(plan))
-
- if cl is None or cp is None:
- return
-
- cl.commands.append(
- {
- "type": "compute",
- "plan": int(plan),
- "descriptor_set": int(descriptor_set),
- "blocks": (int(blocks_x), int(blocks_y), int(blocks_z)),
- }
- )
- cl.compute_instance_size += max(0, int(cp.pc_size))
-
-
-# --- API: FFT stage ---
-
-
-def stage_fft_plan_create(
- context,
- dims,
- axes,
- buffer_size,
- do_r2c,
- normalize,
- pad_left,
- pad_right,
- frequency_zeropadding,
- kernel_num,
- kernel_convolution,
- conjugate_convolution,
- convolution_features,
- input_buffer_size,
- num_batches,
- single_kernel_multiple_batches,
- keep_shader_code,
-):
- _ = do_r2c
- _ = normalize
- _ = pad_left
- _ = pad_right
- _ = frequency_zeropadding
- _ = kernel_convolution
- _ = conjugate_convolution
- _ = convolution_features
- _ = num_batches
- _ = single_kernel_multiple_batches
- _ = keep_shader_code
-
- if int(context) not in _contexts:
- _set_error("Invalid context handle for stage_fft_plan_create")
- return 0
-
- plan = _FFTPlan(
- int(context),
- list(dims),
- list(axes),
- int(buffer_size),
- int(input_buffer_size),
- int(kernel_num),
- )
-
- return _new_handle(_fft_plans, plan)
-
-
-def stage_fft_plan_destroy(plan):
- _fft_plans.pop(int(plan), None)
-
-
-def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer):
- _ = buffer
- _ = inverse
- _ = kernel
- _ = input_buffer
-
- cl = _command_lists.get(int(command_list))
- if cl is None or int(plan) not in _fft_plans:
- return
-
- cl.commands.append(
- {
- "type": "fft",
- "plan": int(plan),
- }
- )
-
-
-__all__ = [
- "reset_device_options",
- "set_device_options",
- "init",
- "log",
- "set_log_level",
- "get_devices",
- "context_create",
- "signal_wait",
- "signal_insert",
- "signal_destroy",
- "context_destroy",
- "get_error_string",
- "context_stop_threads",
- "buffer_create",
- "buffer_destroy",
- "buffer_get_queue_signal",
- "buffer_wait_staging_idle",
- "buffer_write_staging",
- "buffer_read_staging",
- "buffer_write",
- "buffer_read",
- "command_list_create",
- "command_list_destroy",
- "command_list_get_instance_size",
- "command_list_reset",
- "command_list_submit",
- "descriptor_set_create",
- "descriptor_set_destroy",
- "descriptor_set_write_buffer",
- "descriptor_set_write_image",
- "image_create",
- "image_destroy",
- "image_create_sampler",
- "image_destroy_sampler",
- "image_write",
- "image_format_block_size",
- "image_read",
- "stage_compute_plan_create",
- "stage_compute_plan_destroy",
- "stage_compute_record",
- "stage_fft_plan_create",
- "stage_fft_plan_destroy",
- "stage_fft_record",
- "LOG_LEVEL_VERBOSE",
- "LOG_LEVEL_INFO",
- "LOG_LEVEL_WARNING",
- "LOG_LEVEL_ERROR",
- "DESCRIPTOR_TYPE_STORAGE_BUFFER",
- "DESCRIPTOR_TYPE_STORAGE_IMAGE",
- "DESCRIPTOR_TYPE_UNIFORM_BUFFER",
- "DESCRIPTOR_TYPE_UNIFORM_IMAGE",
- "DESCRIPTOR_TYPE_SAMPLER",
-]
diff --git a/docs/special_pages/brython_shader_lab.html b/docs/special_pages/shader_playground.html
similarity index 91%
rename from docs/special_pages/brython_shader_lab.html
rename to docs/special_pages/shader_playground.html
index 0e9e057c..9fb37d17 100644
--- a/docs/special_pages/brython_shader_lab.html
+++ b/docs/special_pages/shader_playground.html
@@ -401,6 +401,7 @@
@@ -676,7 +679,7 @@
VkDispatch Shader Playground
btn.classList.add("active");
/* Switch output highlighting mode */
- if (backend === "cuda") {
+ if (backend === "cuda" || backend === "opencl") {
window.cmOutput.setOption("mode", "text/x-csrc");
} else {
window.cmOutput.setOption("mode", "text/x-glsl");
@@ -690,6 +693,7 @@ VkDispatch Shader Playground
/* ── URL hash restore ── */
var deviceFields = [
{ key: "ss", id: "opt-subgroup-size" },
+ { key: "sse", id: "opt-subgroup-enabled" },
{ key: "wsx", id: "opt-wg-size-x" },
{ key: "wsy", id: "opt-wg-size-y" },
{ key: "wsz", id: "opt-wg-size-z" },
@@ -720,11 +724,11 @@ VkDispatch Shader Playground
/* Restore backend from URL */
if (params.has("be")) {
var be = params.get("be").toLowerCase();
- if (be === "cuda") {
- window.currentBackend = "cuda";
+ if (be === "cuda" || be === "opencl") {
+ window.currentBackend = be;
toggleButtons.forEach(function (b) {
b.classList.remove("active");
- if (b.getAttribute("data-backend") === "cuda") {
+ if (b.getAttribute("data-backend") === be) {
b.classList.add("active");
}
});
@@ -734,7 +738,15 @@ VkDispatch Shader Playground
deviceFields.forEach(function (f) {
if (params.has(f.key)) {
- document.getElementById(f.id).value = params.get(f.key);
+ if(f.id === "opt-subgroup-enabled") {
+ document.getElementById(f.id).checked =
+ params.get(f.key) === "1" ||
+ params.get(f.key).toLowerCase() === "true" ||
+ params.get(f.key).toLowerCase() === "yes" ||
+ params.get(f.key).toLowerCase() === "on";
+ } else {
+ document.getElementById(f.id).value = params.get(f.key);
+ }
}
});
toggleFields.forEach(function (f) {
@@ -792,43 +804,74 @@ VkDispatch Shader Playground
});
/* ── share button ── */
- document
- .getElementById("share-btn")
- .addEventListener("click", function () {
- var code = window.cmCode.getValue();
- var encoded = btoa(unescape(encodeURIComponent(code)));
-
- var hashParts = ["code=" + encoded];
-
- /* Include backend in share link */
- hashParts.push("be=" + encodeURIComponent(window.currentBackend));
-
- deviceFields.forEach(function (f) {
- var val = document.getElementById(f.id).value.trim();
- if (val !== "") {
- hashParts.push(
- encodeURIComponent(f.key) +
- "=" +
- encodeURIComponent(val)
- );
- }
- });
- toggleFields.forEach(function (f) {
- var el = document.getElementById(f.id);
- if (!el) return;
- var checked = el.checked ? "1" : "0";
+ function buildPlaygroundHash() {
+ var code = window.cmCode.getValue();
+ var encoded = btoa(unescape(encodeURIComponent(code)));
+
+ var hashParts = ["code=" + encoded];
+
+ /* Include backend in share/runtime URL */
+ hashParts.push("be=" + encodeURIComponent(window.currentBackend));
+
+ deviceFields.forEach(function (f) {
+ var val = document.getElementById(f.id).value.trim();
+
+ if(f.id === "opt-subgroup-enabled") {
+ val = document.getElementById(f.id).checked ? "1" : "0";
+ }
+
+ if (val !== "") {
hashParts.push(
encodeURIComponent(f.key) +
"=" +
- encodeURIComponent(checked)
+ encodeURIComponent(val)
);
- });
+ }
+ });
+ toggleFields.forEach(function (f) {
+ var el = document.getElementById(f.id);
+ if (!el) return;
+ var checked = el.checked ? "1" : "0";
+ hashParts.push(
+ encodeURIComponent(f.key) +
+ "=" +
+ encodeURIComponent(checked)
+ );
+ });
+
+ return hashParts.join("&");
+ }
+
+ window.updatePlaygroundUrlState = function () {
+ var hash = buildPlaygroundHash();
+ var nextUrl =
+ window.location.pathname +
+ window.location.search +
+ "#" +
+ hash;
+
+ if (window.location.hash.slice(1) !== hash) {
+ if (window.history && window.history.replaceState) {
+ window.history.replaceState(null, "", nextUrl);
+ } else {
+ window.location.hash = hash;
+ }
+ }
+
+ return hash;
+ };
+
+ document
+ .getElementById("share-btn")
+ .addEventListener("click", function () {
+ var hash = window.updatePlaygroundUrlState();
var url =
window.location.origin +
window.location.pathname +
+ window.location.search +
"#" +
- hashParts.join("&");
+ hash;
copyToClipboard(url).then(function () {
showToast("Share link copied to clipboard.");
@@ -917,13 +960,14 @@ VkDispatch Shader Playground
import sys
import traceback
-import vkdispatch_native
+import vkdispatch as vd
import vkdispatch.base.context as vd_context
import vkdispatch.base.init as vd_init
import vkdispatch.execution_pipeline.command_graph as vd_command_graph
import vkdispatch.fft.shader_factories as vd_fft_shader_factories
import vkdispatch.codegen as vc
+vd.initialize(backend="dummy")
class OutputBuffer:
def __init__(self):
@@ -961,6 +1005,7 @@ VkDispatch Shader Playground
def _read_device_options():
return {
"subgroup_size": _parse_positive_int("opt-subgroup-size", "Subgroup Size"),
+ "subgroup_enabled": document["opt-subgroup-enabled"].checked,
"max_workgroup_size": (
_parse_positive_int("opt-wg-size-x", "Max Workgroup Size X"),
_parse_positive_int("opt-wg-size-y", "Max Workgroup Size Y"),
@@ -984,16 +1029,8 @@ VkDispatch Shader Playground
def _reset_vkdispatch_runtime():
context = getattr(vd_context, "__context", None)
- if context is not None:
- if hasattr(vd_context, "set_running"):
- vd_context.set_running(False)
-
- handles_list = list(context.handles_dict.values())
- for handle in handles_list:
- handle.destroy()
-
- vkdispatch_native.context_destroy(context._handle)
- vd_context.__context = None
+ #if context is not None:
+ # vd_context.destroy_context()
vd_init.__initilized_instance = False
vd_init.__device_infos = None
@@ -1008,6 +1045,9 @@ VkDispatch Shader Playground
code = window.cmCode.getValue()
window.cmOutput.setValue("")
+ if event is not None and hasattr(window, "updatePlaygroundUrlState"):
+ window.updatePlaygroundUrlState()
+
stdout_buffer = OutputBuffer()
stderr_buffer = OutputBuffer()
@@ -1017,14 +1057,18 @@ VkDispatch Shader Playground
try:
options = _read_device_options()
- vkdispatch_native.set_device_options(
+ _reset_vkdispatch_runtime()
+
+ vd.initialize(backend="dummy")
+ vd.get_context()
+ vd.set_dummy_context_params(
subgroup_size=options["subgroup_size"],
+ subgroup_enabled=options["subgroup_enabled"],
max_workgroup_size=options["max_workgroup_size"],
max_workgroup_invocations=options["max_workgroup_invocations"],
max_workgroup_count=options["max_workgroup_count"],
- max_compute_shared_memory_size=options["max_compute_shared_memory_size"],
+ max_shared_memory=options["max_compute_shared_memory_size"],
)
- _reset_vkdispatch_runtime()
# Set codegen backend based on toggle state
backend = str(window.currentBackend)
diff --git a/docs/tutorials/reductions_and_fft.rst b/docs/tutorials/reductions_and_fft.rst
index b078503b..6b77430a 100644
--- a/docs/tutorials/reductions_and_fft.rst
+++ b/docs/tutorials/reductions_and_fft.rst
@@ -162,6 +162,14 @@ For advanced workflows (for example padded 2D cross-correlation), use ``input_ma
``output_map`` to remap FFT I/O indices and ``input_signal_range`` to skip inactive
regions.
+Map argument annotations do not determine FFT compute precision. ``read_op.register``
+and ``write_op.register`` always use the internal FFT compute type; map callbacks should
+cast user-chosen buffer values to and from that register type as needed. If both FFT I/O
+paths are mapped and ``compute_type`` is not provided, ``vd.fft`` defaults to
+``complex64`` (falling back to ``complex32`` when required by device support).
+When ``output_map`` is provided without ``input_map``, pass an explicit input buffer
+argument after the ``output_map`` arguments so read and write phases use different proxies.
+
.. code-block:: python
import vkdispatch.codegen as vc
diff --git a/examples/pytorch_cuda_graph_cuda_python.py b/examples/pytorch_cuda_graph_cuda_python.py
new file mode 100644
index 00000000..51a949f9
--- /dev/null
+++ b/examples/pytorch_cuda_graph_cuda_python.py
@@ -0,0 +1,80 @@
+#!/usr/bin/env python3
+"""Capture and replay a vkdispatch CUDA kernel inside a PyTorch CUDA Graph.
+
+This example uses:
+ - vkdispatch runtime backend: "cuda"
+ - a custom vkdispatch shader recorded into CommandGraph
+ - torch.cuda.CUDAGraph capture + replay
+ - zero-copy tensor sharing via __cuda_array_interface__
+"""
+
+import torch
+
+import vkdispatch as vd
+import vkdispatch.codegen as vc
+from vkdispatch.codegen.abreviations import Buff, Const, f32
+
+
+@vd.shader(exec_size=lambda args: args.x.size)
+def custom_shader(out: Buff[f32], x: Buff[f32], bias: Const[f32]):
+ tid = vc.global_invocation_id().x
+ out[tid] = x[tid] * 1.5 + vc.sin(x[tid]) + bias
+
+
+def main() -> None:
+ if not torch.cuda.is_available():
+ raise RuntimeError("CUDA is required for this example.")
+
+ torch.cuda.set_device(0)
+ torch.manual_seed(0)
+
+ vd.initialize(backend="cuda")
+ vd.make_context(device_ids=torch.cuda.current_device())
+
+ n = 16
+ bias = 0.25
+
+ # Static allocations are required for CUDA Graph replay.
+ x = torch.empty(n, device="cuda", dtype=torch.float32)
+ out = torch.empty_like(x)
+ x.fill_(0.0)
+
+ x_vd = vd.from_cuda_array(x)
+ out_vd = vd.from_cuda_array(out)
+
+ cmd_graph = vd.CommandGraph()
+
+ # Record one vkdispatch kernel launch into the command graph.
+ # For backend="cuda-python", Const/Var payloads are fixed at record time.
+ custom_shader(out=out_vd, x=x_vd, bias=bias, graph=cmd_graph)
+
+ torch.cuda.synchronize()
+ # Pre-stage internal uniform uploads outside torch capture so only dispatch is captured.
+ #cmd_graph.prepare_for_cuda_graph_capture()
+
+ graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(graph):
+ # torch.cuda.graph(...) may switch to an internal capture stream.
+ # Bind vkdispatch to the active stream from inside that context.
+ with vd.cuda_graph_capture(torch.cuda.current_stream()):
+ print("Submitting vkdispatch CommandGraph to CUDA Graph...")
+ cmd_graph.submit()
+ print("Done recording.")
+
+ replay_inputs = [0.0, 1.0, 2.0, 3.0]
+ for i, value in enumerate(replay_inputs, start=1):
+ x.fill_(value)
+ graph.replay()
+ torch.cuda.synchronize()
+
+ expected = x * 1.5 + torch.sin(x) + bias
+ torch.testing.assert_close(out, expected, rtol=1e-5, atol=1e-5)
+ print(
+ f"replay {i} input={value:.1f} output[:8]={out[:8].detach().cpu().tolist()}"
+ )
+
+ print("Done.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pyproject.toml b/pyproject.toml
index fc741656..7379c159 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -2,36 +2,7 @@
requires = [
"setuptools>=59.0",
"wheel",
- "Cython"
+ "Cython",
+ "packaging"
]
build-backend = "setuptools.build_meta"
-
-[project]
-name = "vkdispatch"
-version = "0.0.30"
-authors = [
- { name="Shahar Sandhaus", email="shahar.sandhaus@gmail.com" },
-]
-description = "A Python module for orchestrating and dispatching large computations across multi-GPU systems using Vulkan."
-readme = "README.md"
-requires-python = ">=3.6"
-classifiers = [
- "Programming Language :: Python :: 3",
- "License :: OSI Approved :: MIT License",
- "Operating System :: OS Independent",
- "Development Status :: 2 - Pre-Alpha",
-]
-dependencies = [
- "setuptools>=59.0",
-]
-scripts = { vdlist = 'vkdispatch.cli:cli_entrypoint' }
-
-[project.urls]
-Homepage = "https://github.com/sharhar/vkdispatch"
-Issues = "https://github.com/sharhar/vkdispatch/issues"
-
-[project.optional-dependencies]
-cli = ["Click"]
-cuda = ["cuda-python"]
-pycuda = ["pycuda"]
-numpy = ["numpy"]
diff --git a/setup.py b/setup.py
index 38f19dfc..422495ce 100644
--- a/setup.py
+++ b/setup.py
@@ -1,239 +1,129 @@
import os
import platform
+import re
import subprocess
+from pathlib import Path
from setuptools import Extension
+from setuptools import find_packages
from setuptools import setup
from setuptools.command.build_ext import build_ext
-import re
-
-# Typically you'll put `packaging` in your setup_requires or pyproject.toml if needed.
try:
from packaging.version import Version
except ImportError:
- # As a fallback, if you absolutely can't rely on `packaging`,
- # you could use distutils: from distutils.version import LooseVersion as Version
print("Warning: 'packaging' not found; version comparisons might be less accurate.")
from distutils.version import LooseVersion as Version
-system = platform.system()
+BUILD_TARGET_FULL = "full"
+BUILD_TARGET_CORE = "core"
+BUILD_TARGET_NATIVE = "native"
+BUILD_TARGET_META = "meta"
+VALID_BUILD_TARGETS = {
+ BUILD_TARGET_FULL,
+ BUILD_TARGET_CORE,
+ BUILD_TARGET_NATIVE,
+ BUILD_TARGET_META,
+}
-proj_root = os.path.abspath(os.path.dirname(__file__))
-molten_vk_path = "./deps/MoltenVK/MoltenVK/MoltenVK/static/MoltenVK.xcframework/macos-arm64_x86_64/"
-vulkan_sdk_root = os.environ.get('VULKAN_SDK')
-platform_name_dict = {
- "Darwin": "MACOS",
- "Windows": "WINDOWS",
- "Linux": "LINUX"
-}
+def get_build_target() -> str:
+ target = os.environ.get("VKDISPATCH_BUILD_TARGET", BUILD_TARGET_FULL).strip().lower()
+ if target not in VALID_BUILD_TARGETS:
+ valid = ", ".join(sorted(VALID_BUILD_TARGETS))
+ raise RuntimeError(
+ f"Invalid VKDISPATCH_BUILD_TARGET={target!r}. Expected one of: {valid}"
+ )
+ return target
-platform_library_dirs = []
-platform_define_macros = []
-platform_link_libraries = []
-platform_extra_link_args = []
-platform_extra_compile_args = (
- ["/W3", "/GL", "/DNDEBUG", "/MD", "/EHsc", "/std:c++17"]
- if system == "Windows"
- else [
- "-O2",
- "-g",
- "-std=c++17",
- ]
-)
-include_directories = [
- proj_root + "/deps/VMA/include",
- proj_root + "/deps/volk",
- proj_root + "/deps/VkFFT/vkFFT",
-]
+BUILD_TARGET = get_build_target()
-if os.name == "posix":
- platform_extra_link_args.append("-g")
- platform_extra_link_args.append("-O0")
- platform_extra_link_args.append("-fno-omit-frame-pointer")
- platform_link_libraries.extend(["dl", "pthread"])
-
-
-if vulkan_sdk_root is None:
- include_directories.extend([
- proj_root + "/include_ext",
- proj_root + "/deps/Vulkan-Headers/include",
- proj_root + "/deps/Vulkan-Utility-Libraries/include",
- proj_root + "/deps/glslang",
- proj_root + "/deps/glslang/glslang/Include",
- ])
-
- if system == "Darwin":
- platform_library_dirs.append(molten_vk_path)
- platform_link_libraries.append("MoltenVK")
- platform_extra_link_args.extend([
- "-framework", "Metal",
- "-framework", "AVFoundation",
- "-framework", "AppKit"
- ])
- platform_extra_compile_args.append("-mmacosx-version-min=10.15")
- else:
- platform_define_macros.append(("VKDISPATCH_USE_VOLK", 1))
-else:
- include_directories.extend([
- vulkan_sdk_root + '/include',
- vulkan_sdk_root + '/include/utility',
- vulkan_sdk_root + '/include/glslang/Include',
- ])
+proj_root = Path(__file__).resolve().parent
+system = platform.system()
+molten_vk_path = "./deps/MoltenVK/MoltenVK/MoltenVK/static/MoltenVK.xcframework/macos-arm64_x86_64/"
+vulkan_sdk_root = os.environ.get("VULKAN_SDK")
- platform_define_macros.append(("VKDISPATCH_USE_VOLK", 1))
- platform_define_macros.append(("VKDISPATCH_LOADER_PATH", '"' + os.path.abspath(f"{vulkan_sdk_root}") + '/"'))
- #if os.name == "posix":
- # platform_link_libraries.append("vulkan")
- #else:
- # platform_link_libraries.append("vulkan-1")
+def read_version() -> str:
+ init_path = proj_root / "vkdispatch" / "__init__.py"
+ text = init_path.read_text(encoding="utf-8")
+ match = re.search(r'^__version__\s*=\s*"([^"]+)"', text, re.MULTILINE)
+ if not match:
+ raise RuntimeError(f"Could not find __version__ in {init_path}")
+ return match.group(1)
- platform_library_dirs.append(vulkan_sdk_root + '/lib')
- platform_link_libraries.extend([
- "glslang",
- "SPIRV",
- "MachineIndependent",
- "GenericCodeGen",
- "SPIRV-Tools-opt",
- "SPIRV-Tools-link",
- "SPIRV-Tools-reduce",
- "SPIRV-Tools",
- "glslang-default-resource-limits"
- ])
+def read_readme() -> str:
+ return (proj_root / "README.md").read_text(encoding="utf-8")
-sources = []
+VERSION = read_version()
-def append_to_sources(prefix, source_list):
- global sources
+COMMON_CLASSIFIERS = [
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: OS Independent",
+ "Development Status :: 2 - Pre-Alpha",
+]
+
+COMMON_PROJECT_URLS = {
+ "Homepage": "https://github.com/sharhar/vkdispatch",
+ "Issues": "https://github.com/sharhar/vkdispatch/issues",
+}
+
+COMMON_EXTRAS = {
+ "cuda": ["cuda-python"],
+ "opencl": ["pyopencl", "numpy"],
+ "pycuda": ["pycuda"],
+ "numpy": ["numpy"],
+}
- for source in source_list:
- sources.append(prefix + source)
-
-
-sources.append("vkdispatch_native/wrapper.pyx")
-
-append_to_sources("vkdispatch_native/", [
- "context/init.cpp",
- "context/context.cpp",
- "context/errors.cpp",
- "context/handles.cpp",
-
- "objects/buffer.cpp",
- "objects/image.cpp",
- "objects/command_list.cpp",
- "objects/descriptor_set.cpp",
-
- "stages/stage_fft.cpp",
- "stages/stage_compute.cpp",
-
- "queue/queue.cpp",
- "queue/signal.cpp",
- "queue/work_queue.cpp",
- "queue/barrier_manager.cpp",
-
- "libs/VMAImpl.cpp",
- "libs/VolkImpl.cpp"
-])
-
-if vulkan_sdk_root is None:
- append_to_sources("deps/glslang/glslang/", [
- "CInterface/glslang_c_interface.cpp",
- "GenericCodeGen/CodeGen.cpp",
- "GenericCodeGen/Link.cpp",
- "MachineIndependent/glslang_tab.cpp",
- "MachineIndependent/attribute.cpp",
- "MachineIndependent/Constant.cpp",
- "MachineIndependent/iomapper.cpp",
- "MachineIndependent/InfoSink.cpp",
- "MachineIndependent/Initialize.cpp",
- "MachineIndependent/IntermTraverse.cpp",
- "MachineIndependent/Intermediate.cpp",
- "MachineIndependent/ParseContextBase.cpp",
- "MachineIndependent/ParseHelper.cpp",
- "MachineIndependent/PoolAlloc.cpp",
- "MachineIndependent/RemoveTree.cpp",
- "MachineIndependent/Scan.cpp",
- "MachineIndependent/ShaderLang.cpp",
- "MachineIndependent/SpirvIntrinsics.cpp",
- "MachineIndependent/SymbolTable.cpp",
- "MachineIndependent/Versions.cpp",
- "MachineIndependent/intermOut.cpp",
- "MachineIndependent/limits.cpp",
- "MachineIndependent/linkValidate.cpp",
- "MachineIndependent/parseConst.cpp",
- "MachineIndependent/reflection.cpp",
- "MachineIndependent/preprocessor/Pp.cpp",
- "MachineIndependent/preprocessor/PpAtom.cpp",
- "MachineIndependent/preprocessor/PpContext.cpp",
- "MachineIndependent/preprocessor/PpScanner.cpp",
- "MachineIndependent/preprocessor/PpTokens.cpp",
- "MachineIndependent/propagateNoContraction.cpp",
- "ResourceLimits/ResourceLimits.cpp",
- "ResourceLimits/resource_limits_c.cpp"
- ])
-
- append_to_sources("deps/glslang/SPIRV/", [
- "GlslangToSpv.cpp",
- "InReadableOrder.cpp",
- "Logger.cpp",
- "SpvBuilder.cpp",
- "SpvPostProcess.cpp",
- "doc.cpp",
- "SpvTools.cpp",
- "disassemble.cpp",
- "CInterface/spirv_c_interface.cpp"
- ])
def parse_compiler_version(version_output):
if not isinstance(version_output, str):
return None
-
- # Try to match either clang or gcc version string
- clang_match = re.search(r'clang version ([^\s]+)', version_output)
- gcc_match = re.search(r'gcc.+?([\d.]+(?:-[a-zA-Z0-9]+)?)', version_output, re.IGNORECASE)
-
+
+ clang_match = re.search(r"clang version ([^\s]+)", version_output)
+ gcc_match = re.search(
+ r"gcc.+?([\d.]+(?:-[a-zA-Z0-9]+)?)", version_output, re.IGNORECASE
+ )
+
match = clang_match or gcc_match
if not match:
return None
try:
return Version(match.group(1))
- except Exception as e:
- print(f"Invalid version: {e}")
+ except Exception as exc:
+ print(f"Invalid version: {exc}")
return None
+
def detect_unix_compiler(compiler_exe):
- """
- Given the 'compiler_exe' (like 'gcc', 'clang', etc.), returns a string
- denoting the compiler family: 'clang', 'gcc', or 'unknown'.
- """
try:
- # Run e.g. `gcc --version` or `clang --version`
- version_output = subprocess.check_output([compiler_exe, '--version'],
- stderr=subprocess.STDOUT,
- universal_newlines=True)
-
- if 'clang' in version_output:
- return 'clang', parse_compiler_version(version_output)
- elif 'gcc' in version_output or 'Free Software Foundation' in version_output:
- return 'gcc', parse_compiler_version(version_output)
- else:
- return 'unknown', None
+ version_output = subprocess.check_output(
+ [compiler_exe, "--version"],
+ stderr=subprocess.STDOUT,
+ universal_newlines=True,
+ )
+
+ if "clang" in version_output:
+ return "clang", parse_compiler_version(version_output)
+ if "gcc" in version_output or "Free Software Foundation" in version_output:
+ return "gcc", parse_compiler_version(version_output)
+ return "unknown", None
except Exception:
- return 'unknown', None
-
+ return "unknown", None
+
+
class CustomBuildExt(build_ext):
def build_extensions(self):
compiler_type = self.compiler.compiler_type
print(f"Detected compiler type: {compiler_type}")
- if compiler_type == 'unix':
+ if compiler_type == "unix":
print(f"Detected compiler: {self.compiler.compiler}")
compiler_family, version = detect_unix_compiler(self.compiler.compiler[0])
print(f"Detected compiler family: {compiler_family}")
@@ -241,50 +131,290 @@ def build_extensions(self):
if version is not None:
for ext in self.extensions:
- if compiler_family == 'clang' and version < Version('9.0'):
- ext.libraries.append('c++fs')
- elif compiler_family == 'gcc' and version < Version('9.1'):
- ext.libraries.append('stdc++fs')
+ if compiler_family == "clang" and version < Version("9.0"):
+ ext.libraries.append("c++fs")
+ elif compiler_family == "gcc" and version < Version("9.1"):
+ ext.libraries.append("stdc++fs")
else:
- print("WARNING: Unknown compiler family, not adding filesystem library")
+ print(
+ "WARNING: Unknown compiler family, not adding filesystem library"
+ )
- # Now actually build the extensions
super().build_extensions()
-setup(
- name="vkdispatch",
- packages=[
- "vkdispatch",
- "vkdispatch.base",
- "vkdispatch.backends",
- "vkdispatch._compat",
- "vkdispatch.codegen",
- "vkdispatch.codegen.backends",
- "vkdispatch.codegen.functions",
- "vkdispatch.codegen.functions.base_functions",
- "vkdispatch.codegen.variables",
- "vkdispatch.execution_pipeline",
- "vkdispatch.shader",
- "vkdispatch.reduce",
- "vkdispatch.vkfft",
- "vkdispatch.fft"
- ],
- ext_modules=[
- Extension(
- "vkdispatch_native",
- sources=sources,
- language="c++",
- define_macros=platform_define_macros,
- library_dirs=platform_library_dirs,
- libraries=platform_link_libraries,
- extra_compile_args=platform_extra_compile_args,
- extra_link_args=platform_extra_link_args,
- include_dirs=include_directories,
+
+def append_to_sources(prefix, source_list, out_sources):
+ for source in source_list:
+ out_sources.append(prefix + source)
+
+
+def build_native_extension():
+ platform_library_dirs = []
+ platform_define_macros = []
+ platform_link_libraries = []
+ platform_extra_link_args = []
+ platform_extra_compile_args = (
+ ["/W3", "/GL", "/DNDEBUG", "/MD", "/EHsc", "/std:c++17"]
+ if system == "Windows"
+ else ["-O2", "-g", "-std=c++17"]
+ )
+
+ include_directories = [
+ str(proj_root / "deps" / "VMA" / "include"),
+ str(proj_root / "deps" / "volk"),
+ str(proj_root / "deps" / "VkFFT" / "vkFFT"),
+ ]
+
+ if os.name == "posix":
+ platform_extra_link_args.extend(["-g", "-O0", "-fno-omit-frame-pointer"])
+ platform_link_libraries.extend(["dl", "pthread"])
+
+ if vulkan_sdk_root is None:
+ include_directories.extend(
+ [
+ str(proj_root / "include_ext"),
+ str(proj_root / "deps" / "Vulkan-Headers" / "include"),
+ str(proj_root / "deps" / "Vulkan-Utility-Libraries" / "include"),
+ str(proj_root / "deps" / "glslang"),
+ str(proj_root / "deps" / "glslang" / "glslang" / "Include"),
+ ]
+ )
+
+ if system == "Darwin":
+ platform_library_dirs.append(molten_vk_path)
+ platform_link_libraries.append("MoltenVK")
+ platform_extra_link_args.extend(
+ [
+ "-framework",
+ "Metal",
+ "-framework",
+ "AVFoundation",
+ "-framework",
+ "AppKit",
+ ]
+ )
+ platform_extra_compile_args.append("-mmacosx-version-min=10.15")
+ else:
+ platform_define_macros.append(("VKDISPATCH_USE_VOLK", 1))
+ else:
+ include_directories.extend(
+ [
+ vulkan_sdk_root + "/include",
+ vulkan_sdk_root + "/include/utility",
+ vulkan_sdk_root + "/include/glslang/Include",
+ ]
)
- ],
- cmdclass={
- 'build_ext': CustomBuildExt,
- },
- version="0.0.30",
- zip_safe=False,
-)
+
+ platform_define_macros.append(("VKDISPATCH_USE_VOLK", 1))
+ platform_define_macros.append(
+ ("VKDISPATCH_LOADER_PATH", '"' + os.path.abspath(vulkan_sdk_root) + '/"')
+ )
+
+ platform_library_dirs.append(vulkan_sdk_root + "/lib")
+ platform_link_libraries.extend(
+ [
+ "glslang",
+ "SPIRV",
+ "MachineIndependent",
+ "GenericCodeGen",
+ "SPIRV-Tools-opt",
+ "SPIRV-Tools-link",
+ "SPIRV-Tools-reduce",
+ "SPIRV-Tools",
+ "glslang-default-resource-limits",
+ ]
+ )
+
+ sources = []
+ sources.append("vkdispatch_native/wrapper.pyx")
+
+ append_to_sources(
+ "vkdispatch_native/",
+ [
+ "context/init.cpp",
+ "context/context.cpp",
+ "context/errors.cpp",
+ "context/handles.cpp",
+ "objects/buffer.cpp",
+ "objects/image.cpp",
+ "objects/command_list.cpp",
+ "objects/descriptor_set.cpp",
+ "stages/stage_fft.cpp",
+ "stages/stage_compute.cpp",
+ "queue/queue.cpp",
+ "queue/signal.cpp",
+ "queue/work_queue.cpp",
+ "queue/barrier_manager.cpp",
+ "libs/VMAImpl.cpp",
+ "libs/VolkImpl.cpp",
+ ],
+ sources,
+ )
+
+ if vulkan_sdk_root is None:
+ append_to_sources(
+ "deps/glslang/glslang/",
+ [
+ "CInterface/glslang_c_interface.cpp",
+ "GenericCodeGen/CodeGen.cpp",
+ "GenericCodeGen/Link.cpp",
+ "MachineIndependent/glslang_tab.cpp",
+ "MachineIndependent/attribute.cpp",
+ "MachineIndependent/Constant.cpp",
+ "MachineIndependent/iomapper.cpp",
+ "MachineIndependent/InfoSink.cpp",
+ "MachineIndependent/Initialize.cpp",
+ "MachineIndependent/IntermTraverse.cpp",
+ "MachineIndependent/Intermediate.cpp",
+ "MachineIndependent/ParseContextBase.cpp",
+ "MachineIndependent/ParseHelper.cpp",
+ "MachineIndependent/PoolAlloc.cpp",
+ "MachineIndependent/RemoveTree.cpp",
+ "MachineIndependent/Scan.cpp",
+ "MachineIndependent/ShaderLang.cpp",
+ "MachineIndependent/SpirvIntrinsics.cpp",
+ "MachineIndependent/SymbolTable.cpp",
+ "MachineIndependent/Versions.cpp",
+ "MachineIndependent/intermOut.cpp",
+ "MachineIndependent/limits.cpp",
+ "MachineIndependent/linkValidate.cpp",
+ "MachineIndependent/parseConst.cpp",
+ "MachineIndependent/reflection.cpp",
+ "MachineIndependent/preprocessor/Pp.cpp",
+ "MachineIndependent/preprocessor/PpAtom.cpp",
+ "MachineIndependent/preprocessor/PpContext.cpp",
+ "MachineIndependent/preprocessor/PpScanner.cpp",
+ "MachineIndependent/preprocessor/PpTokens.cpp",
+ "MachineIndependent/propagateNoContraction.cpp",
+ "ResourceLimits/ResourceLimits.cpp",
+ "ResourceLimits/resource_limits_c.cpp",
+ ],
+ sources,
+ )
+
+ append_to_sources(
+ "deps/glslang/SPIRV/",
+ [
+ "GlslangToSpv.cpp",
+ "InReadableOrder.cpp",
+ "Logger.cpp",
+ "SpvBuilder.cpp",
+ "SpvPostProcess.cpp",
+ "doc.cpp",
+ "SpvTools.cpp",
+ "disassemble.cpp",
+ "CInterface/spirv_c_interface.cpp",
+ ],
+ sources,
+ )
+
+ return Extension(
+ "vkdispatch_vulkan_native",
+ sources=sources,
+ language="c++",
+ define_macros=platform_define_macros,
+ library_dirs=platform_library_dirs,
+ libraries=platform_link_libraries,
+ extra_compile_args=platform_extra_compile_args,
+ extra_link_args=platform_extra_link_args,
+ include_dirs=include_directories,
+ )
+
+
+def base_setup_kwargs():
+ return {
+ "version": VERSION,
+ "author": "Shahar Sandhaus",
+ "author_email": "shahar.sandhaus@gmail.com",
+ "description": (
+ "A Python module for orchestrating and dispatching large computations "
+ "across multi-GPU systems using Vulkan."
+ ),
+ "long_description": read_readme(),
+ "long_description_content_type": "text/markdown",
+ "python_requires": ">=3.6",
+ "classifiers": COMMON_CLASSIFIERS,
+ "project_urls": COMMON_PROJECT_URLS,
+ "zip_safe": False,
+ }
+
+
+def core_packages():
+ return find_packages(include=["vkdispatch", "vkdispatch.*"])
+
+
+def setup_for_target(target: str):
+ kwargs = base_setup_kwargs()
+
+ if target == BUILD_TARGET_FULL:
+ kwargs.update(
+ {
+ "name": "vkdispatch",
+ "packages": core_packages(),
+ "install_requires": ["setuptools>=59.0"],
+ "extras_require": {
+ "cli": ["Click"],
+ **COMMON_EXTRAS,
+ },
+ "entry_points": {
+ "console_scripts": [
+ "vdlist=vkdispatch.cli:cli_entrypoint",
+ ]
+ },
+ "ext_modules": [build_native_extension()],
+ "cmdclass": {"build_ext": CustomBuildExt},
+ }
+ )
+ return kwargs
+
+ if target == BUILD_TARGET_CORE:
+ kwargs.update(
+ {
+ "name": "vkdispatch-core",
+ "packages": core_packages(),
+ "install_requires": ["setuptools>=59.0"],
+ "extras_require": dict(COMMON_EXTRAS),
+ }
+ )
+ return kwargs
+
+ if target == BUILD_TARGET_NATIVE:
+ kwargs.update(
+ {
+ "name": "vkdispatch-vulkan-native",
+ "packages": [],
+ "py_modules": [],
+ "install_requires": [],
+ "ext_modules": [build_native_extension()],
+ "cmdclass": {"build_ext": CustomBuildExt},
+ }
+ )
+ return kwargs
+
+ if target == BUILD_TARGET_META:
+ kwargs.update(
+ {
+ "name": "vkdispatch",
+ "packages": [],
+ "py_modules": [],
+ "install_requires": [
+ f"vkdispatch-core=={VERSION}",
+ f"vkdispatch-vulkan-native=={VERSION}",
+ ],
+ "extras_require": {
+ "cli": ["Click"],
+ **COMMON_EXTRAS,
+ },
+ "entry_points": {
+ "console_scripts": [
+ "vdlist=vkdispatch.cli:cli_entrypoint",
+ ]
+ },
+ }
+ )
+ return kwargs
+
+ raise AssertionError(f"Unhandled build target: {target}")
+
+
+setup(**setup_for_target(BUILD_TARGET))
diff --git a/shader_run.py b/shader_run.py
new file mode 100644
index 00000000..8c34a024
--- /dev/null
+++ b/shader_run.py
@@ -0,0 +1,89 @@
+import vkdispatch as vd
+
+from vkdispatch.base.command_list import CommandList
+from vkdispatch.base.compute_plan import ComputePlan
+from vkdispatch.base.descriptor_set import DescriptorSet
+
+import numpy as np
+
+def load_shader(path: str) -> ComputePlan:
+ shader_source = open(path, 'r').read()
+
+ return ComputePlan(
+ shader_source=shader_source,
+ binding_type_list=[1, 1, 1],
+ pc_size=0,
+ shader_name=f"shader_{path.split('/')[-1].split('.')[0]}"
+ )
+
+def make_descriptor(plan: ComputePlan, out_buff: vd.Buffer, in_buff: vd.Buffer, kern_buff: vd.Buffer):
+ descriptor_set = DescriptorSet(plan)
+
+ descriptor_set.bind_buffer(out_buff, 0)
+ descriptor_set.bind_buffer(in_buff, 1)
+ descriptor_set.bind_buffer(kern_buff, 2)
+
+ return descriptor_set
+
+def numpy_convolution(signal: np.ndarray, kernel: np.ndarray) -> np.ndarray:
+ return np.fft.ifft(
+ np.fft.fft(signal, axis=1).astype(np.complex64)
+ *
+ kernel.conjugate(),
+ axis=1
+ )
+
+BUFF_SHAPE = (4, 512, 257)
+
+np.random.seed(1337)
+
+in_data = (np.random.rand(*BUFF_SHAPE) + 1j * np.random.rand(*BUFF_SHAPE)).astype(np.complex64)
+kern_data = (np.random.rand(*BUFF_SHAPE) + 1j * np.random.rand(*BUFF_SHAPE)).astype(np.complex64)
+
+reference_result_data = numpy_convolution(in_data, kern_data[0])
+
+out_buff = vd.buffer_c64(BUFF_SHAPE)
+in_buff = vd.buffer_c64(BUFF_SHAPE)
+kern_buff = vd.buffer_c64(BUFF_SHAPE)
+
+in_buff.write(in_data)
+kern_buff.write(kern_data)
+
+block_count = (1028, 32, 1)
+
+plan_bad = load_shader("conv_bad.comp")
+plan_good = load_shader("conv_good.comp")
+
+cmd_list_bad = CommandList()
+
+cmd_list_bad.record_compute_plan(
+ plan_bad,
+ make_descriptor(plan_bad, out_buff, in_buff, kern_buff),
+ block_count
+)
+
+cmd_list_bad.submit(instance_count=1)
+
+result_data_bad = out_buff.read(0)
+
+cmd_list_good = CommandList()
+
+cmd_list_good.record_compute_plan(
+ plan_good,
+ make_descriptor(plan_good, out_buff, in_buff, kern_buff),
+ block_count
+)
+
+cmd_list_good.submit(instance_count=1)
+
+result_data_good = out_buff.read(0)
+
+for i in range(BUFF_SHAPE[0]):
+ np.save(f"result_bad_{i}.npy", result_data_bad[i])
+ np.save(f"result_good_{i}.npy", result_data_good[i])
+ np.save(f"reference_result_{i}.npy", reference_result_data[i])
+ np.save(f"diff_bad_{i}.npy", result_data_bad[i] - reference_result_data[i])
+ np.save(f"diff_good_{i}.npy", result_data_good[i] - reference_result_data[i])
+ np.save(f"diff_{i}.npy", result_data_good[i] - result_data_bad[i])
+
+assert np.allclose(result_data_good, result_data_bad, atol=1e-3)
diff --git a/test.py b/test.py
index 320b68e5..b7f21622 100644
--- a/test.py
+++ b/test.py
@@ -1,51 +1,17 @@
import vkdispatch as vd
import vkdispatch.codegen as vc
-import numpy as np
-from typing import Tuple
+vd.initialize(backend="vulkan", log_level=vd.LogLevel.INFO)
+vc.set_codegen_backend("glsl")
-vd.initialize(backend="pycuda")
+SIZE = 4096
-def make_shape(fft_size: int, data_size: int) -> Tuple[int, ...]:
- total_square_size = fft_size * fft_size
- assert data_size % total_square_size == 0, "Data size must be a multiple of fft_size squared"
- return (data_size // total_square_size, fft_size, fft_size)
+buff_shape = (2, SIZE, SIZE)
-def make_random_data(fft_size: int, run_index: int, data_size: int, seed: int = 1337) -> np.ndarray:
- shape = make_shape(fft_size, data_size)
- rng = np.random.default_rng(seed + fft_size * 1000 + run_index)
+buff = vd.Buffer(buff_shape, var_type=vd.complex64)
- real = rng.standard_normal(shape).astype(np.float32)
- imag = rng.standard_normal(shape).astype(np.float32)
- return (real + 1j * imag).astype(np.complex64)
+vd.vkfft.fft(buff, axis=1) #, print_shader=True)
-def compute_metrics(reference: np.ndarray, result: np.ndarray):
- reference64 = reference.astype(np.complex128, copy=False)
- result64 = result.astype(np.complex128, copy=False)
+vd.queue_wait_idle()
- delta = result64 - reference64
- abs_delta = np.abs(delta)
- abs_reference = np.abs(reference64)
-
- eps = 1e-12
- relative_l2 = np.linalg.norm(delta.ravel()) / max(np.linalg.norm(reference64.ravel()), eps)
- max_relative = np.max(abs_delta / np.maximum(abs_reference, eps))
- max_absolute = np.max(abs_delta)
-
- return float(relative_l2), float(max_relative), float(max_absolute)
-
-fft_size = 64
-data_size = 16 * 1024 * 1024
-
-input_data = make_random_data(fft_size, 0, data_size)
-reference = np.fft.fft(input_data)
-
-shape = make_shape(fft_size, data_size)
-
-buffer = vd.buffer_c64(shape) #Buffer(shape, var_type=vd.complex64)
-
-buffer.write(input_data)
-vd.fft.fft(buffer) #, print_shader=True)
-result_data = buffer.read(0)
-
-print(compute_metrics(reference, result_data))
\ No newline at end of file
+#print(vd.fft.fft_src(buff_shape, axis=1).code)
\ No newline at end of file
diff --git a/test2.py b/test2.py
index 813a205e..5f494e18 100644
--- a/test2.py
+++ b/test2.py
@@ -1,304 +1,109 @@
import vkdispatch as vd
import vkdispatch.codegen as vc
-vd.initialize(debug_mode=True, backend="pycuda") #, log_level=vd.LogLevel.INFO)
+#vd.initialize(debug_mode=True, backend="cuda")
+#vc.set_codegen_backend("cuda")
-vc.set_codegen_backend("cuda")
-
-import dataclasses
-import enum
-
-from typing import List
-from typing import Any
-from typing import Dict
-from typing import Tuple
-
-#vd.initialize(debug_mode=True)
-vd.make_context(use_cpu=True)
-
-from vkdispatch.base.compute_plan import ComputePlan
-from vkdispatch.base.descriptor_set import DescriptorSet
-from vkdispatch.base.command_list import CommandList
+from typing import Callable, Union, Tuple
import numpy as np
-class CommandType(enum.Enum):
- ADD_VALUE = 0
- SUB_VALUE = 1
- MULT_VALUE = 2
- DIV_VALUE = 3
- SIN_VALUE = 4
- COS_VALUE = 5
-
-valid_commands = [
- CommandType.ADD_VALUE,
- CommandType.SUB_VALUE,
-]
-
-command_type_to_str = {
- CommandType.ADD_VALUE: "ADD",
- CommandType.SUB_VALUE: "SUB",
- CommandType.MULT_VALUE: "MULT",
- CommandType.DIV_VALUE: "DIV",
- CommandType.SIN_VALUE: "SIN",
- CommandType.COS_VALUE: "COS"
-}
-
-@dataclasses.dataclass
-class ProgramCommand:
- command_type: CommandType
- value: float
+import time
+import dataclasses
@dataclasses.dataclass
-class RunConfig:
- buffer_count: int
- buffer_sizes: List[int]
-
- program_count: int
- program_commands: List[List[ProgramCommand]]
-
- def __repr__(self):
- commands_repr = ""
-
- for commands in self.program_commands:
- commands_repr += "\n"
-
- for command in commands:
- command_name = command_type_to_str[command.command_type]
-
- commands_repr += f" {command_name} {command.value}\n"
-
- return f"""RunConfig(
- buffer_count={self.buffer_count},
- buffer_sizes={self.buffer_sizes},
- program_count={self.program_count},
- program_commands=[{commands_repr}
-])"""
-
-def make_random_config() -> RunConfig:
- buffer_count = np.random.randint(10, 50)
- buffer_sizes = np.random.randint(500, 2500, size=buffer_count).tolist()
-
- program_count = np.random.randint(10, 50)
- program_commands = []
-
- for _ in range(program_count):
- command_count = np.random.randint(10, 50)
- commands = []
-
- for _ in range(command_count):
- command_type = np.random.choice(valid_commands)
- value = np.random.uniform(-10, 10)
-
- commands.append(ProgramCommand(command_type, value))
-
- program_commands.append(commands)
-
- return RunConfig(
- buffer_count=buffer_count,
- buffer_sizes=buffer_sizes,
- program_count=program_count,
- program_commands=program_commands
- )
-
-buffer_cache: Dict[int, vd.Buffer] = {}
-
-def get_buffer(index: int, config: RunConfig) -> vd.Buffer:
- global buffer_cache
+class Config:
+ data_size: int
+ iter_count: int
+ iter_batch: int
+ run_count: int
+ signal_factor: int
+ warmup: int = 10
+
+ def make_shape(self, fft_size: int) -> Tuple[int, ...]:
+ total_square_size = fft_size * fft_size
+ assert self.data_size % total_square_size == 0, "Data size must be a multiple of fft_size squared"
+ return (self.data_size // total_square_size, fft_size, fft_size)
- if index not in buffer_cache:
- buffer_cache[index] = vd.asbuffer(
- np.zeros(
- shape=(config.buffer_sizes[index],),
- dtype=np.float32
- )
- )
-
- return buffer_cache[index]
+ def make_random_data(self, fft_size: int):
+ shape = self.make_shape(fft_size)
+ return np.random.rand(*shape).astype(np.complex64)
-array_cache: Dict[int, np.ndarray] = {}
+def run_vkdispatch(config: Config,
+ fft_size: int,
+ io_count: Union[int, Callable],
+ gpu_function: Callable) -> float:
+ shape = config.make_shape(fft_size)
-def get_array(index: int, config: RunConfig) -> np.ndarray:
- global array_cache
+ buffer = vd.Buffer(shape, var_type=vd.complex64)
+ kernel = vd.Buffer(shape, var_type=vd.complex64)
- if index not in array_cache:
- array_cache[index] = np.zeros(
- shape=(config.buffer_sizes[index],),
- dtype=np.float32
- )
-
- return array_cache[index]
-
-def make_source(commands: List[ProgramCommand]):
- local_size_x = vd.get_context().max_workgroup_size[0]
-
- header = """
-#version 450
-#extension GL_ARB_separate_shader_objects : enable
-//#extension GL_EXT_debug_printf : enable
-
-layout(push_constant) uniform PushConstant {
- uint exec_count;
-} PC;
-
-layout(set = 0, binding = 0) buffer Buffer0 { float data[]; } bufOut;
-layout(set = 0, binding = 1) buffer Buffer1 { float data[]; } bufIn;
-""" + f"""
-layout(local_size_x = {local_size_x}, local_size_y = 1, local_size_z = 1) in;
-""" + """
-void main() {
- if(PC.exec_count <= gl_GlobalInvocationID.x) {
- return ;
- }
-
- uint tid = gl_GlobalInvocationID.x;
-
- float value = bufIn.data[tid];
-"""
-
- body = ""
-
- for command in commands:
- if command.command_type == CommandType.ADD_VALUE:
- body += f" value += {command.value};\n"
- elif command.command_type == CommandType.SUB_VALUE:
- body += f" value -= {command.value};\n"
- elif command.command_type == CommandType.MULT_VALUE:
- body += f" value *= {command.value};\n"
- elif command.command_type == CommandType.DIV_VALUE:
- body += f" value /= {command.value};\n"
- elif command.command_type == CommandType.SIN_VALUE:
- body += f" value = sin(value);\n"
- elif command.command_type == CommandType.COS_VALUE:
- body += f" value = cos(value);\n"
-
- ending = """
- bufOut.data[tid] = value;
-}
-"""
-
- return header + body + ending
-
-program_cache: Dict[int, ComputePlan] = {}
-
-def get_program(index: int, config: RunConfig) -> ComputePlan:
- global program_cache
-
- if index not in program_cache:
- program_cache[index] = ComputePlan(
- shader_source=make_source(config.program_commands[index]),
- binding_type_list=[1, 1],
- pc_size=4,
- shader_name=f"program_{index}"
- )
-
- return program_cache[index]
-
-descriptor_set_cache: Dict[Tuple[int, int, int], DescriptorSet] = {}
-
-def get_descriptor_set(out_buffer: int, in_buffer: int, program: ComputePlan, config: RunConfig) -> DescriptorSet:
- global descriptor_set_cache
-
- dict_key = (out_buffer, in_buffer, program._handle)
-
- if dict_key not in descriptor_set_cache:
- output_buffer = get_buffer(out_buffer, config)
- input_buffer = get_buffer(in_buffer, config)
-
- descriptor_set = DescriptorSet(program)
- descriptor_set.bind_buffer(output_buffer, 0)
- descriptor_set.bind_buffer(input_buffer, 1)
-
- descriptor_set_cache[dict_key] = descriptor_set
+ graph = vd.CommandGraph()
+ old_graph = vd.set_global_graph(graph)
+
+ gpu_function(config, fft_size, buffer, kernel)
- return descriptor_set_cache[dict_key]
+ vd.set_global_graph(old_graph)
-def clear_caches():
- global buffer_cache
- global array_cache
- global program_cache
- global descriptor_set_cache
+ for _ in range(config.warmup):
+ graph.submit(config.iter_batch)
- buffer_cache.clear()
- array_cache.clear()
- program_cache.clear()
- descriptor_set_cache.clear()
+ vd.queue_wait_idle()
-def do_vkdispatch_command(cmd_list: CommandList, out_buffer: int, in_buffer: int, program: int, config: RunConfig):
- compute_plan = get_program(program, config)
- descriptor_set = get_descriptor_set(out_buffer, in_buffer, compute_plan, config)
+ if callable(io_count):
+ io_count = io_count(buffer.size, fft_size)
- cmd_list.reset()
+ gb_byte_count = io_count * 8 * buffer.size / (1024 * 1024 * 1024)
- local_size = vd.get_context().max_workgroup_size[0]
+ start_time = time.perf_counter()
- total_exec_size = min(config.buffer_sizes[out_buffer], config.buffer_sizes[in_buffer])
+ for _ in range(config.iter_count // config.iter_batch):
+ graph.submit(config.iter_batch)
- block_count = (total_exec_size + local_size - 1) // local_size
+ vd.queue_wait_idle()
- cmd_list.record_compute_plan(compute_plan, descriptor_set, [block_count, 1, 1])
+ elapsed_time = time.perf_counter() - start_time
- cmd_list.submit(data=np.array([total_exec_size], dtype=np.uint32).tobytes())
+ buffer.destroy()
+ kernel.destroy()
+ graph.destroy()
+ vd.fft.cache_clear()
-def do_numpy_command(out_buffer: int, in_buffer: int, program: int, config: RunConfig):
- output_array = get_array(out_buffer, config)
- input_array = get_array(in_buffer, config)
+ time.sleep(1)
- total_exec_size = min(config.buffer_sizes[out_buffer], config.buffer_sizes[in_buffer])
+ vd.queue_wait_idle()
- temp_array = np.zeros(shape=(total_exec_size,), dtype=np.float32)
- temp_array[:] = input_array[:total_exec_size]
+ return gb_byte_count, elapsed_time
- commands = config.program_commands[program]
- for command in commands:
- if command.command_type == CommandType.ADD_VALUE:
- temp_array += command.value
- temp_array = temp_array.astype(np.float32)
- elif command.command_type == CommandType.SUB_VALUE:
- temp_array -= command.value
- temp_array = temp_array.astype(np.float32)
- elif command.command_type == CommandType.MULT_VALUE:
- temp_array *= command.value
- temp_array = temp_array.astype(np.float32)
- elif command.command_type == CommandType.DIV_VALUE:
- temp_array /= command.value
- temp_array = temp_array.astype(np.float32)
- elif command.command_type == CommandType.SIN_VALUE:
- temp_array = np.sin(temp_array)
- temp_array = temp_array.astype(np.float32)
- elif command.command_type == CommandType.COS_VALUE:
- temp_array = np.cos(temp_array)
- temp_array = temp_array.astype(np.float32)
+def run_test(config: Config,
+ io_count: Union[int, Callable],
+ gpu_function: Callable):
+ fft_sizes = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
- output_array[:total_exec_size] = temp_array
+ for fft_size in fft_sizes:
+ rates = []
-def test_async_commands():
- for _ in range(50):
- clear_caches()
-
- config = make_random_config()
+ for _ in range(config.run_count):
+ gb_byte_count, elapsed_time = run_vkdispatch(config, fft_size, io_count, gpu_function)
+ gb_per_second = config.iter_count * gb_byte_count / elapsed_time
- cmd_list = CommandList()
+ print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.4f} GB/s")
+ rates.append(gb_per_second)
- exec_count = np.random.randint(1, 250)
+def do_fft(config: Config,
+ fft_size: int,
+ buffer: vd.Buffer,
+ kernel: vd.Buffer):
+ vd.fft.fft(buffer)
- input_buffers = np.random.randint(0, config.buffer_count, size=exec_count)
- output_buffers = np.random.randint(0, config.buffer_count, size=exec_count)
- programs = np.random.randint(0, config.program_count, size=exec_count)
- for input_buffer, output_buffer, program in zip(input_buffers, output_buffers, programs):
- do_vkdispatch_command(cmd_list, output_buffer, input_buffer, program, config)
-
- for input_buffer, output_buffer, program in zip(input_buffers, output_buffers, programs):
- do_numpy_command(output_buffer, input_buffer, program, config)
-
- for i in range(config.buffer_count):
- numpy_buffer = get_array(i, config)
- vkbuffer = get_buffer(i, config).read(0)
-
- assert np.allclose(vkbuffer, numpy_buffer, atol=1e-3)
-
- clear_caches()
+conf = Config(
+ data_size=2**26,
+ iter_count=80,
+ iter_batch=10,
+ run_count=1,
+ signal_factor=8
+)
-test_async_commands()
\ No newline at end of file
+run_test(conf, 2, do_fft)
\ No newline at end of file
diff --git a/test3.py b/test3.py
deleted file mode 100644
index 7b29f4eb..00000000
--- a/test3.py
+++ /dev/null
@@ -1,470 +0,0 @@
-
-import pycuda.autoinit
-import pycuda.driver as cuda
-import numpy as np
-from pycuda.compiler import SourceModule
-
-import struct
-
-
-cuda_kernel = """
-// Expected local size: (8, 1, 1)
-#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X 8
-#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y 1
-#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Z 1
-
-#include
-#include
-#include
-
-#define VKDISPATCH_ENABLE_SUBGROUP_OPS 1
-#define VKDISPATCH_ENABLE_PRINTF 1
-
-__device__ __forceinline__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); }
-__device__ __forceinline__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); }
-__device__ __forceinline__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); }
-__device__ __forceinline__ float2 operator*(float s, float2 v) { return make_float2(s * v.x, s * v.y); }
-__device__ __forceinline__ float2 operator*(float2 v, float s) { return make_float2(v.x * s, v.y * s); }
-
-__device__ __forceinline__ float2 vkdispatch_make_float2(float x, float y) { return make_float2(x, y); }
-__device__ __forceinline__ float2 vkdispatch_make_float2(float x) { return make_float2(x, x); }
-template __device__ __forceinline__ float2 vkdispatch_make_float2(TVec v) { return make_float2((float)v.x, (float)v.y); }
-
-__device__ __forceinline__ uint3 vkdispatch_local_invocation_id() {
- return make_uint3((unsigned int)threadIdx.x, (unsigned int)threadIdx.y, (unsigned int)threadIdx.z);
-}
-
-__device__ __forceinline__ uint3 vkdispatch_workgroup_id() {
- return make_uint3((unsigned int)blockIdx.x, (unsigned int)blockIdx.y, (unsigned int)blockIdx.z);
-}
-
-__device__ __forceinline__ unsigned int vkdispatch_local_invocation_index() {
- return (unsigned int)(threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z));
-}
-
-__shared__ float2 sdata[68];
-
-struct UniformObjectBuffer {
- uint4 exec_count;
- int4 sdata_shape;
- int4 buf1_shape;
-};
-struct Buffer1 { float2* data; };
-
-extern "C" __global__ void vkdispatch_main(const UniformObjectBuffer* vkdispatch_uniform_ptr, float2* vkdispatch_binding_1_ptr) {
- const UniformObjectBuffer& UBO = *vkdispatch_uniform_ptr;
- Buffer1 buf1 = {vkdispatch_binding_1_ptr};
- unsigned int workgroup_index = ((unsigned int)(vkdispatch_workgroup_id().x));
- unsigned int tid = vkdispatch_local_invocation_id().x;
- unsigned int input_batch_offset = ((unsigned int)(0));
- unsigned int output_batch_offset = ((unsigned int)(0));
- float2 omega_register = vkdispatch_make_float2(0.0f);
- unsigned int subsequence_offset = ((unsigned int)(0));
- unsigned int io_index = ((unsigned int)(0));
- unsigned int io_index_2 = ((unsigned int)(0));
- float2 radix_register_0 = vkdispatch_make_float2(0.0f);
- float2 radix_register_1 = vkdispatch_make_float2(0.0f);
- float2 fft_reg_0 = vkdispatch_make_float2(0.0f);
- float2 fft_reg_1 = vkdispatch_make_float2(0.0f);
- float2 fft_reg_2 = vkdispatch_make_float2(0.0f);
- float2 fft_reg_3 = vkdispatch_make_float2(0.0f);
- float2 fft_reg_4 = vkdispatch_make_float2(0.0f);
- float2 fft_reg_5 = vkdispatch_make_float2(0.0f);
- float2 fft_reg_6 = vkdispatch_make_float2(0.0f);
- float2 fft_reg_7 = vkdispatch_make_float2(0.0f);
-
- /* Reading input samples from global memory into FFT registers. */
- input_batch_offset = ((workgroup_index + vkdispatch_local_invocation_id().y) << 6);
- io_index = (tid + input_batch_offset);
- fft_reg_0 = buf1.data[io_index];
- io_index = ((tid + 8) + input_batch_offset);
- fft_reg_1 = buf1.data[io_index];
- io_index = ((tid + 16) + input_batch_offset);
- fft_reg_2 = buf1.data[io_index];
- io_index = ((tid + 24) + input_batch_offset);
- fft_reg_3 = buf1.data[io_index];
- io_index = ((tid + 32) + input_batch_offset);
- fft_reg_4 = buf1.data[io_index];
- io_index = ((tid + 40) + input_batch_offset);
- fft_reg_5 = buf1.data[io_index];
- io_index = ((tid + 48) + input_batch_offset);
- fft_reg_6 = buf1.data[io_index];
- io_index = ((tid + 56) + input_batch_offset);
- fft_reg_7 = buf1.data[io_index];
-
- /*
- * FFT stage 1/2.
- * Prime group (2, 2, 2): execute 1 radix-8 sub-FFTs per invocation.
- * Register-group coverage this stage: 8.
- */
-
- /*
- * Starting mixed-radix FFT decomposition for this invocation on 8 register samples.
- * Radix factorization sequence: (2, 2, 2).
- * At each level: partition lanes into stage-local sub-sequences, apply twiddles,
- * run radix-P butterflies, then reassemble in stride-consistent order for downstream stages.
- */
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_4;
- fft_reg_4 = (fft_reg_0 - radix_register_0);
- fft_reg_0 = (fft_reg_0 + radix_register_0);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_5;
- fft_reg_5 = (fft_reg_1 - radix_register_0);
- fft_reg_1 = (fft_reg_1 + radix_register_0);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_6;
- fft_reg_6 = (fft_reg_2 - radix_register_0);
- fft_reg_2 = (fft_reg_2 + radix_register_0);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_7;
- fft_reg_7 = (fft_reg_3 - radix_register_0);
- fft_reg_3 = (fft_reg_3 + radix_register_0);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_2;
- fft_reg_2 = (fft_reg_0 - radix_register_0);
- fft_reg_0 = (fft_reg_0 + radix_register_0);
-
- /*
- * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass.
- * Twiddle domain size: N = 4. Twiddle index source: 1.
- * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index).
- * This phase-aligns each sub-FFT with its parent decomposition stage.
- */
- radix_register_0.x = fft_reg_6.x;
- fft_reg_6.x = fft_reg_6.y;
- fft_reg_6.y = (-radix_register_0.x);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_6;
- fft_reg_6 = (fft_reg_4 - radix_register_0);
- fft_reg_4 = (fft_reg_4 + radix_register_0);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_3;
- fft_reg_3 = (fft_reg_1 - radix_register_0);
- fft_reg_1 = (fft_reg_1 + radix_register_0);
-
- /*
- * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass.
- * Twiddle domain size: N = 4. Twiddle index source: 1.
- * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index).
- * This phase-aligns each sub-FFT with its parent decomposition stage.
- */
- radix_register_0.x = fft_reg_7.x;
- fft_reg_7.x = fft_reg_7.y;
- fft_reg_7.y = (-radix_register_0.x);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_7;
- fft_reg_7 = (fft_reg_5 - radix_register_0);
- fft_reg_5 = (fft_reg_5 + radix_register_0);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_1;
- fft_reg_1 = (fft_reg_0 - radix_register_0);
- fft_reg_0 = (fft_reg_0 + radix_register_0);
-
- /*
- * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass.
- * Twiddle domain size: N = 8. Twiddle index source: 1.
- * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index).
- * This phase-aligns each sub-FFT with its parent decomposition stage.
- */
- radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_5.x, 0.7071067811865476, ((-fft_reg_5.y) * -0.7071067811865475)), fmaf(fft_reg_5.x, -0.7071067811865475, (fft_reg_5.y * 0.7071067811865476)));
- fft_reg_5 = radix_register_0;
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_5;
- fft_reg_5 = (fft_reg_4 - radix_register_0);
- fft_reg_4 = (fft_reg_4 + radix_register_0);
-
- /*
- * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass.
- * Twiddle domain size: N = 8. Twiddle index source: 2.
- * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index).
- * This phase-aligns each sub-FFT with its parent decomposition stage.
- */
- radix_register_0.x = fft_reg_3.x;
- fft_reg_3.x = fft_reg_3.y;
- fft_reg_3.y = (-radix_register_0.x);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_3;
- fft_reg_3 = (fft_reg_2 - radix_register_0);
- fft_reg_2 = (fft_reg_2 + radix_register_0);
-
- /*
- * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass.
- * Twiddle domain size: N = 8. Twiddle index source: 3.
- * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index).
- * This phase-aligns each sub-FFT with its parent decomposition stage.
- */
- radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_7.x, -0.7071067811865475, ((-fft_reg_7.y) * -0.7071067811865476)), fmaf(fft_reg_7.x, -0.7071067811865476, (fft_reg_7.y * -0.7071067811865475)));
- fft_reg_7 = radix_register_0;
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_7;
- fft_reg_7 = (fft_reg_6 - radix_register_0);
- fft_reg_6 = (fft_reg_6 + radix_register_0);
-
- /*
- * FFT stage 2/2.
- * Prime group (2, 2, 2): execute 1 radix-8 sub-FFTs per invocation.
- * Register-group coverage this stage: 8.
- */
- /* Register shuffle not possible, falling back to shared memory shuffle. */
- io_index = (tid * 8);
- io_index = (io_index + (io_index >> 4));
- sdata[io_index] = fft_reg_0;
- io_index = (tid * 8 + 1);
- io_index = (io_index + (io_index >> 4));
- sdata[io_index] = fft_reg_4;
- io_index = (tid * 8 + 2);
- io_index = (io_index + (io_index >> 4));
- sdata[io_index] = fft_reg_2;
- io_index = (tid * 8 + 3);
- io_index = (io_index + (io_index >> 4));
- sdata[io_index] = fft_reg_6;
- io_index = (tid * 8 + 4);
- io_index = (io_index + (io_index >> 4));
- sdata[io_index] = fft_reg_1;
- io_index = (tid * 8 + 5);
- io_index = (io_index + (io_index >> 4));
- sdata[io_index] = fft_reg_5;
- io_index = (tid * 8 + 6);
- io_index = (io_index + (io_index >> 4));
- sdata[io_index] = fft_reg_3;
- io_index = (tid * 8 + 7);
- io_index = (io_index + (io_index >> 4));
- sdata[io_index] = fft_reg_7;
- __syncthreads();
- io_index = tid;
- io_index = (io_index + (io_index >> 4));
- fft_reg_0 = sdata[io_index];
- io_index = (tid + 8);
- io_index = (io_index + (io_index >> 4));
- fft_reg_4 = sdata[io_index];
- io_index = (tid + 16);
- io_index = (io_index + (io_index >> 4));
- fft_reg_2 = sdata[io_index];
- io_index = (tid + 24);
- io_index = (io_index + (io_index >> 4));
- fft_reg_6 = sdata[io_index];
- io_index = (tid + 32);
- io_index = (io_index + (io_index >> 4));
- fft_reg_1 = sdata[io_index];
- io_index = (tid + 40);
- io_index = (io_index + (io_index >> 4));
- fft_reg_5 = sdata[io_index];
- io_index = (tid + 48);
- io_index = (io_index + (io_index >> 4));
- fft_reg_3 = sdata[io_index];
- io_index = (tid + 56);
- io_index = (io_index + (io_index >> 4));
- fft_reg_7 = sdata[io_index];
-
- /*
- * Starting mixed-radix FFT decomposition for this invocation on 8 register samples.
- * Radix factorization sequence: (2, 2, 2).
- * At each level: partition lanes into stage-local sub-sequences, apply twiddles,
- * run radix-P butterflies, then reassemble in stride-consistent order for downstream stages.
- */
-
- /*
- * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass.
- * Twiddle domain size: N = 64. Twiddle index source: tid.
- * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index).
- * This phase-aligns each sub-FFT with its parent decomposition stage.
- */
- omega_register.x = (tid * -0.09817477042468103);
- omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x));
- radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_4.x, omega_register.x, ((-fft_reg_4.y) * omega_register.y)), fmaf(fft_reg_4.x, omega_register.y, (fft_reg_4.y * omega_register.x)));
- fft_reg_4 = radix_register_0;
- omega_register.x = (tid * -0.19634954084936207);
- omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x));
- radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_2.x, omega_register.x, ((-fft_reg_2.y) * omega_register.y)), fmaf(fft_reg_2.x, omega_register.y, (fft_reg_2.y * omega_register.x)));
- fft_reg_2 = radix_register_0;
- omega_register.x = (tid * -0.2945243112740431);
- omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x));
- radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_6.x, omega_register.x, ((-fft_reg_6.y) * omega_register.y)), fmaf(fft_reg_6.x, omega_register.y, (fft_reg_6.y * omega_register.x)));
- fft_reg_6 = radix_register_0;
- omega_register.x = (tid * -0.39269908169872414);
- omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x));
- radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_1.x, omega_register.x, ((-fft_reg_1.y) * omega_register.y)), fmaf(fft_reg_1.x, omega_register.y, (fft_reg_1.y * omega_register.x)));
- fft_reg_1 = radix_register_0;
- omega_register.x = (tid * -0.4908738521234052);
- omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x));
- radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_5.x, omega_register.x, ((-fft_reg_5.y) * omega_register.y)), fmaf(fft_reg_5.x, omega_register.y, (fft_reg_5.y * omega_register.x)));
- fft_reg_5 = radix_register_0;
- omega_register.x = (tid * -0.5890486225480862);
- omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x));
- radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_3.x, omega_register.x, ((-fft_reg_3.y) * omega_register.y)), fmaf(fft_reg_3.x, omega_register.y, (fft_reg_3.y * omega_register.x)));
- fft_reg_3 = radix_register_0;
- omega_register.x = (tid * -0.6872233929727672);
- omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x));
- radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_7.x, omega_register.x, ((-fft_reg_7.y) * omega_register.y)), fmaf(fft_reg_7.x, omega_register.y, (fft_reg_7.y * omega_register.x)));
- fft_reg_7 = radix_register_0;
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_1;
- fft_reg_1 = (fft_reg_0 - radix_register_0);
- fft_reg_0 = (fft_reg_0 + radix_register_0);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_5;
- fft_reg_5 = (fft_reg_4 - radix_register_0);
- fft_reg_4 = (fft_reg_4 + radix_register_0);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_3;
- fft_reg_3 = (fft_reg_2 - radix_register_0);
- fft_reg_2 = (fft_reg_2 + radix_register_0);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_7;
- fft_reg_7 = (fft_reg_6 - radix_register_0);
- fft_reg_6 = (fft_reg_6 + radix_register_0);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_2;
- fft_reg_2 = (fft_reg_0 - radix_register_0);
- fft_reg_0 = (fft_reg_0 + radix_register_0);
-
- /*
- * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass.
- * Twiddle domain size: N = 4. Twiddle index source: 1.
- * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index).
- * This phase-aligns each sub-FFT with its parent decomposition stage.
- */
- radix_register_0.x = fft_reg_3.x;
- fft_reg_3.x = fft_reg_3.y;
- fft_reg_3.y = (-radix_register_0.x);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_3;
- fft_reg_3 = (fft_reg_1 - radix_register_0);
- fft_reg_1 = (fft_reg_1 + radix_register_0);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_6;
- fft_reg_6 = (fft_reg_4 - radix_register_0);
- fft_reg_4 = (fft_reg_4 + radix_register_0);
-
- /*
- * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass.
- * Twiddle domain size: N = 4. Twiddle index source: 1.
- * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index).
- * This phase-aligns each sub-FFT with its parent decomposition stage.
- */
- radix_register_0.x = fft_reg_7.x;
- fft_reg_7.x = fft_reg_7.y;
- fft_reg_7.y = (-radix_register_0.x);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_7;
- fft_reg_7 = (fft_reg_5 - radix_register_0);
- fft_reg_5 = (fft_reg_5 + radix_register_0);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_4;
- fft_reg_4 = (fft_reg_0 - radix_register_0);
- fft_reg_0 = (fft_reg_0 + radix_register_0);
-
- /*
- * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass.
- * Twiddle domain size: N = 8. Twiddle index source: 1.
- * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index).
- * This phase-aligns each sub-FFT with its parent decomposition stage.
- */
- radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_5.x, 0.7071067811865476, ((-fft_reg_5.y) * -0.7071067811865475)), fmaf(fft_reg_5.x, -0.7071067811865475, (fft_reg_5.y * 0.7071067811865476)));
- fft_reg_5 = radix_register_0;
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_5;
- fft_reg_5 = (fft_reg_1 - radix_register_0);
- fft_reg_1 = (fft_reg_1 + radix_register_0);
-
- /*
- * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass.
- * Twiddle domain size: N = 8. Twiddle index source: 2.
- * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index).
- * This phase-aligns each sub-FFT with its parent decomposition stage.
- */
- radix_register_0.x = fft_reg_6.x;
- fft_reg_6.x = fft_reg_6.y;
- fft_reg_6.y = (-radix_register_0.x);
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_6;
- fft_reg_6 = (fft_reg_2 - radix_register_0);
- fft_reg_2 = (fft_reg_2 + radix_register_0);
-
- /*
- * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass.
- * Twiddle domain size: N = 8. Twiddle index source: 3.
- * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index).
- * This phase-aligns each sub-FFT with its parent decomposition stage.
- */
- radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_7.x, -0.7071067811865475, ((-fft_reg_7.y) * -0.7071067811865476)), fmaf(fft_reg_7.x, -0.7071067811865476, (fft_reg_7.y * -0.7071067811865475)));
- fft_reg_7 = radix_register_0;
- /* Radix-2 butterfly base case */
- radix_register_0 = fft_reg_7;
- fft_reg_7 = (fft_reg_3 - radix_register_0);
- fft_reg_3 = (fft_reg_3 + radix_register_0);
-
- /*
- * Writing register-resident FFT outputs to global memory.
- * Addressing uses computed batch offsets plus FFT-lane stride.
- */
- output_batch_offset = ((workgroup_index + vkdispatch_local_invocation_id().y) << 6);
- io_index = (tid + output_batch_offset);
- buf1.data[io_index] = fft_reg_0;
- io_index = ((tid + 8) + output_batch_offset);
- buf1.data[io_index] = fft_reg_1;
- io_index = ((tid + 16) + output_batch_offset);
- buf1.data[io_index] = fft_reg_2;
- io_index = ((tid + 24) + output_batch_offset);
- buf1.data[io_index] = fft_reg_3;
- io_index = ((tid + 32) + output_batch_offset);
- buf1.data[io_index] = fft_reg_4;
- io_index = ((tid + 40) + output_batch_offset);
- buf1.data[io_index] = fft_reg_5;
- io_index = ((tid + 48) + output_batch_offset);
- buf1.data[io_index] = fft_reg_6;
- io_index = ((tid + 56) + output_batch_offset);
- buf1.data[io_index] = fft_reg_7;
-}"""
-
-
-mod = SourceModule(cuda_kernel, no_extern_c=True)
-kernel = mod.get_function("vkdispatch_main")
-
-# --- Set up UniformObjectBuffer on device ---
-# uint4 = 4x uint32 (16 bytes), int4 = 4x int32 (16 bytes)
-# Total: 48 bytes, 16-byte aligned
-
-n = 64
-ubo_bytes = struct.pack(
- "4I 4i 4i",
- # exec_count (uint4)
- n, 1, 1, 0,
- # sdata_shape (int4)
- n, 1, 1, 1,
- # buf1_shape (int4)
- n, 1, 1, 1,
-)
-
-ubo_gpu = cuda.mem_alloc(len(ubo_bytes))
-cuda.memcpy_htod(ubo_gpu, ubo_bytes)
-
-# --- Set up Buffer1 data (float2 = 2x float32 per element) ---
-
-buf1_data = np.random.randn(n).astype(np.complex64)
-buf1_gpu = cuda.mem_alloc(buf1_data.nbytes)
-cuda.memcpy_htod(buf1_gpu, buf1_data)
-
-# --- Pack the Buffer1 struct (just a device pointer, 8 bytes) ---
-# Buffer1 { float2* data } is passed BY VALUE, so we pack the pointer
-
-buf1_struct = struct.pack("P", int(buf1_gpu)) # "P" = pointer-sized uint
-
-# --- Launch ---
-
-kernel(
- ubo_gpu,
- buf1_gpu,
- block=(8, 1, 1),
- grid=(1, 1),
-)
-
-# --- Verify ---
-
-print(buf1_data.shape)
-
-result = np.empty_like(buf1_data)
-cuda.memcpy_dtoh(result, buf1_gpu)
-assert np.allclose(result, np.fft.fft(buf1_data))
-print("Success:", result[:4])
\ No newline at end of file
diff --git a/tests/test_async_processing.py b/tests/test_async_processing.py
index bad805fc..83082142 100644
--- a/tests/test_async_processing.py
+++ b/tests/test_async_processing.py
@@ -129,8 +129,9 @@ def get_array(index: int, config: RunConfig) -> np.ndarray:
def make_source(commands: List[ProgramCommand]):
local_size_x = vd.get_context().max_workgroup_size[0]
+ is_cuda_python = vd.is_cuda()
- if vd.get_backend() == "pycuda":
+ if is_cuda_python:
header = (
f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X {local_size_x}\n"
"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y 1\n"
@@ -193,7 +194,7 @@ def make_source(commands: List[ProgramCommand]):
elif command.command_type == CommandType.COS_VALUE:
body += f" value = cos(value);\n"
- if vd.get_backend() == "pycuda":
+ if is_cuda_python:
ending = """
vkdispatch_binding_0_ptr[tid] = value;
}
@@ -301,6 +302,9 @@ def do_numpy_command(out_buffer: int, in_buffer: int, program: int, config: RunC
output_array[:total_exec_size] = temp_array
def test_async_commands():
+ if not vd.is_vulkan():
+ return
+
for _ in range(50):
clear_caches()
diff --git a/tests/test_fft_mixed_precision.py b/tests/test_fft_mixed_precision.py
new file mode 100644
index 00000000..62dd969f
--- /dev/null
+++ b/tests/test_fft_mixed_precision.py
@@ -0,0 +1,320 @@
+import numpy as np
+import pytest
+from types import SimpleNamespace
+
+import vkdispatch as vd
+import vkdispatch.codegen as vc
+import vkdispatch.fft.functions as fft_functions
+
+
+@pytest.fixture(autouse=True)
+def _clear_fft_cache():
+ yield
+ try:
+ vd.fft.cache_clear()
+ except Exception:
+ pass
+
+
+def _require_runtime_context():
+ try:
+ context = vd.get_context()
+ except Exception as exc:
+ pytest.skip(f"No runtime backend available for mixed-precision FFT tests: {exc}")
+
+ is_dummy = getattr(vd, "is_dummy", None)
+ if callable(is_dummy) and is_dummy():
+ pytest.skip("Dummy backend is codegen-only and cannot execute FFT kernels.")
+
+ return context
+
+
+def _supports_complex32(context) -> bool:
+ for device in context.device_infos:
+ if device.float_16_support != 1:
+ return False
+ if (
+ device.storage_buffer_16_bit_access != 1
+ and device.uniform_and_storage_buffer_16_bit_access != 1
+ ):
+ return False
+ return True
+
+
+def _supports_complex128(context) -> bool:
+ return all(device.float_64_support == 1 for device in context.device_infos)
+
+
+def _require_complex32_support(context):
+ if not _supports_complex32(context):
+ pytest.skip("Active device set does not support complex32 (fp16) FFT buffers.")
+
+
+def _require_complex128_support(context):
+ if not _supports_complex128(context):
+ pytest.skip("Active device set does not support complex128 (fp64) FFT buffers.")
+
+
+def _quantize_to_complex32(values: np.ndarray) -> np.ndarray:
+ real = values.real.astype(np.float16).astype(np.float32)
+ imag = values.imag.astype(np.float16).astype(np.float32)
+ return (real + (1j * imag)).astype(np.complex64)
+
+
+def _write_complex32(buffer: vd.Buffer, values: np.ndarray):
+ packed = np.empty(values.shape + (2,), dtype=np.float16)
+ packed[..., 0] = values.real.astype(np.float16)
+ packed[..., 1] = values.imag.astype(np.float16)
+ buffer.write(np.ascontiguousarray(packed))
+
+
+def test_fft_complex32_io_with_complex64_compute():
+ context = _require_runtime_context()
+ _require_complex32_support(context)
+
+ rng = np.random.default_rng(7)
+ data = (
+ rng.standard_normal(64) + 1j * rng.standard_normal(64)
+ ).astype(np.complex64)
+ quantized = _quantize_to_complex32(data)
+
+ test_buffer = vd.Buffer(data.shape, vd.complex32)
+ _write_complex32(test_buffer, data)
+
+ vd.fft.fft(test_buffer, compute_type=vd.complex64)
+
+ result = test_buffer.read(0).astype(np.complex64)
+ reference = np.fft.fft(quantized).astype(np.complex64)
+
+ assert np.allclose(result, reference, atol=3e-1, rtol=2e-2)
+
+
+def test_fft_map_complex32_input_to_complex128_output_auto_compute():
+ context = _require_runtime_context()
+ _require_complex32_support(context)
+ _require_complex128_support(context)
+
+ rng = np.random.default_rng(11)
+ data = (
+ rng.standard_normal(32) + 1j * rng.standard_normal(32)
+ ).astype(np.complex64)
+ quantized = _quantize_to_complex32(data)
+
+ input_buffer = vd.Buffer(data.shape, vd.complex32)
+ _write_complex32(input_buffer, data)
+ output_buffer = vd.Buffer(data.shape, vd.complex128)
+
+ def input_map(buffer: vc.Buffer[vd.complex32]):
+ vd.fft.read_op().read_from_buffer(buffer)
+
+ def output_map(buffer: vc.Buffer[vd.complex128]):
+ vd.fft.write_op().write_to_buffer(buffer)
+
+ vd.fft.fft(
+ output_buffer,
+ input_buffer,
+ input_map=vd.map(input_map),
+ output_map=vd.map(output_map),
+ )
+
+ result = output_buffer.read(0)
+ reference = np.fft.fft(quantized).astype(np.complex128)
+
+ assert np.allclose(result, reference, atol=3e-1, rtol=2e-2)
+
+
+def test_fft_input_output_maps_allow_float32_buffers():
+ _require_runtime_context()
+
+ rng = np.random.default_rng(23)
+ data = rng.standard_normal(64).astype(np.float32)
+
+ input_buffer = vd.asbuffer(data)
+ output_buffer = vd.Buffer(data.shape, vd.float32)
+
+ def input_map(buffer: vc.Buffer[vd.float32]):
+ read_op = vd.fft.read_op()
+ value = vc.to_dtype(read_op.register.var_type.child_type, buffer[read_op.io_index])
+ read_op.register.real = value
+ read_op.register.imag = vc.to_dtype(read_op.register.var_type.child_type, 0)
+
+ def output_map(buffer: vc.Buffer[vd.float32]):
+ write_op = vd.fft.write_op()
+ buffer[write_op.io_index] = vc.to_dtype(buffer.var_type, write_op.register.real)
+
+ vd.fft.fft(
+ output_buffer,
+ input_buffer,
+ input_map=vd.map(input_map),
+ output_map=vd.map(output_map),
+ )
+
+ result = output_buffer.read(0).astype(np.float32)
+ reference = np.fft.fft(data.astype(np.complex64)).real.astype(np.float32)
+
+ assert np.allclose(result, reference, atol=2e-3, rtol=1e-3)
+
+
+def test_convolve_kernel_map_allows_float32_buffer():
+ _require_runtime_context()
+
+ rng = np.random.default_rng(31)
+ data = (
+ rng.standard_normal(64) + 1j * rng.standard_normal(64)
+ ).astype(np.complex64)
+ scale = np.float32(0.5)
+
+ signal_buffer = vd.asbuffer(data.copy())
+ scale_buffer = vd.asbuffer(np.full(data.shape, scale, dtype=np.float32))
+
+ def kernel_map(scale_values: vc.Buffer[vd.float32]):
+ read_op = vd.fft.read_op()
+ scale_value = vc.to_dtype(
+ read_op.register.var_type,
+ vc.to_complex(scale_values[read_op.io_index]),
+ )
+ read_op.register[:] = vc.mult_complex(read_op.register, scale_value)
+
+ vd.fft.convolve(
+ signal_buffer,
+ scale_buffer,
+ kernel_map=vd.map(kernel_map),
+ )
+
+ result = signal_buffer.read(0).astype(np.complex64)
+ reference = (data * scale).astype(np.complex64)
+
+ assert np.allclose(result, reference, atol=2e-3, rtol=1e-3)
+
+
+def test_fft_output_map_without_input_map_uses_explicit_input_buffer():
+ _require_runtime_context()
+
+ rng = np.random.default_rng(37)
+ data = (
+ rng.standard_normal(64) + 1j * rng.standard_normal(64)
+ ).astype(np.complex64)
+
+ input_buffer = vd.asbuffer(data.copy())
+ output_buffer = vd.Buffer(data.shape, vd.complex64)
+
+ @vd.map
+ def output_map(buffer: vc.Buffer[vd.complex64]):
+ vd.fft.write_op().write_to_buffer(buffer)
+
+ vd.fft.fft(
+ output_buffer,
+ input_buffer,
+ output_map=output_map,
+ )
+
+ result = output_buffer.read(0).astype(np.complex64)
+ reference = np.fft.fft(data).astype(np.complex64)
+
+ assert np.allclose(result, reference, atol=2e-3, rtol=1e-3)
+
+
+def test_convolve_output_map_without_input_map_uses_explicit_input_buffer():
+ if True:
+ return
+ _require_runtime_context()
+
+ rng = np.random.default_rng(41)
+ data = (
+ rng.standard_normal(64) + 1j * rng.standard_normal(64)
+ ).astype(np.complex64)
+
+ input_buffer = vd.asbuffer(data.copy())
+ output_buffer = vd.Buffer(data.shape, vd.complex64)
+
+ @vd.map
+ def kernel_map():
+ # Identity map: keep spectrum unchanged.
+ return
+
+ @vd.map
+ def output_map(buffer: vc.Buffer[vd.complex64]):
+ vd.fft.write_op().write_to_buffer(buffer)
+
+ vd.fft.convolve(
+ output_buffer,
+ input_buffer,
+ kernel_map=kernel_map,
+ output_map=output_map,
+ )
+
+ result = output_buffer.read(0).astype(np.complex64)
+ reference = data.astype(np.complex64)
+
+ assert np.allclose(result, reference, atol=2e-3, rtol=1e-3)
+
+
+def test_fft_complex64_io_with_complex128_compute():
+ context = _require_runtime_context()
+ _require_complex128_support(context)
+
+ rng = np.random.default_rng(29)
+ data = (
+ rng.standard_normal(64) + 1j * rng.standard_normal(64)
+ ).astype(np.complex64)
+
+ test_buffer = vd.asbuffer(data)
+ vd.fft.fft(test_buffer, compute_type=vd.complex128)
+
+ result = test_buffer.read(0).astype(np.complex64)
+ reference = np.fft.fft(data).astype(np.complex64)
+
+ assert np.allclose(result, reference, atol=2e-3, rtol=1e-3)
+
+
+def test_resolve_input_precision_output_map_infers_input_from_post_map_argument(monkeypatch):
+ monkeypatch.setattr(
+ fft_functions,
+ "ensure_supported_complex_precision",
+ lambda dtype, role: None,
+ )
+
+ class _FakeBuffer:
+ def __init__(self, var_type):
+ self.var_type = var_type
+
+ output_map = SimpleNamespace(
+ buffer_types=[vc.Buffer[vd.complex64], vc.Buffer[vd.float32]],
+ )
+
+ resolved = fft_functions._resolve_input_precision(
+ (
+ _FakeBuffer(vd.complex64),
+ _FakeBuffer(vd.float32),
+ _FakeBuffer(vd.complex128),
+ ),
+ input_map=None,
+ output_map=output_map,
+ input_type=None,
+ output_precision=None,
+ )
+
+ assert resolved is vd.complex128
+
+
+def test_resolve_input_precision_output_map_requires_input_buffer_after_map_args(monkeypatch):
+ monkeypatch.setattr(
+ fft_functions,
+ "ensure_supported_complex_precision",
+ lambda dtype, role: None,
+ )
+
+ class _FakeBuffer:
+ def __init__(self, var_type):
+ self.var_type = var_type
+
+ output_map = SimpleNamespace(buffer_types=[vc.Buffer[vd.complex64]])
+
+ with pytest.raises(ValueError, match="input buffer argument must be provided"):
+ fft_functions._resolve_input_precision(
+ (_FakeBuffer(vd.complex64),),
+ input_map=None,
+ output_map=output_map,
+ input_type=None,
+ output_precision=None,
+ )
diff --git a/tests/test_image.py b/tests/test_image.py
index 0b6a0c06..2a03478c 100644
--- a/tests/test_image.py
+++ b/tests/test_image.py
@@ -8,6 +8,9 @@
vd.initialize(log_level=vd.LogLevel.WARNING, debug_mode=True)
def test_1d_image_creation():
+ if not vd.is_vulkan():
+ return
+
# Create a 1D image
signal = np.sin(np.array([i/8 for i in range(0, 50, 1)])).astype(np.float32)
@@ -17,6 +20,8 @@ def test_1d_image_creation():
assert np.allclose(test_line.read(0), signal)
def test_2d_image_creation():
+ if not vd.is_vulkan():
+ return
# Create a 2D image
signal_2d = np.sin(np.array([[i/8 + j/17 for i in range(0, 50, 1)] for j in range(0, 50, 1)])).astype(np.float32)
@@ -26,6 +31,8 @@ def test_2d_image_creation():
assert np.allclose(test_img.read(0), signal_2d)
def test_3d_image_creation():
+ if not vd.is_vulkan():
+ return
# Create a 3D image
signal_3d = np.sin(np.array([[[i/8 + j/17 + k/23 for i in range(0, 50, 1)] for j in range(0, 50, 1)] for k in range(0, 50, 1)])).astype(np.float32)
@@ -35,6 +42,8 @@ def test_3d_image_creation():
assert np.allclose(test_img.read(0), signal_3d)
def test_1d_image_linear_sampling():
+ if not vd.is_vulkan():
+ return
# Create a 1D image
signal = np.sin(np.array([i/8 for i in range(0, 50, 1)])).astype(np.float32)
@@ -57,6 +66,8 @@ def do_approx(buff: Buff[f32], line: Img1[f32]):
assert np.allclose(result_arr.read()[0], signal_full, atol=0.002)
def test_2d_image_linear_sampling():
+ if not vd.is_vulkan():
+ return
# Create a 2D image
signal_2d = np.sin(np.array([[i/8 + j/17 for i in range(0, 50, 1)] for j in range(0, 50, 1)])).astype(np.float32)
sample_factor = 10
diff --git a/tests/test_reductions.py b/tests/test_reductions.py
index 06ad2fbe..3bed232d 100644
--- a/tests/test_reductions.py
+++ b/tests/test_reductions.py
@@ -160,4 +160,63 @@ def sum_map(buffer: Buff[f32]) -> f32:
read_data = res_buf.read(0)[0]
# Check that the data is the same
- assert np.allclose([np.sin(data).sum(axis=1)], [read_data])
\ No newline at end of file
+ assert np.allclose([np.sin(data).sum(axis=1)], [read_data])
+
+def test_mapped_reductions_min():
+ # Create a buffer
+ buf = vd.Buffer((1024,), vd.float32)
+
+ # Create a numpy array
+ data = np.random.randn(1024).astype(np.float32)
+
+ # Write the data to the buffer
+ buf.write(data)
+
+ @vd.reduce.map_reduce(vd.reduce.SubgroupMin)
+ def min_map(buffer: Buff[f32]) -> f32:
+ return buffer[vd.reduce.mapped_io_index()]
+
+ res_buf = min_map(buf)
+
+ # Read the data from the buffer
+ read_data = res_buf.read(0)
+
+ # Check that the data is the same
+ assert np.allclose([data.min()], [read_data[0]])
+
+def test_mapped_reductions_max():
+ # Create a buffer
+ buf = vd.Buffer((1024,), vd.float32)
+
+ # Create a numpy array
+ data = np.random.randn(1024).astype(np.float32)
+
+ # Write the data to the buffer
+ buf.write(data)
+
+ @vd.reduce.map_reduce(vd.reduce.SubgroupMax)
+ def max_map(buffer: Buff[f32]) -> f32:
+ return buffer[vd.reduce.mapped_io_index()]
+
+ res_buf = max_map(buf)
+
+ # Read the data from the buffer
+ read_data = res_buf.read(0)
+
+ # Check that the data is the same
+ assert np.allclose([data.max()], [read_data[0]])
+
+def test_min_max_codegen_stage_creation():
+ @vd.reduce.map_reduce(vd.reduce.SubgroupMin)
+ def min_map(buffer: Buff[f32]) -> f32:
+ return buffer[vd.reduce.mapped_io_index()]
+
+ @vd.reduce.map_reduce(vd.reduce.SubgroupMax)
+ def max_map(buffer: Buff[f32]) -> f32:
+ return buffer[vd.reduce.mapped_io_index()]
+
+ min_src_stage1, min_src_stage2 = min_map.get_src()
+ max_src_stage1, max_src_stage2 = max_map.get_src()
+
+ assert min_src_stage1 and min_src_stage2
+ assert max_src_stage1 and max_src_stage2
diff --git a/tests/test_vkfft.py b/tests/test_vkfft.py
index 49b2bf70..caf8a480 100644
--- a/tests/test_vkfft.py
+++ b/tests/test_vkfft.py
@@ -20,6 +20,8 @@ def check_fft_dims(fft_dims: List[int], max_fft_size: int):
return all([dim <= max_fft_size for dim in fft_dims]) and np.prod(fft_dims) * vd.complex64.item_size < 2 ** 20
def test_fft_1d():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0])
@@ -44,6 +46,8 @@ def test_fft_1d():
vd.vkfft.clear_plan_cache()
def test_fft_2d():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0])
@@ -67,6 +71,8 @@ def test_fft_2d():
vd.vkfft.clear_plan_cache()
def test_fft_3d():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0])
@@ -90,6 +96,8 @@ def test_fft_3d():
vd.vkfft.clear_plan_cache()
def test_ifft_1d():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0])
@@ -114,6 +122,8 @@ def test_ifft_1d():
vd.vkfft.clear_plan_cache()
def test_ifft_2d():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0])
@@ -137,6 +147,8 @@ def test_ifft_2d():
vd.vkfft.clear_plan_cache()
def test_ifft_3d():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0])
@@ -160,6 +172,8 @@ def test_ifft_3d():
vd.vkfft.clear_plan_cache()
def test_rfft_1d():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0])
@@ -183,6 +197,8 @@ def test_rfft_1d():
vd.vkfft.clear_plan_cache()
def test_rfft_2d():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0])
@@ -206,6 +222,8 @@ def test_rfft_2d():
vd.vkfft.clear_plan_cache()
def test_rfft_3d():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0])
@@ -229,6 +247,8 @@ def test_rfft_3d():
vd.vkfft.clear_plan_cache()
def test_irfft_1d():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0])
@@ -252,6 +272,8 @@ def test_irfft_1d():
vd.vkfft.clear_plan_cache()
def test_irfft_2d():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0])
@@ -275,6 +297,8 @@ def test_irfft_2d():
vd.vkfft.clear_plan_cache()
def test_irfft_3d():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0])
diff --git a/tests/test_vkfft_conv.py b/tests/test_vkfft_conv.py
index cc56d7eb..a4404c80 100644
--- a/tests/test_vkfft_conv.py
+++ b/tests/test_vkfft_conv.py
@@ -30,6 +30,8 @@ def check_fft_dims(fft_dims: List[int], max_fft_size: int):
return all([dim <= max_fft_size for dim in fft_dims]) and np.prod(fft_dims) * vd.complex64.item_size < 2 ** 29
def test_convolution_2d_powers_of_2():
+ if not vd.is_vulkan():
+ return
max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size
buffer_cache = {}
diff --git a/vkdispatch/__init__.py b/vkdispatch/__init__.py
index 7f6e2229..27e99e2a 100644
--- a/vkdispatch/__init__.py
+++ b/vkdispatch/__init__.py
@@ -1,26 +1,42 @@
from .base.init import DeviceInfo
from .base.init import LogLevel
from .base.init import get_devices
-from .base.init import get_backend
+from .base.init import get_backend, is_vulkan, is_cuda, is_opencl, is_dummy
from .base.init import initialize
from .base.init import is_initialized
from .base.init import log, log_error, log_warning, log_info, log_verbose, set_log_level
from .base.dtype import dtype
-from .base.dtype import float32, int32, uint32, complex64
-from .base.dtype import vec2, vec3, vec4, ivec2, ivec3, ivec4, uvec2, uvec3, uvec4
+from .base.dtype import float16, float32, float64, int16, uint16, int32, uint32, int64, uint64
+from .base.dtype import complex32, complex64, complex128
+from .base.dtype import hvec2, hvec3, hvec4
+from .base.dtype import vec2, vec3, vec4
+from .base.dtype import dvec2, dvec3, dvec4
+from .base.dtype import ihvec2, ihvec3, ihvec4
+from .base.dtype import ivec2, ivec3, ivec4
+from .base.dtype import uhvec2, uhvec3, uhvec4
+from .base.dtype import uvec2, uvec3, uvec4
from .base.dtype import mat2, mat3, mat4
from .base.context import get_context, queue_wait_idle, Signal
from .base.context import get_context_handle
-from .base.context import make_context, select_queue_families
+from .base.context import make_context, select_queue_families, set_dummy_context_params
from .base.context import is_context_initialized
from .base.buffer import asbuffer
-from .base.buffer import Buffer, buffer_u32, buffer_i32, buffer_f32, buffer_c64
+from .base.buffer import from_cuda_array
+from .base.buffer import Buffer
from .base.buffer import asrfftbuffer
from .base.buffer import RFFTBuffer
+from .base.buffer_allocators import buffer_u32, buffer_uv2, buffer_uv3, buffer_uv4
+from .base.buffer_allocators import buffer_i32, buffer_iv2, buffer_iv3, buffer_iv4
+from .base.buffer_allocators import buffer_f32, buffer_v2, buffer_v3, buffer_v4, buffer_c64
+from .base.buffer_allocators import buffer_u16, buffer_uhv2, buffer_uhv3, buffer_uhv4
+from .base.buffer_allocators import buffer_i16, buffer_ihv2, buffer_ihv3, buffer_ihv4
+from .base.buffer_allocators import buffer_f16, buffer_hv2, buffer_hv3, buffer_hv4
+from .base.buffer_allocators import buffer_f64, buffer_dv2, buffer_dv3, buffer_dv4
+
from .base.image import image_format
from .base.image import image_type
from .base.image import image_view_type
@@ -36,8 +52,9 @@
from .execution_pipeline.command_graph import CommandGraph, BufferBindInfo, ImageBindInfo
from .execution_pipeline.command_graph import global_graph, set_global_graph, default_graph
+from .execution_pipeline.cuda_graph_capture import cuda_graph_capture, get_cuda_capture, CUDAGraphCapture
-from .shader.shader_function import ShaderFunction
+from .shader.shader_function import ShaderFunction, ShaderSource
from .shader.context import ShaderContext, shader_context
from .shader.map import map, MappingFunction
from .shader.decorator import shader
@@ -46,4 +63,4 @@
import vkdispatch.fft as fft
import vkdispatch.reduce as reduce
-__version__ = "0.0.30"
+__version__ = "0.0.34"
diff --git a/vkdispatch/backends/backend_selection.py b/vkdispatch/backends/backend_selection.py
new file mode 100644
index 00000000..6a3836b9
--- /dev/null
+++ b/vkdispatch/backends/backend_selection.py
@@ -0,0 +1,129 @@
+from __future__ import annotations
+
+import importlib
+from types import ModuleType
+from typing import Dict, Optional
+
+import os
+
+BACKEND_VULKAN = "vulkan"
+BACKEND_CUDA = "cuda"
+BACKEND_OPENCL = "opencl"
+BACKEND_DUMMY = "dummy"
+
+_VALID_BACKENDS = {BACKEND_VULKAN, BACKEND_CUDA, BACKEND_OPENCL, BACKEND_DUMMY}
+_active_backend_name: Optional[str] = None
+_backend_modules: Dict[str, ModuleType] = {}
+
+
+class BackendUnavailableError(ImportError):
+ def __init__(self, backend_name: str, message: str):
+ super().__init__(message)
+ self.backend_name = backend_name
+
+
+def normalize_backend_name(backend: Optional[str]) -> str:
+ if backend is None:
+ return BACKEND_VULKAN
+
+ backend_name = backend.strip().lower()
+ if backend_name not in _VALID_BACKENDS:
+ valid = ", ".join(sorted(_VALID_BACKENDS))
+ raise ValueError(f"Unknown backend '{backend}'. Expected one of: {valid}")
+
+ return backend_name
+
+
+def set_active_backend(backend: str) -> str:
+ global _active_backend_name
+
+ backend_name = normalize_backend_name(backend)
+
+ if _active_backend_name is not None and _active_backend_name != backend_name:
+ raise RuntimeError(
+ f"Backend is already set to '{_active_backend_name}' and cannot be changed to '{backend_name}' in this process."
+ )
+
+ _active_backend_name = backend_name
+ return _active_backend_name
+
+
+def clear_active_backend() -> None:
+ global _active_backend_name
+ _active_backend_name = None
+
+def get_environment_backend() -> Optional[str]:
+ env_backend = os.environ.get("VKDISPATCH_BACKEND")
+ if env_backend is not None:
+ return normalize_backend_name(env_backend)
+ return None
+
+def get_active_backend_name(default: Optional[str] = None) -> str:
+ if _active_backend_name is not None:
+ return _active_backend_name
+
+ if default is not None:
+ return normalize_backend_name(default)
+
+ env_backend = get_environment_backend()
+
+ if env_backend is not None:
+ return env_backend
+
+ return BACKEND_VULKAN
+
+
+def _load_backend_module(backend_name: str) -> ModuleType:
+ if backend_name in _backend_modules:
+ return _backend_modules[backend_name]
+
+ try:
+ if backend_name == BACKEND_VULKAN:
+ module = importlib.import_module("vkdispatch_vulkan_native")
+ elif backend_name == BACKEND_CUDA:
+ module = importlib.import_module("vkdispatch.backends.cuda_backend")
+ elif backend_name == BACKEND_OPENCL:
+ module = importlib.import_module("vkdispatch.backends.opencl_backend")
+ elif backend_name == BACKEND_DUMMY:
+ module = importlib.import_module("vkdispatch.backends.dummy_backend")
+ else:
+ # Defensive guard for future refactors.
+ raise ValueError(f"Unsupported backend '{backend_name}'")
+ except ImportError as exc:
+ if backend_name == BACKEND_VULKAN:
+ raise BackendUnavailableError(
+ backend_name,
+ "Vulkan backend is unavailable because the 'vkdispatch_native' package "
+ f"could not be imported ({exc}).",
+ ) from exc
+ if backend_name == BACKEND_CUDA:
+ raise BackendUnavailableError(
+ backend_name,
+ "CUDA Python backend is unavailable because the "
+ "'vkdispatch.backends.cuda_backend' module could not be imported "
+ f"({exc}).",
+ ) from exc
+ if backend_name == BACKEND_OPENCL:
+ raise BackendUnavailableError(
+ backend_name,
+ "OpenCL backend is unavailable because the "
+ "'vkdispatch.backends.opencl_backend' module could not be imported "
+ f"({exc}).",
+ ) from exc
+ raise
+
+ _backend_modules[backend_name] = module
+ return module
+
+
+def get_backend_module(backend: Optional[str] = None) -> ModuleType:
+ backend_name = normalize_backend_name(backend) if backend is not None else get_active_backend_name()
+ return _load_backend_module(backend_name)
+
+
+class _BackendProxy:
+ def __getattr__(self, name: str):
+ return getattr(get_backend_module(), name)
+
+
+native = _BackendProxy()
diff --git a/vkdispatch/backends/cuda_backend/__init__.py b/vkdispatch/backends/cuda_backend/__init__.py
new file mode 100644
index 00000000..a4bf6927
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/__init__.py
@@ -0,0 +1,111 @@
+"""cuda-python-backed runtime shim mirroring the vkdispatch_native API surface.
+
+This module intentionally matches the function names exposed by the Cython
+extension so existing Python runtime objects can call into either backend.
+"""
+
+from __future__ import annotations
+
+from .api_buffer import (
+ buffer_create,
+ buffer_create_external,
+ buffer_destroy,
+ buffer_get_queue_signal,
+ buffer_read,
+ buffer_read_staging,
+ buffer_wait_staging_idle,
+ buffer_write,
+ buffer_write_staging,
+)
+from .api_command_list import (
+ command_list_create,
+ command_list_destroy,
+ command_list_get_instance_size,
+ command_list_reset,
+ command_list_submit,
+ stage_compute_record
+)
+from .api_compute import (
+ stage_compute_plan_create,
+ stage_compute_plan_destroy,
+)
+from .api_context import (
+ context_create,
+ context_destroy,
+ context_stop_threads,
+ cuda_stream_override_begin,
+ cuda_stream_override_end,
+ get_devices,
+ get_error_string,
+ init,
+ log,
+ set_log_level,
+)
+from .descriptor_sets import (
+ descriptor_set_create,
+ descriptor_set_destroy,
+ descriptor_set_write_buffer,
+ descriptor_set_write_image,
+ descriptor_set_write_inline_uniform,
+)
+from .image_fft_stubs import (
+ image_create,
+ image_create_sampler,
+ image_destroy,
+ image_destroy_sampler,
+ image_format_block_size,
+ image_read,
+ image_write,
+ stage_fft_plan_create,
+ stage_fft_plan_destroy,
+ stage_fft_record,
+)
+from .signal import signal_destroy, signal_insert, signal_wait
+
+__all__ = [
+ "init",
+ "log",
+ "set_log_level",
+ "get_devices",
+ "context_create",
+ "signal_wait",
+ "signal_insert",
+ "signal_destroy",
+ "context_destroy",
+ "get_error_string",
+ "context_stop_threads",
+ "cuda_stream_override_begin",
+ "cuda_stream_override_end",
+ "buffer_create",
+ "buffer_create_external",
+ "buffer_destroy",
+ "buffer_get_queue_signal",
+ "buffer_wait_staging_idle",
+ "buffer_write_staging",
+ "buffer_read_staging",
+ "buffer_write",
+ "buffer_read",
+ "command_list_create",
+ "command_list_destroy",
+ "command_list_get_instance_size",
+ "command_list_reset",
+ "command_list_submit",
+ "descriptor_set_create",
+ "descriptor_set_destroy",
+ "descriptor_set_write_buffer",
+ "descriptor_set_write_image",
+ "descriptor_set_write_inline_uniform",
+ "image_create",
+ "image_destroy",
+ "image_create_sampler",
+ "image_destroy_sampler",
+ "image_write",
+ "image_format_block_size",
+ "image_read",
+ "stage_compute_plan_create",
+ "stage_compute_plan_destroy",
+ "stage_compute_record",
+ "stage_fft_plan_create",
+ "stage_fft_plan_destroy",
+ "stage_fft_record",
+]
diff --git a/vkdispatch/backends/cuda_backend/api_buffer.py b/vkdispatch/backends/cuda_backend/api_buffer.py
new file mode 100644
index 00000000..3502fe96
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/api_buffer.py
@@ -0,0 +1,238 @@
+from __future__ import annotations
+
+from . import state as state
+from .cuda_primitives import cuda
+from .helpers import (
+ activate_context,
+ allocate_staging_storage,
+ buffer_device_ptr,
+ context_from_handle,
+ new_handle,
+ queue_indices,
+ set_error,
+ stream_for_queue,
+ to_bytes,
+)
+from .state import CUDABuffer
+
+from .signal import CUDASignal, signal_destroy
+
+def buffer_create(context, size, per_device):
+ _ = per_device
+
+ ctx = context_from_handle(int(context))
+ if ctx is None:
+ return 0
+
+ size = int(size)
+ if size <= 0:
+ set_error("Buffer size must be greater than zero")
+ return 0
+
+ try:
+ with activate_context(ctx):
+ allocation = cuda.mem_alloc(size)
+
+ signal_handles = [
+ CUDASignal(context_handle=int(context), queue_index=i, done=True).handle
+ for i in range(ctx.queue_count)
+ ]
+
+ obj = CUDABuffer(
+ context_handle=int(context),
+ size=size,
+ device_ptr=int(allocation),
+ device_allocation=allocation,
+ owns_allocation=True,
+ staging_data=[allocate_staging_storage(size) for _ in range(ctx.queue_count)],
+ signal_handles=signal_handles,
+ )
+ return new_handle(state.buffers, obj)
+ except Exception as exc:
+ set_error(f"Failed to create CUDA buffer: {exc}")
+ return 0
+
+
+def buffer_create_external(context, size, device_ptr):
+ ctx = context_from_handle(int(context))
+ if ctx is None:
+ return 0
+
+ size = int(size)
+ device_ptr = int(device_ptr)
+
+ if size <= 0:
+ set_error("External buffer size must be greater than zero")
+ return 0
+
+ if device_ptr == 0:
+ set_error("External buffer device pointer must be non-zero")
+ return 0
+
+ try:
+ signal_handles = [
+ CUDASignal(context_handle=int(context), queue_index=i, done=True).handle
+ for i in range(ctx.queue_count)
+ ]
+
+ obj = CUDABuffer(
+ context_handle=int(context),
+ size=size,
+ device_ptr=device_ptr,
+ device_allocation=None,
+ owns_allocation=False,
+ staging_data=[allocate_staging_storage(size) for _ in range(ctx.queue_count)],
+ signal_handles=signal_handles,
+ )
+ return new_handle(state.buffers, obj)
+ except Exception as exc:
+ set_error(f"Failed to create external CUDA buffer alias: {exc}")
+ return 0
+
+
+def buffer_destroy(buffer):
+ obj = state.buffers.pop(int(buffer), None)
+ if obj is None:
+ return
+
+ for signal_handle in obj.signal_handles:
+ signal_destroy(signal_handle)
+
+ ctx = state.contexts.get(obj.context_handle)
+ if ctx is None or not obj.owns_allocation or obj.device_allocation is None:
+ return
+
+ try:
+ with activate_context(ctx):
+ obj.device_allocation.free()
+ except Exception:
+ pass
+
+
+def buffer_get_queue_signal(buffer, queue_index):
+ obj = state.buffers.get(int(buffer))
+ if obj is None:
+ return CUDASignal(context_handle=0, queue_index=0, done=True).handle
+
+ queue_index = int(queue_index)
+ if queue_index < 0 or queue_index >= len(obj.signal_handles):
+ queue_index = 0
+
+ return obj.signal_handles[queue_index]
+
+
+def buffer_wait_staging_idle(buffer, queue_index):
+ signal_handle = buffer_get_queue_signal(buffer, queue_index)
+ signal_obj = CUDASignal.from_handle(signal_handle)
+ if signal_obj is None:
+ return True
+ return signal_obj.query()
+
+
+def buffer_write_staging(buffer, queue_index, data, size):
+ obj = state.buffers.get(int(buffer))
+ if obj is None:
+ return
+
+ queue_index = int(queue_index)
+ if queue_index < 0 or queue_index >= len(obj.staging_data):
+ return
+
+ payload = to_bytes(data)
+ size = min(int(size), len(payload), obj.size)
+ if size <= 0:
+ return
+
+ payload_view = memoryview(payload)[:size]
+ staging_view = memoryview(obj.staging_data[queue_index])
+ staging_view[:size] = payload_view
+
+
+def buffer_read_staging(buffer, queue_index, size):
+ obj = state.buffers.get(int(buffer))
+ if obj is None:
+ return bytes(int(size))
+
+ queue_index = int(queue_index)
+ if queue_index < 0 or queue_index >= len(obj.staging_data):
+ return bytes(int(size))
+
+ size = max(0, int(size))
+ staging = obj.staging_data[queue_index]
+
+ if size <= len(staging):
+ return bytes(staging[:size])
+
+ return bytes(staging) + bytes(size - len(staging))
+
+
+def buffer_write(buffer, offset, size, index):
+ obj = state.buffers.get(int(buffer))
+ if obj is None:
+ return
+
+ ctx = state.contexts.get(obj.context_handle)
+ if ctx is None:
+ set_error(f"Missing context for buffer handle {buffer}")
+ return
+
+ offset = int(offset)
+ size = int(size)
+ if size <= 0 or offset < 0:
+ return
+
+ try:
+ with activate_context(ctx):
+ for queue_index in queue_indices(ctx, int(index), all_on_negative=True):
+ stream = stream_for_queue(ctx, queue_index)
+ end = min(offset + size, obj.size)
+ copy_size = end - offset
+ if copy_size <= 0:
+ continue
+
+ src_view = memoryview(obj.staging_data[queue_index])[:copy_size]
+ cuda.memcpy_htod_async(buffer_device_ptr(obj) + offset, src_view, stream)
+
+ signal = CUDASignal.from_handle(obj.signal_handles[queue_index])
+ if signal is not None:
+ signal.record(stream)
+ except Exception as exc:
+ set_error(f"Failed to write CUDA buffer: {exc}")
+
+
+def buffer_read(buffer, offset, size, index):
+ obj = state.buffers.get(int(buffer))
+ if obj is None:
+ return
+
+ ctx = state.contexts.get(obj.context_handle)
+ if ctx is None:
+ set_error(f"Missing context for buffer handle {buffer}")
+ return
+
+ queue_index = int(index)
+ if queue_index < 0 or queue_index >= ctx.queue_count:
+ set_error(f"Invalid queue index {queue_index} for buffer read")
+ return
+
+ offset = int(offset)
+ size = int(size)
+ if size <= 0 or offset < 0:
+ return
+
+ try:
+ with activate_context(ctx):
+ stream = stream_for_queue(ctx, queue_index)
+ end = min(offset + size, obj.size)
+ copy_size = end - offset
+ if copy_size <= 0:
+ return
+
+ dst_view = memoryview(obj.staging_data[queue_index])[:copy_size]
+ cuda.memcpy_dtoh_async(dst_view, buffer_device_ptr(obj) + offset, stream)
+
+ signal = CUDASignal.from_handle(obj.signal_handles[queue_index])
+ if signal is not None:
+ signal.record(stream)
+ except Exception as exc:
+ set_error(f"Failed to read CUDA buffer: {exc}")
diff --git a/vkdispatch/backends/cuda_backend/api_command_list.py b/vkdispatch/backends/cuda_backend/api_command_list.py
new file mode 100644
index 00000000..8c80c102
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/api_command_list.py
@@ -0,0 +1,205 @@
+from __future__ import annotations
+
+from typing import List, Optional, Tuple
+
+from . import state as state
+from .helpers import (
+ activate_context,
+ build_kernel_args_template,
+ estimate_kernel_param_size_bytes,
+ new_handle,
+ queue_indices,
+ set_error,
+ stream_for_queue,
+ to_bytes,
+)
+from .state import CUDACommandList, CUDAComputePlan, CUDACommandRecord
+
+from .descriptor_sets import CUDADescriptorSet
+
+import dataclasses
+
+@dataclasses.dataclass
+class CUDAResolvedLaunch:
+ plan: CUDAComputePlan
+ blocks: Tuple[int, int, int]
+ descriptor_set: Optional[CUDADescriptorSet]
+ pc_size: int
+ pc_offset: int
+ static_args: Optional[Tuple[object, ...]] = None
+
+def command_list_create(context):
+ if int(context) not in state.contexts:
+ set_error("Invalid context handle for command_list_create")
+ return 0
+
+ return new_handle(state.command_lists, CUDACommandList(context_handle=int(context)))
+
+
+def command_list_destroy(command_list):
+ obj = state.command_lists.pop(int(command_list), None)
+ if obj is None:
+ return
+
+ ctx = state.contexts.get(obj.context_handle)
+ if ctx is None:
+ return
+
+
+def command_list_get_instance_size(command_list):
+ obj = state.command_lists.get(int(command_list))
+ if obj is None:
+ return 0
+
+ return int(sum(int(command.pc_size) for command in obj.commands))
+
+
+def command_list_reset(command_list):
+ obj = state.command_lists.get(int(command_list))
+ if obj is None:
+ return
+
+ obj.commands = []
+
+
+def command_list_submit(command_list, data, instance_count, index):
+ obj = state.command_lists.get(int(command_list))
+ if obj is None:
+ return True
+
+ ctx = state.contexts.get(obj.context_handle)
+ if ctx is None:
+ set_error(f"Missing context for command list {command_list}")
+ return True
+
+ instance_count = int(instance_count)
+ if instance_count <= 0:
+ return True
+
+ instance_size = command_list_get_instance_size(command_list)
+ payload = to_bytes(data)
+ expected_payload_size = int(instance_size) * int(instance_count)
+
+ if expected_payload_size == 0:
+ if len(payload) != 0:
+ set_error(
+ f"Unexpected push-constant data for command list with instance_size=0 "
+ f"(got {len(payload)} bytes)."
+ )
+ return True
+ elif len(payload) != expected_payload_size:
+ set_error(
+ f"Push-constant data size mismatch. Expected {expected_payload_size} bytes "
+ f"(instance_size={instance_size}, instance_count={instance_count}) but got {len(payload)} bytes."
+ )
+ return True
+
+ queue_targets = queue_indices(ctx, int(index), all_on_negative=True)
+ if len(queue_targets) == 0:
+ queue_targets = [0]
+
+ try:
+ with activate_context(ctx):
+ for queue_index in queue_targets:
+ stream = stream_for_queue(ctx, queue_index)
+ resolved_launches: List[CUDAResolvedLaunch] = []
+ per_instance_offset = 0
+
+ for command in obj.commands:
+ plan = state.compute_plans.get(command.plan_handle)
+ if plan is None:
+ raise RuntimeError(f"Invalid compute plan handle {command.plan_handle}")
+
+ descriptor_set = None
+ if command.descriptor_set_handle != 0:
+ descriptor_set = CUDADescriptorSet.from_handle(command.descriptor_set_handle)
+ if descriptor_set is None:
+ raise RuntimeError(
+ f"Invalid descriptor set handle {command.descriptor_set_handle}"
+ )
+
+ command_pc_size = int(command.pc_size)
+ first_instance_payload = b""
+ if command_pc_size > 0 and len(payload) > 0:
+ first_instance_payload = payload[per_instance_offset: per_instance_offset + command_pc_size]
+
+ static_args = None
+ if command_pc_size == 0:
+ static_args = build_kernel_args_template(plan, descriptor_set, b"")
+ size_check_args = static_args
+ else:
+ size_check_args = build_kernel_args_template(
+ plan,
+ descriptor_set,
+ first_instance_payload,
+ )
+
+ estimated_param_size = estimate_kernel_param_size_bytes(size_check_args)
+ if estimated_param_size > int(ctx.max_kernel_param_size):
+ shader_name = plan.shader_name.decode("utf-8", errors="replace")
+ raise RuntimeError(
+ f"Kernel '{shader_name}' launch parameters require "
+ f"{estimated_param_size} bytes, exceeding device limit "
+ f"{ctx.max_kernel_param_size} bytes. "
+ "Reduce by-value uniform/push-constant payload size or switch large "
+ "uniform data to buffer-backed arguments."
+ )
+ resolved_launches.append(
+ CUDAResolvedLaunch(
+ plan=plan,
+ blocks=command.blocks,
+ descriptor_set=descriptor_set,
+ pc_size=command_pc_size,
+ pc_offset=per_instance_offset,
+ static_args=static_args,
+ )
+ )
+ per_instance_offset += command_pc_size
+
+ if per_instance_offset != instance_size:
+ raise RuntimeError(
+ f"Internal command list size mismatch: computed {per_instance_offset} bytes, "
+ f"expected {instance_size} bytes."
+ )
+
+ for instance_index in range(instance_count):
+ instance_base_offset = instance_index * instance_size
+ for launch in resolved_launches:
+ if launch.static_args is not None:
+ args = launch.static_args
+ else:
+ pc_start = instance_base_offset + launch.pc_offset
+ pc_end = pc_start + launch.pc_size
+ pc_payload = payload[pc_start:pc_end]
+ args = build_kernel_args_template(
+ launch.plan,
+ launch.descriptor_set,
+ pc_payload,
+ )
+
+ launch.plan.function(
+ *args,
+ block=launch.plan.local_size,
+ grid=launch.blocks,
+ stream=stream,
+ )
+ except Exception as exc:
+ set_error(f"Failed to submit CUDA command list: {exc}")
+
+ return True
+
+def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z):
+ cl = state.command_lists.get(int(command_list))
+ cp = state.compute_plans.get(int(plan))
+ if cl is None or cp is None:
+ set_error("Invalid command list or compute plan handle for stage_compute_record")
+ return
+
+ cl.commands.append(
+ CUDACommandRecord(
+ plan_handle=int(plan),
+ descriptor_set_handle=int(descriptor_set),
+ blocks=(int(blocks_x), int(blocks_y), int(blocks_z)),
+ pc_size=int(cp.pc_size),
+ )
+ )
diff --git a/vkdispatch/backends/cuda_backend/api_compute.py b/vkdispatch/backends/cuda_backend/api_compute.py
new file mode 100644
index 00000000..8db48b43
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/api_compute.py
@@ -0,0 +1,76 @@
+from __future__ import annotations
+
+from . import state as state
+from .cuda_primitives import SourceModule, cuda
+from .helpers import (
+ activate_context,
+ context_from_handle,
+ new_handle,
+ parse_kernel_params,
+ parse_local_size,
+ set_error,
+ to_bytes,
+)
+from .state import CUDAComputePlan
+
+
+def _nvrtc_compile_options(ctx):
+ options = ["-w"]
+
+ try:
+ dev = cuda.Device(ctx.device_index)
+ cc_major, cc_minor = dev.compute_capability()
+ options.append(f"--gpu-architecture=sm_{int(cc_major)}{int(cc_minor)}")
+ except Exception:
+ pass
+
+ return options
+
+
+def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name):
+ ctx = context_from_handle(int(context))
+ if ctx is None:
+ return 0
+
+ source_bytes = to_bytes(shader_source)
+ shader_name_bytes = to_bytes(shader_name)
+ source_text = source_bytes.decode("utf-8", errors="replace")
+
+ try:
+ with activate_context(ctx):
+ module = SourceModule(
+ source_text,
+ no_extern_c=True,
+ options=_nvrtc_compile_options(ctx),
+ )
+ function = module.get_function("vkdispatch_main")
+ except Exception as exc:
+ set_error(f"Failed to compile CUDA kernel '{shader_name_bytes.decode(errors='ignore')}': {exc}")
+ return 0
+
+ try:
+ params = parse_kernel_params(source_text)
+ local_size = parse_local_size(source_text)
+ except Exception as exc:
+ set_error(f"Failed to parse CUDA kernel metadata: {exc}")
+ return 0
+
+ plan = CUDAComputePlan(
+ context_handle=int(context),
+ shader_source=source_bytes,
+ bindings=[int(x) for x in bindings],
+ shader_name=shader_name_bytes,
+ module=module,
+ function=function,
+ local_size=local_size,
+ params=params,
+ pc_size=int(pc_size),
+ )
+
+ return new_handle(state.compute_plans, plan)
+
+
+def stage_compute_plan_destroy(plan):
+ if plan is None:
+ return
+ state.compute_plans.pop(int(plan), None)
diff --git a/vkdispatch/backends/cuda_backend/api_context.py b/vkdispatch/backends/cuda_backend/api_context.py
new file mode 100644
index 00000000..7232b2c5
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/api_context.py
@@ -0,0 +1,250 @@
+from __future__ import annotations
+
+import hashlib
+
+from . import state as state
+from .cuda_primitives import cuda
+from .helpers import (
+ activate_context,
+ clear_error,
+ coerce_stream_handle,
+ new_handle,
+ query_max_kernel_param_size,
+ set_error,
+ stream_override_stack,
+)
+from .state import CUDAContext
+
+
+def init(debug, log_level):
+ state.debug_mode = bool(debug)
+ state.log_level = int(log_level)
+ clear_error()
+
+ if state.initialized:
+ return
+
+ cuda.init()
+ state.initialized = True
+
+
+def log(log_level, text, file_str, line_str):
+ _ = log_level
+ _ = text
+ _ = file_str
+ _ = line_str
+
+
+def set_log_level(log_level):
+ state.log_level = int(log_level)
+
+
+def get_devices():
+ if not state.initialized:
+ init(False, state.log_level)
+
+ try:
+ device_count = cuda.Device.count()
+ except Exception as exc:
+ set_error(f"Failed to enumerate CUDA devices: {exc}")
+ return []
+
+ driver_version = 0
+ try:
+ driver_version = int(cuda.get_driver_version())
+ except Exception:
+ driver_version = 0
+
+ devices = []
+
+ for index in range(device_count):
+ dev = cuda.Device(index)
+ attrs = dev.get_attributes()
+ cc_major, cc_minor = dev.compute_capability()
+ total_memory = int(dev.total_memory())
+
+ max_workgroup_size = (
+ int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_X, 0)),
+ int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_Y, 0)),
+ int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_Z, 0)),
+ )
+
+ max_workgroup_count = (
+ int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_X, 0)),
+ int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_Y, 0)),
+ int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_Z, 0)),
+ )
+
+ subgroup_size = int(attrs.get(cuda.device_attribute.WARP_SIZE, 0))
+ max_shared_memory = int(
+ attrs.get(cuda.device_attribute.MAX_SHARED_MEMORY_PER_BLOCK, 0)
+ )
+
+ try:
+ bus_id = str(dev.pci_bus_id())
+ except Exception:
+ bus_id = f"cuda-device-{index}"
+
+ uuid_bytes = hashlib.md5(bus_id.encode("utf-8")).digest()
+
+ devices.append(
+ (
+ 0, # Vulkan variant
+ int(cc_major), # major
+ int(cc_minor), # minor
+ 0, # patch
+ driver_version,
+ 0, # vendor id unknown in this API layer
+ index, # device id
+ 2, # discrete gpu
+ str(dev.name()),
+ 1, # shader_buffer_float32_atomics
+ 1, # shader_buffer_float32_atomic_add
+ 1, # float64 support
+ 1 if (cc_major > 5 or (cc_major == 5 and cc_minor >= 3)) else 0, # float16 support
+ 1, # int64
+ 1, # int16
+ 1, # storage_buffer_16_bit_access
+ 1, # uniform_and_storage_buffer_16_bit_access
+ 1, # storage_push_constant_16
+ 1, # storage_input_output_16
+ max_workgroup_size,
+ int(attrs.get(cuda.device_attribute.MAX_THREADS_PER_BLOCK, 0)),
+ max_workgroup_count,
+ 8, # max descriptor sets (virtualized for parity)
+ 4096, # max push constant size
+ min(total_memory, (1 << 31) - 1),
+ 65536,
+ 16,
+ subgroup_size,
+ 0x7FFFFFFF, # supported stages (virtualized for parity)
+ 0x7FFFFFFF, # supported operations (virtualized for parity)
+ 1,
+ max_shared_memory,
+ [(1, 0x002)], # compute queue
+ 1, # scalar block layout
+ 1, # timeline semaphores equivalent
+ uuid_bytes,
+ )
+ )
+
+ return devices
+
+
+def context_create(device_indicies, queue_families):
+ if not state.initialized:
+ init(False, state.log_level)
+
+ try:
+ device_ids = [int(x) for x in device_indicies]
+ except Exception:
+ set_error("context_create expected a list of integer device indices")
+ return 0
+
+ if len(device_ids) != 1:
+ set_error("CUDA Python backend currently supports exactly one device")
+ return 0
+
+ if len(queue_families) != 1 or len(queue_families[0]) != 1:
+ set_error("CUDA Python backend currently supports exactly one queue")
+ return 0
+
+ device_index = device_ids[0]
+
+ cuda_context = None
+ context_pushed = False
+
+ try:
+ if device_index < 0 or device_index >= cuda.Device.count():
+ set_error(f"Invalid CUDA device index {device_index}")
+ return 0
+
+ dev = cuda.Device(device_index)
+ cc_major, _cc_minor = dev.compute_capability()
+ max_kernel_param_size = query_max_kernel_param_size(dev.device_raw, cc_major)
+ uses_primary_context = False
+
+ if hasattr(dev, "retain_primary_context"):
+ cuda_context = dev.retain_primary_context()
+ uses_primary_context = True
+ cuda_context.push()
+ else: # pragma: no cover - fallback for older CUDA Python
+ cuda_context = dev.make_context()
+ context_pushed = True
+ stream = cuda.Stream()
+
+ ctx = CUDAContext(
+ device_index=device_index,
+ cuda_context=cuda_context,
+ streams=[stream],
+ queue_count=1,
+ queue_to_device=[0],
+ max_kernel_param_size=int(max_kernel_param_size),
+ uses_primary_context=uses_primary_context,
+ stopped=False,
+ )
+ handle = new_handle(state.contexts, ctx)
+
+ # Leave no context current after creation.
+ cuda.Context.pop()
+ context_pushed = False
+ return handle
+ except Exception as exc:
+ if context_pushed:
+ try:
+ cuda.Context.pop()
+ except Exception:
+ pass
+
+ if cuda_context is not None:
+ try:
+ cuda_context.detach()
+ except Exception:
+ pass
+
+ set_error(f"Failed to create CUDA Python context: {exc}")
+ return 0
+
+
+def context_destroy(context):
+ ctx = state.contexts.pop(int(context), None)
+ if ctx is None:
+ return
+
+ try:
+ with activate_context(ctx):
+ for stream in ctx.streams:
+ stream.synchronize()
+ except Exception:
+ pass
+
+ try:
+ ctx.cuda_context.detach()
+ except Exception:
+ pass
+
+
+def context_stop_threads(context):
+ ctx = state.contexts.get(int(context))
+ if ctx is not None:
+ ctx.stopped = True
+
+
+def get_error_string():
+ if state.error_string is None:
+ return 0
+ return state.error_string
+
+
+def cuda_stream_override_begin(stream_obj):
+ try:
+ stack = stream_override_stack()
+ stack.append(coerce_stream_handle(stream_obj))
+ except Exception as exc:
+ set_error(f"Failed to activate external CUDA stream override: {exc}")
+
+
+def cuda_stream_override_end():
+ stack = stream_override_stack()
+ if len(stack) > 0:
+ stack.pop()
diff --git a/vkdispatch/backends/cuda_backend/bindings.py b/vkdispatch/backends/cuda_backend/bindings.py
new file mode 100644
index 00000000..be7d82ee
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/bindings.py
@@ -0,0 +1,326 @@
+from __future__ import annotations
+
+import ctypes
+import importlib.util
+import os
+from pathlib import Path
+import shutil
+import sys
+from typing import List, Optional
+
+try:
+ import numpy as np
+except Exception as exc: # pragma: no cover - import failure path
+ raise ImportError(
+ "The CUDA Python backend requires both 'cuda-python' and 'numpy' to be installed."
+ ) from exc
+
+try:
+ from cuda.bindings import driver, nvrtc
+except Exception:
+ try:
+ from cuda import cuda as driver # type: ignore
+ from cuda import nvrtc # type: ignore
+ except Exception as exc: # pragma: no cover - import failure path
+ raise ImportError(
+ "The CUDA Python backend requires the NVIDIA cuda-python package "
+ "(`pip install cuda-python`)."
+ ) from exc
+
+
+def to_int(value) -> int:
+ if isinstance(value, int):
+ return int(value)
+
+ if hasattr(value, "value"):
+ try:
+ return int(value.value)
+ except Exception:
+ pass
+
+ return int(value)
+
+
+def drv_call(names, *args):
+ if isinstance(names, str):
+ names = [names]
+
+ last_error = None
+ for name in names:
+ fn = getattr(driver, name, None)
+ if fn is not None:
+ try:
+ return fn(*args)
+ except TypeError as exc:
+ last_error = exc
+ continue
+
+ if last_error is not None:
+ raise RuntimeError(f"CUDA Driver call failed for {names}: {last_error}") from last_error
+ raise RuntimeError(f"CUDA Driver symbol not found: {names}")
+
+
+def nvrtc_call(names, *args):
+ if isinstance(names, str):
+ names = [names]
+
+ last_error = None
+ for name in names:
+ fn = getattr(nvrtc, name, None)
+ if fn is not None:
+ try:
+ return fn(*args)
+ except TypeError as exc:
+ last_error = exc
+ continue
+
+ if last_error is not None:
+ raise RuntimeError(f"NVRTC call failed for {names}: {last_error}") from last_error
+ raise RuntimeError(f"NVRTC symbol not found: {names}")
+
+
+def status_success(status) -> bool:
+ try:
+ return to_int(status) == 0
+ except Exception:
+ return str(status).endswith("CUDA_SUCCESS") or str(status).endswith("NVRTC_SUCCESS")
+
+
+def drv_error_string(status) -> str:
+ try:
+ name_res = drv_call("cuGetErrorName", status)
+ string_res = drv_call("cuGetErrorString", status)
+ _name_status = name_res[0] if isinstance(name_res, tuple) else 1
+ _string_status = string_res[0] if isinstance(string_res, tuple) else 1
+ if status_success(_name_status) and status_success(_string_status):
+ name = name_res[1] if isinstance(name_res, tuple) and len(name_res) > 1 else name_res
+ text = string_res[1] if isinstance(string_res, tuple) and len(string_res) > 1 else string_res
+ if isinstance(name, (bytes, bytearray)):
+ name = name.decode("utf-8", errors="replace")
+ if isinstance(text, (bytes, bytearray)):
+ text = text.decode("utf-8", errors="replace")
+ return f"{name}: {text}"
+ except Exception:
+ pass
+
+ return str(status)
+
+
+def drv_check(result, op_name: str):
+ if isinstance(result, tuple):
+ status = result[0]
+ payload = result[1:]
+ else:
+ status = result
+ payload = ()
+
+ if not status_success(status):
+ raise RuntimeError(f"{op_name} failed ({drv_error_string(status)})")
+
+ if len(payload) == 0:
+ return None
+
+ if len(payload) == 1:
+ return payload[0]
+
+ return payload
+
+
+def nvrtc_check(result, op_name: str):
+ if isinstance(result, tuple):
+ status = result[0]
+ payload = result[1:]
+ else:
+ status = result
+ payload = ()
+
+ if not status_success(status):
+ raise RuntimeError(f"{op_name} failed ({status})")
+
+ if len(payload) == 0:
+ return None
+
+ if len(payload) == 1:
+ return payload[0]
+
+ return payload
+
+
+def nvrtc_read_bytes(program, size_api: str, read_api: str) -> bytes:
+ raw_size = nvrtc_check(nvrtc_call(size_api, program), size_api)
+ size = int(to_int(raw_size))
+ if size <= 0:
+ return b""
+
+ def _normalize_output(data) -> Optional[bytes]:
+ if data is None:
+ return None
+
+ if isinstance(data, memoryview):
+ data = data.tobytes()
+ elif isinstance(data, str):
+ data = data.encode("utf-8", errors="replace")
+
+ if isinstance(data, (bytes, bytearray)):
+ raw = bytes(data)
+ if len(raw) >= size:
+ return raw[:size]
+ return raw + (b"\x00" * (size - len(raw)))
+
+ if isinstance(data, (tuple, list)):
+ for item in data:
+ normalized = _normalize_output(item)
+ if normalized is not None:
+ return normalized
+
+ return None
+
+ try:
+ direct_data = nvrtc_check(nvrtc_call(read_api, program), read_api)
+ normalized = _normalize_output(direct_data)
+ if normalized is not None:
+ return normalized
+ except Exception:
+ pass
+
+ out_c = ctypes.create_string_buffer(size)
+ out_bytearray = bytearray(size)
+ out_bytes = bytes(size)
+
+ for out_candidate in (out_bytes, out_bytearray, out_c):
+ try:
+ call_result = nvrtc_check(nvrtc_call(read_api, program, out_candidate), read_api)
+ normalized_result = _normalize_output(call_result)
+ if normalized_result is not None:
+ return normalized_result
+
+ if isinstance(out_candidate, bytearray):
+ return bytes(out_candidate)
+
+ if out_candidate is out_c:
+ return bytes(out_c.raw)
+ except Exception:
+ continue
+
+ return bytes(out_c.raw)
+
+
+def discover_cuda_include_dirs() -> List[str]:
+ include_dirs: List[str] = []
+ seen = set()
+
+ def add_dir(path_like) -> None:
+ if path_like is None:
+ return
+ try:
+ resolved = str(Path(path_like).resolve())
+ except Exception:
+ resolved = str(path_like)
+ if resolved in seen:
+ return
+ header_path = Path(resolved) / "cuda_runtime.h"
+ if header_path.exists():
+ seen.add(resolved)
+ include_dirs.append(resolved)
+
+ # Standard CUDA environment variables.
+ for env_name in (
+ "CUDA_HOME",
+ "CUDA_PATH",
+ "CUDA_ROOT",
+ "CUDA_TOOLKIT_ROOT_DIR",
+ "CUDAToolkit_ROOT",
+ ):
+ root = os.environ.get(env_name)
+ if root:
+ add_dir(Path(root) / "include")
+
+ # CUDA toolkit from nvcc location.
+ nvcc_path = shutil.which("nvcc")
+ if nvcc_path:
+ try:
+ nvcc_root = Path(nvcc_path).resolve().parent.parent
+ add_dir(nvcc_root / "include")
+ except Exception:
+ pass
+
+ # Common Unix install locations.
+ add_dir("/usr/local/cuda/include")
+ add_dir("/opt/cuda/include")
+ add_dir("/usr/include")
+
+ # Conda cudatoolkit layouts.
+ conda_prefix = os.environ.get("CONDA_PREFIX")
+ if conda_prefix:
+ add_dir(Path(conda_prefix) / "include")
+ add_dir(Path(conda_prefix) / "targets" / "x86_64-linux" / "include")
+ add_dir(Path(conda_prefix) / "Library" / "include")
+
+ # NVIDIA pip wheel layout.
+ for base in sys.path:
+ add_dir(Path(base) / "nvidia" / "cuda_runtime" / "include")
+
+ # Some environments expose this namespace package.
+ try:
+ spec = importlib.util.find_spec("nvidia.cuda_runtime")
+ if spec is not None and spec.submodule_search_locations:
+ for entry in spec.submodule_search_locations:
+ add_dir(Path(entry) / "include")
+ except Exception:
+ pass
+
+ return include_dirs
+
+
+def prepare_nvrtc_options(options: List[bytes]) -> List[bytes]:
+ normalized: List[bytes] = []
+ has_include_path = False
+
+ for opt in options:
+ as_str = opt.decode("utf-8", errors="replace")
+ if as_str.startswith("-I") or as_str.startswith("--include-path"):
+ has_include_path = True
+ normalized.append(opt)
+
+ if not has_include_path:
+ for include_dir in discover_cuda_include_dirs():
+ normalized.append(f"--include-path={include_dir}".encode("utf-8"))
+
+ return normalized
+
+
+def as_driver_handle(type_name: str, value):
+ handle_type = getattr(driver, type_name, None)
+ if handle_type is None:
+ return value
+
+ try:
+ if isinstance(value, handle_type):
+ return value
+ except Exception:
+ pass
+
+ try:
+ return handle_type(to_int(value))
+ except Exception:
+ return value
+
+
+def writable_host_ptr(view: memoryview):
+ byte_view = view.cast("B")
+ try:
+ c_buffer = (ctypes.c_ubyte * len(byte_view)).from_buffer(byte_view)
+ return ctypes.addressof(c_buffer), c_buffer
+ except Exception:
+ copied = ctypes.create_string_buffer(byte_view.tobytes())
+ return ctypes.addressof(copied), copied
+
+
+def readonly_host_ptr(view: memoryview):
+ byte_view = view.cast("B")
+ try:
+ c_buffer = (ctypes.c_ubyte * len(byte_view)).from_buffer(byte_view)
+ return ctypes.addressof(c_buffer), c_buffer
+ except Exception:
+ copied = ctypes.create_string_buffer(byte_view.tobytes())
+ return ctypes.addressof(copied), copied
diff --git a/vkdispatch/backends/cuda_backend/constants.py b/vkdispatch/backends/cuda_backend/constants.py
new file mode 100644
index 00000000..246346be
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/constants.py
@@ -0,0 +1,23 @@
+from __future__ import annotations
+
+import re
+
+# Log level constants mirrored from native bindings.
+LOG_LEVEL_VERBOSE = 0
+LOG_LEVEL_INFO = 1
+LOG_LEVEL_WARNING = 2
+LOG_LEVEL_ERROR = 3
+
+# Descriptor type enum values mirrored from vkdispatch_native/stages_extern.pxd.
+DESCRIPTOR_TYPE_STORAGE_BUFFER = 1
+DESCRIPTOR_TYPE_STORAGE_IMAGE = 2
+DESCRIPTOR_TYPE_UNIFORM_BUFFER = 3
+DESCRIPTOR_TYPE_UNIFORM_IMAGE = 4
+DESCRIPTOR_TYPE_SAMPLER = 5
+
+LOCAL_X_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_X\s+(\d+)")
+LOCAL_Y_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Y\s+(\d+)")
+LOCAL_Z_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Z\s+(\d+)")
+KERNEL_SIGNATURE_RE = re.compile(r"vkdispatch_main\s*\(([^)]*)\)", re.S)
+BINDING_PARAM_RE = re.compile(r"vkdispatch_binding_(\d+)_ptr$")
+SAMPLER_PARAM_RE = re.compile(r"vkdispatch_sampler_(\d+)$")
diff --git a/vkdispatch/backends/cuda_backend/cuda_primitives.py b/vkdispatch/backends/cuda_backend/cuda_primitives.py
new file mode 100644
index 00000000..8a3af54a
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/cuda_primitives.py
@@ -0,0 +1,571 @@
+from __future__ import annotations
+
+import ctypes
+from dataclasses import dataclass
+from typing import List, Optional
+
+from .bindings import (
+ np,
+ driver,
+ as_driver_handle,
+ discover_cuda_include_dirs,
+ drv_call,
+ drv_check,
+ nvrtc_call,
+ nvrtc_check,
+ nvrtc_read_bytes,
+ prepare_nvrtc_options,
+ readonly_host_ptr,
+ status_success,
+ to_int,
+ writable_host_ptr,
+)
+
+
+@dataclass
+class _ByValueKernelArg:
+ payload: bytes
+ raw_name: str
+
+
+class _DeviceAllocation:
+ def __init__(self, ptr: int):
+ self.ptr = int(ptr)
+ self.freed = False
+
+ def __int__(self):
+ return int(self.ptr)
+
+ def free(self):
+ if self.freed:
+ return
+
+ drv_check(
+ drv_call(
+ ["cuMemFree", "cuMemFree_v2"],
+ as_driver_handle("CUdeviceptr", self.ptr),
+ ),
+ "cuMemFree",
+ )
+ self.freed = True
+
+
+class _ContextHandle:
+ def __init__(self, context_raw, device_index: int, uses_primary_context: bool):
+ self.context_raw = context_raw
+ self.device_index = int(device_index)
+ self.uses_primary_context = bool(uses_primary_context)
+ self._detached = False
+
+ def push(self):
+ drv_check(
+ drv_call(
+ "cuCtxPushCurrent",
+ as_driver_handle("CUcontext", self.context_raw),
+ ),
+ "cuCtxPushCurrent",
+ )
+
+ def detach(self):
+ if self._detached:
+ return
+
+ if self.uses_primary_context:
+ dev = drv_check(drv_call("cuDeviceGet", int(self.device_index)), "cuDeviceGet")
+ drv_check(drv_call("cuDevicePrimaryCtxRelease", dev), "cuDevicePrimaryCtxRelease")
+ else:
+ drv_check(
+ drv_call(
+ ["cuCtxDestroy", "cuCtxDestroy_v2"],
+ as_driver_handle("CUcontext", self.context_raw),
+ ),
+ "cuCtxDestroy",
+ )
+ self._detached = True
+
+
+class _StreamHandle:
+ def __init__(self, handle: Optional[int] = None, ptr: Optional[int] = None, *args, **kwargs):
+ _ = kwargs
+ if handle is None and ptr is None and len(args) == 1:
+ handle = int(args[0])
+ if handle is None and ptr is not None:
+ handle = int(ptr)
+
+ if handle is None:
+ stream_raw = drv_check(drv_call("cuStreamCreate", 0), "cuStreamCreate")
+ self.handle = int(to_int(stream_raw))
+ self.owned = True
+ else:
+ self.handle = int(handle)
+ self.owned = False
+
+ def synchronize(self):
+ drv_check(
+ drv_call(
+ "cuStreamSynchronize",
+ as_driver_handle("CUstream", self.handle),
+ ),
+ "cuStreamSynchronize",
+ )
+
+ def __int__(self):
+ return int(self.handle)
+
+ @property
+ def ptr(self):
+ return int(self.handle)
+
+ @property
+ def cuda_stream(self):
+ return int(self.handle)
+
+
+class _EventHandle:
+ def __init__(self):
+ self.event_raw = drv_check(drv_call("cuEventCreate", 0), "cuEventCreate")
+
+ def record(self, stream_obj: Optional["_StreamHandle"]):
+ stream_handle = 0 if stream_obj is None else int(stream_obj)
+ drv_check(
+ drv_call(
+ "cuEventRecord",
+ self.event_raw,
+ as_driver_handle("CUstream", stream_handle),
+ ),
+ "cuEventRecord",
+ )
+
+ def query(self) -> bool:
+ res = drv_call("cuEventQuery", self.event_raw)
+ status = res[0] if isinstance(res, tuple) else res
+
+ if status_success(status):
+ return True
+
+ status_text = str(status)
+ if "NOT_READY" in status_text:
+ return False
+
+ if to_int(status) != 0:
+ return False
+
+ return True
+
+ def synchronize(self):
+ drv_check(drv_call("cuEventSynchronize", self.event_raw), "cuEventSynchronize")
+
+
+class _KernelFunction:
+ def __init__(self, function_raw):
+ self.function_raw = function_raw
+
+ def __call__(self, *args, block, grid, stream=None):
+ arg_values = []
+
+ def _dedupe(values):
+ out = []
+ seen = set()
+ for value in values:
+ key = f"{type(value).__name__}:{repr(value)}"
+ if key in seen:
+ continue
+ seen.add(key)
+ out.append(value)
+ return out
+
+ arg_ptr_values = []
+ for arg in args:
+ if isinstance(arg, _ByValueKernelArg):
+ payload = arg.payload
+ if len(payload) == 0:
+ payload = b"\x00"
+
+ payload_storage = (ctypes.c_ubyte * len(payload)).from_buffer_copy(payload)
+ arg_values.append(payload_storage)
+ arg_ptr_values.append(ctypes.addressof(payload_storage))
+ continue
+
+ scalar_storage = ctypes.c_uint64(int(arg))
+ arg_values.append(scalar_storage)
+ arg_ptr_values.append(ctypes.addressof(scalar_storage))
+
+ arg_ptr_array = None
+ if len(arg_ptr_values) > 0:
+ arg_ptr_array = (ctypes.c_void_p * len(arg_ptr_values))(
+ *[ctypes.c_void_p(ptr) for ptr in arg_ptr_values]
+ )
+
+ kernel_param_variants = [None, 0, ctypes.c_void_p(0)]
+ if arg_ptr_array is not None:
+ array_ptr = ctypes.cast(arg_ptr_array, ctypes.POINTER(ctypes.c_void_p))
+ kernel_param_variants = _dedupe(
+ [
+ arg_ptr_array,
+ array_ptr,
+ ctypes.cast(array_ptr, ctypes.c_void_p),
+ ctypes.cast(array_ptr, ctypes.c_void_p).value,
+ tuple(arg_ptr_values),
+ list(arg_ptr_values),
+ ]
+ )
+
+ stream_handle = 0 if stream is None else int(stream)
+ stream_variants = _dedupe(
+ [
+ stream_handle,
+ as_driver_handle("CUstream", stream_handle),
+ ]
+ )
+
+ function_candidates = [
+ self.function_raw,
+ as_driver_handle("CUfunction", self.function_raw),
+ ]
+ try:
+ function_candidates.append(to_int(self.function_raw))
+ except Exception:
+ pass
+ function_variants = _dedupe(function_candidates)
+
+ extra_variants = [None, 0, ctypes.c_void_p(0)]
+ last_error = None
+
+ for function_handle in function_variants:
+ for stream_value in stream_variants:
+ for kernel_params in kernel_param_variants:
+ for extra in extra_variants:
+ try:
+ drv_check(
+ drv_call(
+ "cuLaunchKernel",
+ function_handle,
+ int(grid[0]),
+ int(grid[1]),
+ int(grid[2]),
+ int(block[0]),
+ int(block[1]),
+ int(block[2]),
+ 0,
+ stream_value,
+ kernel_params,
+ extra,
+ ),
+ "cuLaunchKernel",
+ )
+ return
+ except Exception as exc:
+ last_error = exc
+
+ try:
+ drv_check(
+ drv_call(
+ "cuLaunchKernel",
+ function_handle,
+ int(grid[0]),
+ int(grid[1]),
+ int(grid[2]),
+ int(block[0]),
+ int(block[1]),
+ int(block[2]),
+ 0,
+ stream_value,
+ kernel_params,
+ ),
+ "cuLaunchKernel",
+ )
+ return
+ except Exception as exc:
+ last_error = exc
+ continue
+
+ if last_error is None:
+ raise RuntimeError("cuLaunchKernel failed with no diagnostic.")
+ raise RuntimeError(f"cuLaunchKernel failed: {last_error}") from last_error
+
+
+class SourceModule:
+ def __init__(self, source: str, no_extern_c: bool = True, options: Optional[List[str]] = None):
+ _ = no_extern_c
+ if options is None:
+ options = []
+
+ program_name = b"vkdispatch.cu"
+ source_bytes = source.encode("utf-8")
+ program = nvrtc_check(
+ nvrtc_call(
+ "nvrtcCreateProgram",
+ source_bytes,
+ program_name,
+ 0,
+ [],
+ [],
+ ),
+ "nvrtcCreateProgram",
+ )
+
+ cubin = b""
+ ptx = b""
+ build_log = b""
+
+ try:
+ encoded_options = [opt.encode("utf-8") if isinstance(opt, str) else bytes(opt) for opt in options]
+ encoded_options = prepare_nvrtc_options(encoded_options)
+ compile_result = nvrtc_call("nvrtcCompileProgram", program, len(encoded_options), encoded_options)
+ compile_status = compile_result[0] if isinstance(compile_result, tuple) else compile_result
+
+ build_log = nvrtc_read_bytes(program, "nvrtcGetProgramLogSize", "nvrtcGetProgramLog")
+ if not status_success(compile_status):
+ clean_build_log = build_log.rstrip(b"\x00").decode("utf-8", errors="replace")
+ if 'could not open source file "cuda_runtime.h"' in clean_build_log:
+ discovered = discover_cuda_include_dirs()
+ hint = (
+ " NVRTC could not find CUDA headers. "
+ f"Discovered include dirs: {discovered if len(discovered) > 0 else 'none'}. "
+ "Set CUDA_HOME/CUDA_PATH to your toolkit root or ensure nvcc is on PATH."
+ )
+ else:
+ hint = ""
+ raise RuntimeError(
+ f"NVRTC compilation failed: {clean_build_log}{hint}"
+ )
+
+ try:
+ cubin = nvrtc_read_bytes(program, "nvrtcGetCUBINSize", "nvrtcGetCUBIN")
+ except Exception:
+ cubin = b""
+
+ if len(cubin) == 0:
+ try:
+ ptx = nvrtc_read_bytes(program, "nvrtcGetPTXSize", "nvrtcGetPTX")
+ except Exception:
+ ptx = b""
+ finally:
+ try:
+ nvrtc_check(nvrtc_call("nvrtcDestroyProgram", program), "nvrtcDestroyProgram")
+ except Exception:
+ pass
+
+ image_data = cubin
+ if len(image_data) == 0:
+ image_data = ptx
+
+ if len(image_data) == 0:
+ raise RuntimeError("NVRTC compilation succeeded but produced neither a CUBIN nor a PTX payload.")
+
+ if len(cubin) == 0 and not image_data.endswith(b"\x00"):
+ image_data += b"\x00"
+
+ self.module_raw = drv_check(
+ drv_call(["cuModuleLoadDataEx", "cuModuleLoadData"], image_data),
+ "cuModuleLoadData",
+ )
+
+ def get_function(self, name: str):
+ func_raw = drv_check(
+ drv_call("cuModuleGetFunction", self.module_raw, name.encode("utf-8")),
+ "cuModuleGetFunction",
+ )
+ return _KernelFunction(func_raw)
+
+
+class _CudaDevice:
+ class device_attribute:
+ MAX_BLOCK_DIM_X = getattr(
+ getattr(driver, "CUdevice_attribute", object()),
+ "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X",
+ 0,
+ )
+ MAX_BLOCK_DIM_Y = getattr(
+ getattr(driver, "CUdevice_attribute", object()),
+ "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y",
+ 0,
+ )
+ MAX_BLOCK_DIM_Z = getattr(
+ getattr(driver, "CUdevice_attribute", object()),
+ "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z",
+ 0,
+ )
+ MAX_THREADS_PER_BLOCK = getattr(
+ getattr(driver, "CUdevice_attribute", object()),
+ "CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK",
+ 0,
+ )
+ MAX_GRID_DIM_X = getattr(
+ getattr(driver, "CUdevice_attribute", object()),
+ "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X",
+ 0,
+ )
+ MAX_GRID_DIM_Y = getattr(
+ getattr(driver, "CUdevice_attribute", object()),
+ "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y",
+ 0,
+ )
+ MAX_GRID_DIM_Z = getattr(
+ getattr(driver, "CUdevice_attribute", object()),
+ "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z",
+ 0,
+ )
+ WARP_SIZE = getattr(
+ getattr(driver, "CUdevice_attribute", object()),
+ "CU_DEVICE_ATTRIBUTE_WARP_SIZE",
+ 0,
+ )
+ MAX_SHARED_MEMORY_PER_BLOCK = getattr(
+ getattr(driver, "CUdevice_attribute", object()),
+ "CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK",
+ 0,
+ )
+
+ class Device:
+ def __init__(self, index: int):
+ self.index = int(index)
+ self.device_raw = drv_check(drv_call("cuDeviceGet", self.index), "cuDeviceGet")
+
+ @staticmethod
+ def count():
+ return int(drv_check(drv_call("cuDeviceGetCount"), "cuDeviceGetCount"))
+
+ def get_attributes(self):
+ attrs = {}
+ for attr_name in (
+ "MAX_BLOCK_DIM_X",
+ "MAX_BLOCK_DIM_Y",
+ "MAX_BLOCK_DIM_Z",
+ "MAX_THREADS_PER_BLOCK",
+ "MAX_GRID_DIM_X",
+ "MAX_GRID_DIM_Y",
+ "MAX_GRID_DIM_Z",
+ "WARP_SIZE",
+ "MAX_SHARED_MEMORY_PER_BLOCK",
+ ):
+ attr_enum = getattr(_CudaDevice.device_attribute, attr_name)
+ try:
+ val = drv_check(
+ drv_call("cuDeviceGetAttribute", attr_enum, self.device_raw),
+ "cuDeviceGetAttribute",
+ )
+ attrs[attr_enum] = int(val)
+ except Exception:
+ attrs[attr_enum] = 0
+ return attrs
+
+ def compute_capability(self):
+ major_enum = getattr(
+ getattr(driver, "CUdevice_attribute", object()),
+ "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR",
+ 0,
+ )
+ minor_enum = getattr(
+ getattr(driver, "CUdevice_attribute", object()),
+ "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR",
+ 0,
+ )
+ major = drv_check(drv_call("cuDeviceGetAttribute", major_enum, self.device_raw), "cuDeviceGetAttribute")
+ minor = drv_check(drv_call("cuDeviceGetAttribute", minor_enum, self.device_raw), "cuDeviceGetAttribute")
+ return int(major), int(minor)
+
+ def total_memory(self):
+ return int(drv_check(drv_call(["cuDeviceTotalMem", "cuDeviceTotalMem_v2"], self.device_raw), "cuDeviceTotalMem"))
+
+ def pci_bus_id(self):
+ try:
+ bus_id = drv_check(drv_call("cuDeviceGetPCIBusId", 64, self.device_raw), "cuDeviceGetPCIBusId")
+ if isinstance(bus_id, (bytes, bytearray)):
+ return bus_id.decode("utf-8", errors="replace").rstrip("\x00")
+ return str(bus_id)
+ except Exception:
+ return f"cuda-device-{self.index}"
+
+ def name(self):
+ try:
+ name = drv_check(drv_call("cuDeviceGetName", 128, self.device_raw), "cuDeviceGetName")
+ if isinstance(name, (bytes, bytearray)):
+ return name.decode("utf-8", errors="replace").rstrip("\x00")
+ return str(name)
+ except Exception:
+ return f"CUDA Device {self.index}"
+
+ def retain_primary_context(self):
+ ctx_raw = drv_check(drv_call("cuDevicePrimaryCtxRetain", self.device_raw), "cuDevicePrimaryCtxRetain")
+ return _ContextHandle(ctx_raw, self.index, True)
+
+ def make_context(self):
+ ctx_raw = drv_check(
+ drv_call(["cuCtxCreate", "cuCtxCreate_v2"], 0, self.device_raw),
+ "cuCtxCreate",
+ )
+ return _ContextHandle(ctx_raw, self.index, False)
+
+ class Context:
+ @staticmethod
+ def pop():
+ try:
+ drv_check(drv_call("cuCtxPopCurrent"), "cuCtxPopCurrent")
+ return
+ except Exception:
+ pass
+
+ popped = ctypes.c_void_p()
+ drv_check(drv_call("cuCtxPopCurrent", popped), "cuCtxPopCurrent")
+
+ Stream = _StreamHandle
+ ExternalStream = _StreamHandle
+ Event = _EventHandle
+ DeviceAllocation = _DeviceAllocation
+ device_attribute = device_attribute
+
+ @staticmethod
+ def init():
+ drv_check(drv_call("cuInit", 0), "cuInit")
+
+ @staticmethod
+ def get_driver_version():
+ return int(drv_check(drv_call("cuDriverGetVersion"), "cuDriverGetVersion"))
+
+ @staticmethod
+ def mem_alloc(size: int):
+ ptr = drv_check(
+ drv_call(["cuMemAlloc", "cuMemAlloc_v2"], int(size)),
+ "cuMemAlloc",
+ )
+ return _DeviceAllocation(int(to_int(ptr)))
+
+ @staticmethod
+ def memcpy_htod_async(dst_ptr, src_obj, stream_obj):
+ src_view = memoryview(src_obj).cast("B")
+ host_ptr, _keepalive = readonly_host_ptr(src_view)
+ stream_handle = 0 if stream_obj is None else int(stream_obj)
+ drv_check(
+ drv_call(
+ ["cuMemcpyHtoDAsync", "cuMemcpyHtoDAsync_v2"],
+ as_driver_handle("CUdeviceptr", int(dst_ptr)),
+ host_ptr,
+ len(src_view),
+ as_driver_handle("CUstream", stream_handle),
+ ),
+ "cuMemcpyHtoDAsync",
+ )
+
+ @staticmethod
+ def memcpy_dtoh_async(dst_obj, src_ptr, stream_obj):
+ dst_view = memoryview(dst_obj).cast("B")
+ host_ptr, _keepalive = writable_host_ptr(dst_view)
+ stream_handle = 0 if stream_obj is None else int(stream_obj)
+ drv_check(
+ drv_call(
+ ["cuMemcpyDtoHAsync", "cuMemcpyDtoHAsync_v2"],
+ host_ptr,
+ as_driver_handle("CUdeviceptr", int(src_ptr)),
+ len(dst_view),
+ as_driver_handle("CUstream", stream_handle),
+ ),
+ "cuMemcpyDtoHAsync",
+ )
+
+ @staticmethod
+ def pagelocked_empty(size: int, dtype):
+ return np.empty(int(size), dtype=dtype)
+
+
+cuda = _CudaDevice
diff --git a/vkdispatch/backends/cuda_backend/descriptor_sets.py b/vkdispatch/backends/cuda_backend/descriptor_sets.py
new file mode 100644
index 00000000..10670708
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/descriptor_sets.py
@@ -0,0 +1,105 @@
+from __future__ import annotations
+
+from . import state as state
+from .helpers import set_error, to_bytes, buffer_device_ptr
+
+from .handle import CUDAHandle, HandleRegistry
+from typing import Dict, Tuple, Optional
+
+_descriptor_sets: HandleRegistry = HandleRegistry()
+
+class CUDADescriptorSet(CUDAHandle):
+ plan_handle: int
+ buffer_bindings: Dict[int, Tuple[int, int, int, int, int, int]]
+ image_bindings: Dict[int, Tuple[int, int, int, int]]
+ inline_uniform_payload: bytes
+
+ def __init__(self, plan_handle: int):
+ super().__init__(_descriptor_sets)
+
+ self.plan_handle = plan_handle
+ self.buffer_bindings = {}
+ self.image_bindings = {}
+ self.inline_uniform_payload = b""
+
+ @staticmethod
+ def from_handle(handle: int) -> Optional["CUDADescriptorSet"]:
+ return _descriptor_sets.get(int(handle))
+
+ def resolve_buffer_pointer(self, binding: int) -> int:
+ binding_info = self.buffer_bindings.get(binding)
+ if binding_info is None:
+ raise RuntimeError(f"Missing descriptor buffer binding {binding}")
+
+ buffer_handle, offset, _, _, _, _ = binding_info
+
+ buffer_obj = state.buffers.get(int(buffer_handle))
+ if buffer_obj is None:
+ raise RuntimeError(f"Invalid buffer handle {buffer_handle} for binding {binding}")
+
+ return buffer_device_ptr(buffer_obj) + int(offset)
+
+def descriptor_set_create(plan):
+ if int(plan) not in state.compute_plans:
+ set_error("Invalid compute plan handle for descriptor_set_create")
+ return 0
+
+ return CUDADescriptorSet(plan_handle=int(plan)).handle
+
+
+def descriptor_set_destroy(descriptor_set):
+ _descriptor_sets.pop(descriptor_set)
+
+
+def descriptor_set_write_buffer(
+ descriptor_set,
+ binding,
+ object,
+ offset,
+ range,
+ uniform,
+ read_access,
+ write_access,
+):
+ ds = CUDADescriptorSet.from_handle(descriptor_set)
+ if ds is None:
+ set_error("Invalid descriptor set handle for descriptor_set_write_buffer")
+ return
+
+ ds.buffer_bindings[int(binding)] = (
+ int(object),
+ int(offset),
+ int(range),
+ int(uniform),
+ int(read_access),
+ int(write_access),
+ )
+
+
+def descriptor_set_write_image(
+ descriptor_set,
+ binding,
+ object,
+ sampler_obj,
+ read_access,
+ write_access,
+):
+ _ = descriptor_set
+ _ = binding
+ _ = object
+ _ = sampler_obj
+ _ = read_access
+ _ = write_access
+ set_error("CUDA Python backend does not support image objects yet")
+
+
+def descriptor_set_write_inline_uniform(descriptor_set, payload):
+ ds = CUDADescriptorSet.from_handle(descriptor_set)
+ if ds is None:
+ set_error("Invalid descriptor set handle for descriptor_set_write_inline_uniform")
+ return
+
+ try:
+ ds.inline_uniform_payload = to_bytes(payload)
+ except Exception as exc:
+ set_error(f"Failed to store inline uniform payload: {exc}")
diff --git a/vkdispatch/backends/cuda_backend/handle.py b/vkdispatch/backends/cuda_backend/handle.py
new file mode 100644
index 00000000..5f5e5082
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/handle.py
@@ -0,0 +1,26 @@
+from typing import Dict, Optional
+
+from . import state as state
+
+class HandleRegistry:
+ def __init__(self):
+ self.registry: Dict[int, object] = {}
+
+ def new_handle(self, obj: object) -> int:
+ handle = state.next_handle
+ self.registry[handle] = obj
+ state.next_handle += 1
+ return handle
+
+ def get(self, handle: int) -> Optional[object]:
+ return self.registry.get(int(handle))
+
+ def pop(self, handle: int) -> Optional[object]:
+ return self.registry.pop(int(handle), None)
+
+
+class CUDAHandle:
+ handle: int
+
+ def __init__(self, registry: HandleRegistry):
+ self.handle = registry.new_handle(self)
\ No newline at end of file
diff --git a/vkdispatch/backends/cuda_backend/helpers.py b/vkdispatch/backends/cuda_backend/helpers.py
new file mode 100644
index 00000000..5dad2743
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/helpers.py
@@ -0,0 +1,380 @@
+from __future__ import annotations
+
+from contextlib import contextmanager
+import re
+import sys
+from typing import Dict, List, Optional, Tuple, Any
+
+from . import state as state
+from .bindings import driver, np, drv_call, drv_check, to_int
+from .constants import (
+ BINDING_PARAM_RE,
+ KERNEL_SIGNATURE_RE,
+ LOCAL_X_RE,
+ LOCAL_Y_RE,
+ LOCAL_Z_RE,
+ SAMPLER_PARAM_RE,
+)
+from .cuda_primitives import _ByValueKernelArg, cuda
+from .state import CUDABuffer, CUDAComputePlan, CUDAContext, CUDAKernelParam
+
+#from .api_descriptor import CUDADescriptorSet
+
+def new_handle(registry: Dict[int, object], obj: object) -> int:
+ handle = state.next_handle
+ state.next_handle += 1
+ registry[handle] = obj
+ return handle
+
+
+def to_bytes(value) -> bytes:
+ if value is None:
+ return b""
+ if isinstance(value, bytes):
+ return value
+ if isinstance(value, bytearray):
+ return bytes(value)
+ if isinstance(value, memoryview):
+ return value.tobytes()
+ return bytes(value)
+
+
+def set_error(message: str) -> None:
+ state.error_string = str(message)
+
+
+def clear_error() -> None:
+ state.error_string = None
+
+
+def coerce_stream_handle(stream_obj) -> Optional[int]:
+ if stream_obj is None:
+ return None
+
+ if isinstance(stream_obj, int):
+ return int(stream_obj)
+
+ cuda_stream_protocol = getattr(stream_obj, "__cuda_stream__", None)
+ if cuda_stream_protocol is not None:
+ try:
+ proto_value = cuda_stream_protocol() if callable(cuda_stream_protocol) else cuda_stream_protocol
+ if isinstance(proto_value, tuple) and len(proto_value) > 0:
+ proto_value = proto_value[0]
+ return int(proto_value)
+ except Exception:
+ pass
+
+ for attr_name in ("cuda_stream", "ptr", "handle"):
+ if hasattr(stream_obj, attr_name):
+ try:
+ return int(getattr(stream_obj, attr_name))
+ except Exception:
+ pass
+
+ nested = getattr(stream_obj, "stream", None)
+ if nested is not None and nested is not stream_obj:
+ try:
+ return coerce_stream_handle(nested)
+ except Exception:
+ pass
+
+ try:
+ return int(stream_obj)
+ except Exception as exc:
+ raise TypeError(
+ "Unable to extract a CUDA stream handle from the provided object. "
+ "Pass an int handle or an object with __cuda_stream__/.cuda_stream/.ptr/.handle."
+ ) from exc
+
+
+def stream_override_stack() -> List[Optional[int]]:
+ stack = getattr(state.stream_override, "stack", None)
+ if stack is None:
+ stack = []
+ state.stream_override.stack = stack
+ return stack
+
+
+def get_stream_override_handle() -> Optional[int]:
+ stack = getattr(state.stream_override, "stack", None)
+ if not stack:
+ return None
+ return stack[-1]
+
+
+def wrap_external_stream(handle: int):
+ handle = int(handle)
+
+ if handle in state.external_stream_cache:
+ return state.external_stream_cache[handle]
+
+ if handle == 0:
+ return None
+
+ ctor_attempts = [
+ lambda: cuda.Stream(handle=handle),
+ lambda: cuda.Stream(ptr=handle),
+ lambda: cuda.Stream(int(handle)),
+ ]
+
+ external_cls = getattr(cuda, "ExternalStream", None)
+ if external_cls is not None:
+ ctor_attempts.insert(0, lambda: external_cls(handle))
+
+ last_error = None
+ for ctor in ctor_attempts:
+ try:
+ stream_obj = ctor()
+ state.external_stream_cache[handle] = stream_obj
+ return stream_obj
+ except Exception as exc: # pragma: no cover - depends on cuda-python version
+ last_error = exc
+
+ raise RuntimeError(
+ f"Failed to wrap external CUDA stream handle {handle} with CUDA Python. "
+ "This CUDA Python version may not support external stream wrappers."
+ ) from last_error
+
+
+def stream_for_queue(ctx: CUDAContext, queue_index: int):
+ override_handle = get_stream_override_handle()
+ if override_handle is None:
+ return ctx.streams[queue_index]
+ return wrap_external_stream(int(override_handle))
+
+
+def buffer_device_ptr(buffer_obj: CUDABuffer) -> int:
+ return int(buffer_obj.device_ptr)
+
+
+def queue_indices(ctx: CUDAContext, queue_index: int, *, all_on_negative: bool = False) -> List[int]:
+ if ctx.queue_count <= 0:
+ return []
+
+ if queue_index is None:
+ return [0]
+
+ queue_index = int(queue_index)
+
+ if all_on_negative and queue_index < 0:
+ return list(range(ctx.queue_count))
+
+ if queue_index == -1:
+ return [0]
+
+ if 0 <= queue_index < ctx.queue_count:
+ return [queue_index]
+
+ return []
+
+
+def context_from_handle(context_handle: int) -> Optional[CUDAContext]:
+ ctx = state.contexts.get(int(context_handle))
+ if ctx is None:
+ set_error(f"Invalid context handle {context_handle}")
+ return ctx
+
+
+@contextmanager
+def activate_context(ctx: CUDAContext):
+ ctx.cuda_context.push()
+ try:
+ yield
+ finally:
+ cuda.Context.pop()
+
+def allocate_staging_storage(size: int):
+ try:
+ # Pagelocked host memory improves async HtoD/DtoH throughput and overlap.
+ return cuda.pagelocked_empty(int(size), np.uint8)
+ except Exception:
+ return bytearray(int(size))
+
+
+def fallback_max_kernel_param_size(compute_capability_major: int) -> int:
+ # CUDA kernels support at least 4 KiB of launch parameters on legacy devices.
+ # Volta+ devices commonly expose a larger 32 KiB-ish argument space.
+ return 32764 if int(compute_capability_major) >= 7 else 4096
+
+
+def query_max_kernel_param_size(device_raw, compute_capability_major: int) -> int:
+ attr_names = (
+ "CU_DEVICE_ATTRIBUTE_MAX_PARAMETER_SIZE",
+ "CU_DEVICE_ATTRIBUTE_MAX_PARAMETER_SIZE_SUPPORTED",
+ "CU_DEVICE_ATTRIBUTE_MAX_KERNEL_PARAMETER_SIZE",
+ )
+
+ attr_enum_container = getattr(driver, "CUdevice_attribute", None)
+ if attr_enum_container is not None:
+ for attr_name in attr_names:
+ attr_enum = getattr(attr_enum_container, attr_name, None)
+ if attr_enum is None:
+ continue
+
+ try:
+ queried_value = drv_check(
+ drv_call("cuDeviceGetAttribute", attr_enum, device_raw),
+ "cuDeviceGetAttribute",
+ )
+ queried_size = int(to_int(queried_value))
+ if queried_size > 0:
+ return queried_size
+ except Exception:
+ continue
+
+ print(
+ "Warning: Unable to query max kernel parameter size from CUDA driver. Falling back to a conservative default.",
+ file=sys.stderr,
+ )
+
+ return fallback_max_kernel_param_size(compute_capability_major)
+
+
+def parse_local_size(source: str) -> Tuple[int, int, int]:
+ x_match = LOCAL_X_RE.search(source)
+ y_match = LOCAL_Y_RE.search(source)
+ z_match = LOCAL_Z_RE.search(source)
+
+ x = int(x_match.group(1)) if x_match else 1
+ y = int(y_match.group(1)) if y_match else 1
+ z = int(z_match.group(1)) if z_match else 1
+
+ return (x, y, z)
+
+
+def parse_kernel_params(source: str) -> List[CUDAKernelParam]:
+ signature_match = KERNEL_SIGNATURE_RE.search(source)
+ if signature_match is None:
+ raise RuntimeError("Could not find vkdispatch_main kernel signature in CUDA source")
+
+ signature_blob = signature_match.group(1).strip()
+ if len(signature_blob) == 0:
+ return []
+
+ params: List[CUDAKernelParam] = []
+
+ for raw_decl in [part.strip() for part in signature_blob.split(",") if len(part.strip()) > 0]:
+ name_match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)\s*$", raw_decl)
+ if name_match is None:
+ raise RuntimeError(f"Unable to parse kernel parameter declaration '{raw_decl}'")
+
+ param_name = name_match.group(1)
+
+ if param_name == "vkdispatch_uniform_ptr":
+ params.append(CUDAKernelParam("uniform", 0, param_name))
+ continue
+
+ if param_name == "vkdispatch_uniform_value":
+ params.append(CUDAKernelParam("uniform_value", None, param_name))
+ continue
+
+ if param_name == "vkdispatch_pc_value":
+ params.append(CUDAKernelParam("push_constant_value", None, param_name))
+ continue
+
+ binding_match = BINDING_PARAM_RE.match(param_name)
+ if binding_match is not None:
+ params.append(CUDAKernelParam("storage", int(binding_match.group(1)), param_name))
+ continue
+
+ sampler_match = SAMPLER_PARAM_RE.match(param_name)
+ if sampler_match is not None:
+ params.append(CUDAKernelParam("sampler", int(sampler_match.group(1)), param_name))
+ continue
+
+ params.append(CUDAKernelParam("unknown", None, param_name))
+
+ return params
+
+def build_kernel_args_template(
+ plan: CUDAComputePlan,
+ descriptor_set: Optional[Any], # CUDADescriptorSet
+ push_constant_payload: bytes = b"",
+) -> Tuple[object, ...]:
+ args: List[object] = []
+
+ for param in plan.params:
+ if param.kind == "uniform":
+ if descriptor_set is None:
+ raise RuntimeError("Kernel requires a descriptor set but none was provided")
+
+ args.append(np.uintp(descriptor_set.resolve_buffer_pointer(0)))
+ continue
+
+ if param.kind == "uniform_value":
+ if descriptor_set is None:
+ raise RuntimeError("Kernel requires a descriptor set but none was provided")
+
+ if len(descriptor_set.inline_uniform_payload) == 0:
+ raise RuntimeError(
+ "Missing inline uniform payload for CUDA by-value uniform parameter "
+ f"'{param.raw_name}'."
+ )
+
+ args.append(_ByValueKernelArg(descriptor_set.inline_uniform_payload, param.raw_name))
+ continue
+
+ if param.kind == "push_constant_value":
+ if plan.pc_size <= 0:
+ raise RuntimeError(
+ f"Kernel parameter '{param.raw_name}' expects push-constant data, but this compute plan has pc_size={plan.pc_size}."
+ )
+
+ if len(push_constant_payload) == 0:
+ raise RuntimeError(
+ "Missing push-constant payload for CUDA by-value push-constant parameter "
+ f"'{param.raw_name}'."
+ )
+
+ if len(push_constant_payload) != int(plan.pc_size):
+ raise RuntimeError(
+ f"Push-constant payload size mismatch for parameter '{param.raw_name}'. "
+ f"Expected {plan.pc_size} bytes but got {len(push_constant_payload)} bytes."
+ )
+
+ args.append(_ByValueKernelArg(push_constant_payload, param.raw_name))
+ continue
+
+ if param.kind == "storage":
+ if descriptor_set is None:
+ raise RuntimeError("Kernel requires a descriptor set but none was provided")
+
+ if param.binding is None:
+ raise RuntimeError("Storage parameter has no binding index")
+
+ args.append(np.uintp(descriptor_set.resolve_buffer_pointer(param.binding)))
+ continue
+
+ if param.kind == "sampler":
+ raise RuntimeError("CUDA Python backend does not support sampled image bindings yet")
+
+ raise RuntimeError(
+ f"Unsupported kernel parameter '{param.raw_name}'. "
+ "Expected vkdispatch_uniform_ptr / vkdispatch_uniform_value / vkdispatch_pc_value / vkdispatch_binding__ptr."
+ )
+
+ return tuple(args)
+
+
+def align_up(value: int, alignment: int) -> int:
+ if alignment <= 1:
+ return value
+ return ((value + alignment - 1) // alignment) * alignment
+
+
+def estimate_kernel_param_size_bytes(args: Tuple[object, ...]) -> int:
+ total_bytes = 0
+
+ for arg in args:
+ if isinstance(arg, _ByValueKernelArg):
+ payload_size = len(arg.payload)
+ # Kernel params are aligned by argument type. Use a conservative
+ # 16-byte alignment for by-value structs.
+ total_bytes = align_up(total_bytes, 16)
+ total_bytes += payload_size
+ continue
+
+ total_bytes = align_up(total_bytes, 8)
+ total_bytes += 8
+
+ return total_bytes
diff --git a/vkdispatch/backends/cuda_backend/image_fft_stubs.py b/vkdispatch/backends/cuda_backend/image_fft_stubs.py
new file mode 100644
index 00000000..7b21e627
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/image_fft_stubs.py
@@ -0,0 +1,132 @@
+from __future__ import annotations
+
+from . import state as state
+from .helpers import set_error
+
+
+def image_create(context, extent, layers, format, type, view_type, generate_mips):
+ _ = context
+ _ = extent
+ _ = layers
+ _ = format
+ _ = type
+ _ = view_type
+ _ = generate_mips
+ set_error("CUDA Python backend does not support image objects yet")
+ return 0
+
+
+def image_destroy(image):
+ _ = image
+ set_error("CUDA Python backend does not support image objects yet")
+
+
+def image_create_sampler(
+ context,
+ mag_filter,
+ min_filter,
+ mip_mode,
+ address_mode,
+ mip_lod_bias,
+ min_lod,
+ max_lod,
+ border_color,
+):
+ _ = context
+ _ = mag_filter
+ _ = min_filter
+ _ = mip_mode
+ _ = address_mode
+ _ = mip_lod_bias
+ _ = min_lod
+ _ = max_lod
+ _ = border_color
+ set_error("CUDA Python backend does not support image samplers yet")
+ return 0
+
+
+def image_destroy_sampler(sampler):
+ _ = sampler
+ set_error("CUDA Python backend does not support image samplers yet")
+
+
+def image_write(image, data, offset, extent, baseLayer, layerCount, device_index):
+ _ = image
+ _ = data
+ _ = offset
+ _ = extent
+ _ = baseLayer
+ _ = layerCount
+ _ = device_index
+ set_error("CUDA Python backend does not support image writes yet")
+
+
+def image_format_block_size(format):
+ _ = format
+ set_error("CUDA Python backend does not support image format block size queries yet")
+
+
+def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index):
+ _ = image
+ _ = offset
+ _ = extent
+ _ = baseLayer
+ _ = layerCount
+ _ = device_index
+ set_error("CUDA Python backend does not support image reads yet")
+ return bytes(max(0, int(out_size)))
+
+
+def stage_fft_plan_create(
+ context,
+ dims,
+ axes,
+ buffer_size,
+ do_r2c,
+ normalize,
+ pad_left,
+ pad_right,
+ frequency_zeropadding,
+ kernel_num,
+ kernel_convolution,
+ conjugate_convolution,
+ convolution_features,
+ input_buffer_size,
+ num_batches,
+ single_kernel_multiple_batches,
+ keep_shader_code,
+):
+ _ = context
+ _ = dims
+ _ = axes
+ _ = buffer_size
+ _ = do_r2c
+ _ = normalize
+ _ = pad_left
+ _ = pad_right
+ _ = frequency_zeropadding
+ _ = kernel_num
+ _ = kernel_convolution
+ _ = conjugate_convolution
+ _ = convolution_features
+ _ = input_buffer_size
+ _ = num_batches
+ _ = single_kernel_multiple_batches
+ _ = keep_shader_code
+ set_error("CUDA Python backend does not support FFT plans yet")
+ return 0
+
+
+def stage_fft_plan_destroy(plan):
+ _ = plan
+ set_error("CUDA Python backend does not support FFT plans yet")
+
+
+def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer):
+ _ = command_list
+ _ = plan
+ _ = buffer
+ _ = inverse
+ _ = kernel
+ _ = input_buffer
+ set_error("CUDA Python backend does not support FFT stages yet")
diff --git a/vkdispatch/backends/cuda_backend/signal.py b/vkdispatch/backends/cuda_backend/signal.py
new file mode 100644
index 00000000..6dfbca35
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/signal.py
@@ -0,0 +1,116 @@
+from __future__ import annotations
+
+from . import state as state
+from .helpers import (
+ activate_context,
+ context_from_handle,
+ queue_indices,
+ set_error,
+ stream_for_queue,
+)
+
+from typing import Optional
+
+from .cuda_primitives import cuda
+from .handle import CUDAHandle, HandleRegistry
+
+_signals: HandleRegistry = HandleRegistry()
+
+class CUDASignal(CUDAHandle):
+ context_handle: int
+ queue_index: int
+ event: Optional["cuda.Event"] = None
+ submitted: bool = True
+ done: bool = True
+
+ def __init__(self,
+ context_handle: int,
+ queue_index: int,
+ event: Optional["cuda.Event"] = None,
+ submitted: bool = True,
+ done: bool = True):
+ super().__init__(_signals)
+
+ self.context_handle = context_handle
+ self.queue_index = queue_index
+ self.event = event
+ self.submitted = submitted
+ self.done = done
+
+ @staticmethod
+ def from_handle(handle: int) -> Optional["CUDASignal"]:
+ return _signals.get(handle)
+
+ def record(self, stream: "cuda.Stream"):
+ self.submitted = True
+ self.done = False
+ if self.event is None:
+ self.event = cuda.Event()
+ self.event.record(stream)
+
+ def query(self) -> bool:
+ if self.event is None:
+ return bool(self.done)
+
+ try:
+ done = self.event.query()
+ except Exception:
+ return False
+
+ self.done = bool(done)
+ return self.done
+
+def signal_wait(signal_ptr, wait_for_timestamp, queue_index):
+ signal_obj = CUDASignal.from_handle(signal_ptr)
+ if signal_obj is None:
+ return True
+
+ if not bool(wait_for_timestamp):
+ # CUDA Python records signals synchronously on submission; host-side "recorded" waits
+ # should therefore complete immediately once an event exists.
+ if signal_obj.event is None:
+ return bool(signal_obj.done)
+ return bool(signal_obj.submitted)
+
+ if signal_obj.done:
+ return True
+
+ if signal_obj.event is None:
+ return bool(signal_obj.done)
+
+ ctx = state.contexts.get(signal_obj.context_handle)
+ if ctx is None:
+ return signal_obj.query()
+
+ try:
+ with activate_context(ctx):
+ signal_obj.event.synchronize()
+ signal_obj.done = True
+ return True
+ except Exception:
+ return signal_obj.query()
+
+
+def signal_insert(context, queue_index):
+ ctx = context_from_handle(int(context))
+ if ctx is None:
+ return 0
+
+ selected = queue_indices(ctx, int(queue_index))
+ if len(selected) == 0:
+ selected = [0]
+
+ signal = CUDASignal(context_handle=int(context), queue_index=selected[0], submitted=False, done=False)
+
+ try:
+ with activate_context(ctx):
+ signal.record(stream_for_queue(ctx, selected[0]))
+ except Exception as exc:
+ set_error(f"Failed to insert signal: {exc}")
+ return 0
+
+ return signal.handle
+
+
+def signal_destroy(signal_ptr):
+ _signals.pop(signal_ptr)
diff --git a/vkdispatch/backends/cuda_backend/state.py b/vkdispatch/backends/cuda_backend/state.py
new file mode 100644
index 00000000..21e7af25
--- /dev/null
+++ b/vkdispatch/backends/cuda_backend/state.py
@@ -0,0 +1,87 @@
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+import threading
+from typing import Dict, List, Optional, Tuple
+
+from .constants import LOG_LEVEL_WARNING
+from .cuda_primitives import SourceModule, cuda
+
+#from .api_descriptor import CUDADescriptorSet
+
+# --- Runtime state ---
+
+initialized = False
+debug_mode = False
+log_level = LOG_LEVEL_WARNING
+error_string: Optional[str] = None
+next_handle = 1
+
+contexts: Dict[int, "CUDAContext"] = {}
+buffers: Dict[int, "CUDABuffer"] = {}
+command_lists: Dict[int, "CUDACommandList"] = {}
+compute_plans: Dict[int, "CUDAComputePlan"] = {}
+external_stream_cache: Dict[int, object] = {}
+stream_override = threading.local()
+
+
+# --- Internal objects ---
+
+@dataclass
+class CUDAContext:
+ device_index: int
+ cuda_context: "cuda.Context"
+ streams: List["cuda.Stream"]
+ queue_count: int
+ queue_to_device: List[int]
+ max_kernel_param_size: int
+ uses_primary_context: bool = False
+ stopped: bool = False
+
+
+@dataclass
+class CUDABuffer:
+ context_handle: int
+ size: int
+ device_ptr: int
+ device_allocation: Optional["cuda.DeviceAllocation"]
+ owns_allocation: bool
+ staging_data: List[object]
+ signal_handles: List[int]
+
+
+@dataclass
+class CUDACommandRecord:
+ plan_handle: int
+ descriptor_set_handle: int
+ blocks: Tuple[int, int, int]
+ pc_size: int
+
+
+@dataclass
+class CUDACommandList:
+ context_handle: int
+ commands: List[CUDACommandRecord] = field(default_factory=list)
+
+
+@dataclass
+class CUDAKernelParam:
+ kind: str
+ binding: Optional[int]
+ raw_name: str
+
+
+@dataclass
+class CUDAComputePlan:
+ context_handle: int
+ shader_source: bytes
+ bindings: List[int]
+ shader_name: bytes
+ module: SourceModule
+ function: object
+ local_size: Tuple[int, int, int]
+ params: List[CUDAKernelParam]
+ pc_size: int
+
+
+
diff --git a/vkdispatch/backends/dummy_backend.py b/vkdispatch/backends/dummy_backend.py
new file mode 100644
index 00000000..420a59f8
--- /dev/null
+++ b/vkdispatch/backends/dummy_backend.py
@@ -0,0 +1,545 @@
+"""Brython-friendly pure-Python shim for ``vkdispatch_native``.
+
+This module mirrors the Cython-exposed API used by ``vkdispatch`` and provides
+dummy metadata helpers for docs/codegen flows.
+
+Runtime GPU operations are intentionally denied so the dummy backend fails fast
+when used outside codegen-only scripts.
+"""
+
+# --- Runtime state ---
+
+_initialized = False
+_debug_mode = False
+_log_level = 2
+_error_string = None
+_next_handle = 1
+
+_contexts = {}
+_signals = {}
+
+# Device limits exposed through get_devices(); mutable so docs UI can tune them.
+_DEFAULT_SUBGROUP_SIZE = 32
+_DEFAULT_MAX_WORKGROUP_SIZE = (1024, 1024, 64)
+_DEFAULT_MAX_WORKGROUP_INVOCATIONS = 1024
+_DEFAULT_MAX_WORKGROUP_COUNT = (65535, 65535, 65535)
+_DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE = 64 * 1024
+
+_device_subgroup_size = _DEFAULT_SUBGROUP_SIZE
+_device_subgroup_enabled = True
+_device_max_workgroup_size = _DEFAULT_MAX_WORKGROUP_SIZE
+_device_max_workgroup_invocations = _DEFAULT_MAX_WORKGROUP_INVOCATIONS
+_device_max_workgroup_count = _DEFAULT_MAX_WORKGROUP_COUNT
+_device_max_compute_shared_memory_size = _DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE
+
+
+# --- Internal objects ---
+
+class _Signal:
+ __slots__ = ("done",)
+
+ def __init__(self, done=True):
+ self.done = bool(done)
+
+
+class _Context:
+ __slots__ = (
+ "device_indices",
+ "queue_families",
+ "queue_count",
+ "queue_to_device",
+ "stopped",
+ )
+
+ def __init__(self, device_indices, queue_families):
+ self.device_indices = list(device_indices)
+ self.queue_families = [list(fam) for fam in queue_families]
+
+ normalized = []
+ for fam in self.queue_families:
+ normalized.append(fam if len(fam) > 0 else [0])
+ self.queue_families = normalized
+
+ self.queue_count = sum(len(fam) for fam in self.queue_families)
+ if self.queue_count <= 0:
+ self.queue_families = [[0]]
+ self.queue_count = 1
+
+ queue_to_device = []
+ for dev_idx, fam in enumerate(self.queue_families):
+ for _ in fam:
+ queue_to_device.append(dev_idx)
+
+ if len(queue_to_device) == 0:
+ queue_to_device = [0]
+
+ self.queue_to_device = queue_to_device
+ self.stopped = False
+
+# --- Internal helpers ---
+
+def _new_handle(registry, obj):
+ global _next_handle
+ handle = _next_handle
+ _next_handle += 1
+ registry[handle] = obj
+ return handle
+
+def _set_error(message):
+ global _error_string
+ _error_string = str(message)
+
+
+def _clear_error():
+ global _error_string
+ _error_string = None
+
+
+_DUMMY_CODEGEN_ONLY_ERROR = (
+ "The 'dummy' backend is codegen-only and does not support runtime GPU "
+ "operations. Use backend='vulkan', backend='pycuda', or backend='cuda-python' for execution."
+)
+
+
+def _deny_runtime_native_call(function_name):
+ raise RuntimeError(f"{_DUMMY_CODEGEN_ONLY_ERROR} (native call: {function_name})")
+
+
+def _as_positive_int(name, value):
+ try:
+ parsed = int(value)
+ except Exception as exc:
+ raise ValueError("%s must be an integer" % name) from exc
+
+ if parsed <= 0:
+ raise ValueError("%s must be greater than zero" % name)
+
+ return parsed
+
+
+def _as_positive_triplet(name, value):
+ try:
+ parts = list(value)
+ except Exception as exc:
+ raise ValueError("%s must contain exactly 3 integers" % name) from exc
+
+ if len(parts) != 3:
+ raise ValueError("%s must contain exactly 3 integers" % name)
+
+ return (
+ _as_positive_int("%s[0]" % name, parts[0]),
+ _as_positive_int("%s[1]" % name, parts[1]),
+ _as_positive_int("%s[2]" % name, parts[2]),
+ )
+
+
+# --- API: context/init/errors/logging ---
+
+
+def reset_device_options():
+ global _device_subgroup_size
+ global _device_subgroup_enabled
+ global _device_max_workgroup_size
+ global _device_max_workgroup_invocations
+ global _device_max_workgroup_count
+ global _device_max_compute_shared_memory_size
+
+ _device_subgroup_size = _DEFAULT_SUBGROUP_SIZE
+ _device_subgroup_enabled = True
+ _device_max_workgroup_size = _DEFAULT_MAX_WORKGROUP_SIZE
+ _device_max_workgroup_invocations = _DEFAULT_MAX_WORKGROUP_INVOCATIONS
+ _device_max_workgroup_count = _DEFAULT_MAX_WORKGROUP_COUNT
+ _device_max_compute_shared_memory_size = _DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE
+
+
+def set_device_options(
+ subgroup_size=None,
+ subgroup_enabled=None,
+ max_workgroup_size=None,
+ max_workgroup_invocations=None,
+ max_workgroup_count=None,
+ max_compute_shared_memory_size=None,
+):
+ global _device_subgroup_size
+ global _device_subgroup_enabled
+ global _device_max_workgroup_size
+ global _device_max_workgroup_invocations
+ global _device_max_workgroup_count
+ global _device_max_compute_shared_memory_size
+
+ if subgroup_size is not None:
+ _device_subgroup_size = _as_positive_int("subgroup_size", subgroup_size)
+
+ if subgroup_enabled is not None:
+ if not isinstance(subgroup_enabled, bool):
+ raise ValueError("subgroup_enabled must be a boolean value")
+ _device_subgroup_enabled = subgroup_enabled
+
+ if max_workgroup_size is not None:
+ _device_max_workgroup_size = _as_positive_triplet(
+ "max_workgroup_size",
+ max_workgroup_size,
+ )
+
+ if max_workgroup_invocations is not None:
+ _device_max_workgroup_invocations = _as_positive_int(
+ "max_workgroup_invocations",
+ max_workgroup_invocations,
+ )
+
+ if max_workgroup_count is not None:
+ _device_max_workgroup_count = _as_positive_triplet(
+ "max_workgroup_count",
+ max_workgroup_count,
+ )
+
+ if max_compute_shared_memory_size is not None:
+ _device_max_compute_shared_memory_size = _as_positive_int(
+ "max_compute_shared_memory_size",
+ max_compute_shared_memory_size,
+ )
+
+
+def init(debug, log_level):
+ global _initialized, _debug_mode, _log_level
+ _initialized = True
+ _debug_mode = bool(debug)
+ _log_level = int(log_level)
+ _clear_error()
+
+
+def log(log_level, text, file_str, line_str):
+ # Keep logging quiet in docs/brython by default.
+ # Function kept for API compatibility.
+ _ = log_level
+ _ = text
+ _ = file_str
+ _ = line_str
+
+
+def set_log_level(log_level):
+ global _log_level
+ _log_level = int(log_level)
+
+
+def get_devices():
+ if not _initialized:
+ init(False, _log_level)
+
+ # One plausible fake discrete GPU with compute+graphics queue families.
+ device_tuple = (
+ 0, # version_variant
+ 1, # version_major
+ 3, # version_minor
+ 0, # version_patch
+ 1001000, # driver_version
+ 0x1BAD, # vendor_id
+ 0x0001, # device_id
+ 2, # device_type (Discrete GPU)
+ "VKDispatch Web Dummy GPU",
+ 1, # shader_buffer_float32_atomics
+ 1, # shader_buffer_float32_atomic_add
+ 1, # float_64_support
+ 1, # float_16_support
+ 1, # int_64_support
+ 1, # int_16_support
+ 1, # storage_buffer_16_bit_access
+ 1, # uniform_and_storage_buffer_16_bit_access
+ 1, # storage_push_constant_16
+ 1, # storage_input_output_16
+ _device_max_workgroup_size, # max_workgroup_size
+ _device_max_workgroup_invocations, # max_workgroup_invocations
+ _device_max_workgroup_count, # max_workgroup_count
+ 8, # max_descriptor_set_count
+ 256, # max_push_constant_size
+ 1 << 30, # max_storage_buffer_range
+ 65536, # max_uniform_buffer_range
+ 0, # uniform_buffer_alignment
+ _device_subgroup_size, # subgroup_size
+ 0x7FFFFFFF if _device_subgroup_enabled else 0, # supported_stages
+ 0x7FFFFFFF if _device_subgroup_enabled else 0, # supported_operations
+ 1, # quad_operations_in_all_stages
+ _device_max_compute_shared_memory_size, # max_compute_shared_memory_size
+ [
+ (8, 0x006), # compute + transfer
+ (4, 0x007), # graphics + compute + transfer
+ ],
+ 1, # scalar_block_layout
+ 1, # timeline_semaphores
+ bytes((0x56, 0x4B, 0x44, 0x30, 0x57, 0x45, 0x42, 0x31, 0x44, 0x55, 0x4D, 0x4D, 0x59, 0x00, 0x00, 0x01)),
+ )
+
+ return [device_tuple]
+
+
+def context_create(device_indicies, queue_families):
+ try:
+ ctx = _Context(device_indicies, queue_families)
+ return _new_handle(_contexts, ctx)
+ except Exception as exc:
+ _set_error("Failed to create context: %s" % exc)
+ return 0
+
+
+def signal_wait(signal_ptr, wait_for_timestamp, queue_index):
+ _ = wait_for_timestamp
+ _ = queue_index
+ signal_obj = _signals.get(int(signal_ptr))
+ if signal_obj is None:
+ return True
+ return bool(signal_obj.done)
+
+
+def signal_insert(context, queue_index):
+ _ = context
+ _ = queue_index
+ return _new_handle(_signals, _Signal(done=True))
+
+
+def signal_destroy(signal_ptr):
+ _signals.pop(int(signal_ptr), None)
+
+
+def context_destroy(context):
+ _contexts.pop(int(context), None)
+
+
+def get_error_string():
+ if _error_string is None:
+ return 0
+ return _error_string
+
+
+def context_stop_threads(context):
+ ctx = _contexts.get(int(context))
+ if ctx is not None:
+ ctx.stopped = True
+
+
+# --- API: buffers ---
+
+
+def buffer_create(context, size, per_device):
+ _deny_runtime_native_call("buffer_create")
+
+
+def buffer_destroy(buffer):
+ _deny_runtime_native_call("buffer_destroy")
+
+
+def buffer_get_queue_signal(buffer, queue_index):
+ _deny_runtime_native_call("buffer_get_queue_signal")
+
+
+def buffer_wait_staging_idle(buffer, queue_index):
+ _deny_runtime_native_call("buffer_wait_staging_idle")
+
+
+def buffer_write_staging(buffer, queue_index, data, size):
+ _deny_runtime_native_call("buffer_write_staging")
+
+
+def buffer_read_staging(buffer, queue_index, size):
+ _deny_runtime_native_call("buffer_read_staging")
+
+
+def buffer_write(buffer, offset, size, index):
+ _deny_runtime_native_call("buffer_write")
+
+
+def buffer_read(buffer, offset, size, index):
+ _deny_runtime_native_call("buffer_read")
+
+
+# --- API: command lists ---
+
+
+def command_list_create(context):
+ _deny_runtime_native_call("command_list_create")
+
+
+def command_list_destroy(command_list):
+ _deny_runtime_native_call("command_list_destroy")
+
+
+def command_list_get_instance_size(command_list):
+ _deny_runtime_native_call("command_list_get_instance_size")
+
+
+def command_list_reset(command_list):
+ _deny_runtime_native_call("command_list_reset")
+
+
+def command_list_submit(command_list, data, instance_count, index):
+ _deny_runtime_native_call("command_list_submit")
+
+
+# --- API: descriptor sets ---
+
+
+def descriptor_set_create(plan):
+ _deny_runtime_native_call("descriptor_set_create")
+
+
+def descriptor_set_destroy(descriptor_set):
+ _deny_runtime_native_call("descriptor_set_destroy")
+
+
+def descriptor_set_write_buffer(
+ descriptor_set,
+ binding,
+ object,
+ offset,
+ range,
+ uniform,
+ read_access,
+ write_access,
+):
+ _deny_runtime_native_call("descriptor_set_write_buffer")
+
+
+def descriptor_set_write_image(
+ descriptor_set,
+ binding,
+ object,
+ sampler_obj,
+ read_access,
+ write_access,
+):
+ _deny_runtime_native_call("descriptor_set_write_image")
+
+
+# --- API: images/samplers ---
+
+
+def image_create(context, extent, layers, format, type, view_type, generate_mips):
+ _deny_runtime_native_call("image_create")
+
+
+def image_destroy(image):
+ _deny_runtime_native_call("image_destroy")
+
+
+def image_create_sampler(
+ context,
+ mag_filter,
+ min_filter,
+ mip_mode,
+ address_mode,
+ mip_lod_bias,
+ min_lod,
+ max_lod,
+ border_color,
+):
+ _deny_runtime_native_call("image_create_sampler")
+
+
+def image_destroy_sampler(sampler):
+ _deny_runtime_native_call("image_destroy_sampler")
+
+
+def image_write(image, data, offset, extent, baseLayer, layerCount, device_index):
+ _deny_runtime_native_call("image_write")
+
+
+def image_format_block_size(format):
+ _deny_runtime_native_call("image_format_block_size")
+
+
+def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index):
+ _deny_runtime_native_call("image_read")
+
+
+# --- API: compute stage ---
+
+
+def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name):
+ _deny_runtime_native_call("stage_compute_plan_create")
+
+
+def stage_compute_plan_destroy(plan):
+ _deny_runtime_native_call("stage_compute_plan_destroy")
+
+
+def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z):
+ _deny_runtime_native_call("stage_compute_record")
+
+
+# --- API: FFT stage ---
+
+
+def stage_fft_plan_create(
+ context,
+ dims,
+ axes,
+ buffer_size,
+ do_r2c,
+ normalize,
+ pad_left,
+ pad_right,
+ frequency_zeropadding,
+ kernel_num,
+ kernel_convolution,
+ conjugate_convolution,
+ convolution_features,
+ input_buffer_size,
+ num_batches,
+ single_kernel_multiple_batches,
+ keep_shader_code,
+):
+ _deny_runtime_native_call("stage_fft_plan_create")
+
+
+def stage_fft_plan_destroy(plan):
+ _deny_runtime_native_call("stage_fft_plan_destroy")
+
+
+def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer):
+ _deny_runtime_native_call("stage_fft_record")
+
+
+__all__ = [
+ "reset_device_options",
+ "set_device_options",
+ "init",
+ "log",
+ "set_log_level",
+ "get_devices",
+ "context_create",
+ "signal_wait",
+ "signal_insert",
+ "signal_destroy",
+ "context_destroy",
+ "get_error_string",
+ "context_stop_threads",
+ "buffer_create",
+ "buffer_destroy",
+ "buffer_get_queue_signal",
+ "buffer_wait_staging_idle",
+ "buffer_write_staging",
+ "buffer_read_staging",
+ "buffer_write",
+ "buffer_read",
+ "command_list_create",
+ "command_list_destroy",
+ "command_list_get_instance_size",
+ "command_list_reset",
+ "command_list_submit",
+ "descriptor_set_create",
+ "descriptor_set_destroy",
+ "descriptor_set_write_buffer",
+ "descriptor_set_write_image",
+ "image_create",
+ "image_destroy",
+ "image_create_sampler",
+ "image_destroy_sampler",
+ "image_write",
+ "image_format_block_size",
+ "image_read",
+ "stage_compute_plan_create",
+ "stage_compute_plan_destroy",
+ "stage_compute_record",
+ "stage_fft_plan_create",
+ "stage_fft_plan_destroy",
+ "stage_fft_record"
+]
diff --git a/vkdispatch/backends/opencl_backend.py b/vkdispatch/backends/opencl_backend.py
new file mode 100644
index 00000000..eed638a3
--- /dev/null
+++ b/vkdispatch/backends/opencl_backend.py
@@ -0,0 +1,2080 @@
+"""pyopencl-backed runtime shim mirroring the vkdispatch_native API surface.
+
+This module intentionally matches the function names exposed by the Cython
+extension so existing Python runtime objects can call into either backend.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+import hashlib
+import re
+import threading
+from typing import Dict, List, Optional, Tuple
+
+import os
+import sys
+
+try:
+ import numpy as np
+except Exception as exc: # pragma: no cover - import failure path
+ raise ImportError(
+ "The OpenCL Python backend requires both 'pyopencl' and 'numpy' to be installed."
+ ) from exc
+
+try:
+ import pyopencl as cl
+except Exception as exc: # pragma: no cover - import failure path
+ raise ImportError(
+ "The OpenCL runtime backend requires the 'pyopencl' package "
+ "(`pip install pyopencl`)."
+ ) from exc
+
+
+# Log level constants mirrored from native bindings.
+LOG_LEVEL_VERBOSE = 0
+LOG_LEVEL_INFO = 1
+LOG_LEVEL_WARNING = 2
+LOG_LEVEL_ERROR = 3
+
+# Descriptor type enum values mirrored from vkdispatch_native/stages_extern.pxd.
+DESCRIPTOR_TYPE_STORAGE_BUFFER = 1
+DESCRIPTOR_TYPE_STORAGE_IMAGE = 2
+DESCRIPTOR_TYPE_UNIFORM_BUFFER = 3
+DESCRIPTOR_TYPE_UNIFORM_IMAGE = 4
+DESCRIPTOR_TYPE_SAMPLER = 5
+
+# Image format block sizes for formats exposed in vkdispatch.base.image.image_format.
+_IMAGE_BLOCK_SIZES = {
+ 13: 1,
+ 14: 1,
+ 20: 2,
+ 21: 2,
+ 27: 3,
+ 28: 3,
+ 41: 4,
+ 42: 4,
+ 74: 2,
+ 75: 2,
+ 76: 2,
+ 81: 4,
+ 82: 4,
+ 83: 4,
+ 88: 6,
+ 89: 6,
+ 90: 6,
+ 95: 8,
+ 96: 8,
+ 97: 8,
+ 98: 4,
+ 99: 4,
+ 100: 4,
+ 101: 8,
+ 102: 8,
+ 103: 8,
+ 104: 12,
+ 105: 12,
+ 106: 12,
+ 107: 16,
+ 108: 16,
+ 109: 16,
+ 110: 8,
+ 111: 8,
+ 112: 8,
+ 113: 16,
+ 114: 16,
+ 115: 16,
+ 116: 24,
+ 117: 24,
+ 118: 24,
+ 119: 32,
+ 120: 32,
+ 121: 32,
+}
+
+_LOCAL_X_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_X\s+(\d+)")
+_LOCAL_Y_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Y\s+(\d+)")
+_LOCAL_Z_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Z\s+(\d+)")
+_REQD_LOCAL_RE = re.compile(r"reqd_work_group_size\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)")
+_KERNEL_SIGNATURE_RE = re.compile(r"vkdispatch_main\s*\(([^)]*)\)", re.S)
+_BINDING_PARAM_RE = re.compile(r"vkdispatch_binding_(\d+)_ptr$")
+_SAMPLER_PARAM_RE = re.compile(r"vkdispatch_sampler_(\d+)$")
+_PUSH_CONSTANT_STRUCT_RE = re.compile(
+ r"typedef\s+struct\s+PushConstant\s*\{(?P.*?)\}\s*PushConstant\s*;",
+ re.S,
+)
+_PUSH_CONSTANT_FIELD_RE = re.compile(
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)\s+"
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)"
+ r"(?:\s*\[\s*(?P\d+)\s*\])?$"
+)
+_VECTOR_TYPE_RE = re.compile(r"([A-Za-z_][A-Za-z0-9_]*?)([2-4])$")
+_OPENCL_VERSION_RE = re.compile(r"OpenCL\s+(\d+)\.(\d+)")
+_DIGIT_RE = re.compile(r"(\d+)")
+_OPENCL_MAX_INFLIGHT_SUBMISSIONS = 4
+_OPENCL_SUBGROUP_PROBE_SOURCE = """
+__kernel void vkdispatch_subgroup_probe(__global uint *out) {
+ size_t gid = get_global_id(0);
+ if (gid == 0) {
+ out[0] = 0u;
+ }
+}
+"""
+
+
+# --- Runtime state ---
+
+_initialized = False
+_debug_mode = False
+_log_level = LOG_LEVEL_WARNING
+_error_string: Optional[str] = None
+_next_handle = 1
+
+_contexts: Dict[int, "_Context"] = {}
+_signals: Dict[int, "_Signal"] = {}
+_buffers: Dict[int, "_Buffer"] = {}
+_command_lists: Dict[int, "_CommandList"] = {}
+_compute_plans: Dict[int, "_ComputePlan"] = {}
+_descriptor_sets: Dict[int, "_DescriptorSet"] = {}
+_images: Dict[int, object] = {}
+_samplers: Dict[int, object] = {}
+_fft_plans: Dict[int, object] = {}
+_subgroup_size_cache: Dict[Tuple[int, int, str, str], int] = {}
+
+_marker_helpers = threading.local()
+
+
+# --- Internal objects ---
+
+
+@dataclass(frozen=True)
+class _DeviceEntry:
+ logical_index: int
+ platform_index: int
+ device_index: int
+ platform: object
+ device: object
+
+
+@dataclass
+class _Signal:
+ context_handle: int
+ queue_index: int
+ event: Optional[object] = None
+ submitted: bool = True
+ done: bool = True
+
+
+@dataclass
+class _Context:
+ device_index: int
+ cl_context: object
+ queues: List[object]
+ queue_count: int
+ queue_to_device: List[int]
+ sub_buffer_alignment: int
+ submission_events: List[List[object]] = field(default_factory=list)
+ stopped: bool = False
+
+
+@dataclass
+class _Buffer:
+ context_handle: int
+ size: int
+ cl_buffer: object
+ staging_data: List[bytearray]
+ signal_handles: List[int]
+
+
+@dataclass
+class _CommandRecord:
+ plan_handle: int
+ descriptor_set_handle: int
+ blocks: Tuple[int, int, int]
+ pc_size: int
+
+
+@dataclass
+class _CommandList:
+ context_handle: int
+ commands: List[_CommandRecord] = field(default_factory=list)
+
+
+@dataclass
+class _KernelParam:
+ kind: str
+ binding: Optional[int]
+ raw_name: str
+
+
+@dataclass(frozen=True)
+class _PushConstantTypeLayout:
+ host_elem_size: int
+ opencl_elem_size: int
+ opencl_align: int
+
+
+@dataclass(frozen=True)
+class _PushConstantFieldDecl:
+ type_name: str
+ field_name: str
+ count: int
+
+
+@dataclass(frozen=True)
+class _PushConstantFieldLayout:
+ type_name: str
+ field_name: str
+ count: int
+ host_offset: int
+ opencl_offset: int
+ host_elem_size: int
+ opencl_elem_size: int
+
+
+@dataclass(frozen=True)
+class _PushConstantLayout:
+ fields: Tuple[_PushConstantFieldLayout, ...]
+ host_size: int
+ opencl_size: int
+ opencl_alignment: int
+ needs_repack: bool
+
+
+@dataclass
+class _ComputePlan:
+ context_handle: int
+ shader_source: bytes
+ bindings: List[int]
+ shader_name: bytes
+ program: object
+ kernel: object
+ local_size: Tuple[int, int, int]
+ params: List[_KernelParam]
+ pc_size: int
+ pc_layout: Optional[_PushConstantLayout] = None
+
+
+@dataclass
+class _DescriptorSet:
+ plan_handle: int
+ buffer_bindings: Dict[int, Tuple[int, int, int, int, int, int]] = field(default_factory=dict)
+ image_bindings: Dict[int, Tuple[int, int, int, int]] = field(default_factory=dict)
+
+
+# --- Helper utilities ---
+
+
+def _new_handle(registry: Dict[int, object], obj: object) -> int:
+ global _next_handle
+ handle = _next_handle
+ _next_handle += 1
+ registry[handle] = obj
+ return handle
+
+
+def _to_bytes(value) -> bytes:
+ if value is None:
+ return b""
+ if isinstance(value, bytes):
+ return value
+ if isinstance(value, bytearray):
+ return bytes(value)
+ if isinstance(value, memoryview):
+ return value.tobytes()
+ return bytes(value)
+
+
+def _set_error(message: str) -> None:
+ global _error_string
+ _error_string = str(message)
+
+
+def _clear_error() -> None:
+ global _error_string
+ _error_string = None
+
+def _enumerate_opencl_devices() -> List[_DeviceEntry]:
+ entries: List[_DeviceEntry] = []
+
+ if (
+ sys.platform.startswith("linux")
+ and "OCL_ICD_VENDORS" not in os.environ
+ and "OPENCL_VENDOR_PATH" not in os.environ
+ and os.path.isdir("/etc/OpenCL/vendors")
+ ):
+ os.environ["OCL_ICD_VENDORS"] = "/etc/OpenCL/vendors"
+
+ try:
+ platforms = cl.get_platforms()
+ except Exception as exc:
+ raise RuntimeError(
+ f"Failed to get OpenCL Platform: {exc}"
+ ) from exc
+
+ logical_index = 0
+ for platform_index, platform in enumerate(platforms):
+ try:
+ devices = platform.get_devices()
+ except Exception:
+ continue
+
+ for device_index, device in enumerate(devices):
+ entries.append(
+ _DeviceEntry(
+ logical_index=logical_index,
+ platform_index=platform_index,
+ device_index=device_index,
+ platform=platform,
+ device=device,
+ )
+ )
+ logical_index += 1
+
+ return entries
+
+
+def _coerce_int(value, fallback: int = 0) -> int:
+ try:
+ return int(value)
+ except Exception:
+ return int(fallback)
+
+
+def _align_up(value: int, alignment: int) -> int:
+ if alignment <= 1:
+ return int(value)
+ return ((int(value) + alignment - 1) // alignment) * alignment
+
+
+def _opencl_version_components(version_text: str) -> Tuple[int, int]:
+ if not isinstance(version_text, str):
+ return (0, 0)
+
+ match = _OPENCL_VERSION_RE.search(version_text)
+ if match is None:
+ return (0, 0)
+
+ return (_coerce_int(match.group(1), 0), _coerce_int(match.group(2), 0))
+
+
+def _driver_version_number(driver_text: str) -> int:
+ if not isinstance(driver_text, str):
+ return 0
+
+ pieces = _DIGIT_RE.findall(driver_text)
+ if len(pieces) == 0:
+ return 0
+
+ folded = 0
+ weight = 1_000_000
+ for token in pieces[:3]:
+ folded += _coerce_int(token, 0) * weight
+ weight = max(1, weight // 1000)
+ return folded
+
+
+def _device_type_to_vkdispatch(device_type: int) -> int:
+ if device_type & getattr(cl.device_type, "GPU", 0):
+ return 2
+ if device_type & getattr(cl.device_type, "ACCELERATOR", 0):
+ return 3
+ if device_type & getattr(cl.device_type, "CPU", 0):
+ return 4
+ return 0
+
+
+def _device_uuid(entry: _DeviceEntry, device_name: str, driver_version: str) -> bytes:
+ platform_vendor = ""
+ platform_name = ""
+ try:
+ platform_vendor = str(entry.platform.vendor)
+ except Exception:
+ platform_vendor = ""
+ try:
+ platform_name = str(entry.platform.name)
+ except Exception:
+ platform_name = ""
+
+ seed = (
+ f"opencl:{entry.platform_index}:{entry.device_index}:"
+ f"{platform_vendor}:"
+ f"{platform_name}:"
+ f"{device_name}:{driver_version}"
+ )
+ return hashlib.md5(seed.encode("utf-8")).digest()
+
+
+def _device_attr(device, attr_name: str, default):
+ try:
+ return getattr(device, attr_name)
+ except Exception:
+ return default
+
+
+def _release_opencl_object(obj: object) -> None:
+ release = getattr(obj, "release", None)
+ if callable(release):
+ try:
+ release()
+ except Exception:
+ pass
+
+
+def _device_identity_key(entry: _DeviceEntry, device_name: str, driver_version: str) -> Tuple[int, int, str, str]:
+ return (int(entry.platform_index), int(entry.device_index), str(device_name), str(driver_version))
+
+
+def _kernel_preferred_workgroup_multiple(device) -> Optional[int]:
+ ctx = None
+ program = None
+ kernel = None
+
+ try:
+ ctx = cl.Context(devices=[device])
+ program = cl.Program(ctx, _OPENCL_SUBGROUP_PROBE_SOURCE).build()
+ kernel = cl.Kernel(program, "vkdispatch_subgroup_probe")
+ multiple = kernel.get_work_group_info(
+ cl.kernel_work_group_info.PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
+ device,
+ )
+ multiple_int = _coerce_int(multiple, 0)
+ if multiple_int > 0:
+ return multiple_int
+ except Exception:
+ return None
+ finally:
+ _release_opencl_object(kernel)
+ _release_opencl_object(program)
+ _release_opencl_object(ctx)
+
+ return None
+
+
+def _round_down_power_of_two(value: int) -> int:
+ value = int(value)
+ if value <= 1:
+ return 1
+ return 1 << (value.bit_length() - 1)
+
+
+def _vendor_subgroup_fallback(
+ *,
+ device_type: int,
+ vendor_text: str,
+ platform_name: str,
+ device_name: str,
+ max_workgroup_invocations: int,
+) -> int:
+ if device_type == 4:
+ return 1
+
+ combined = " ".join(
+ token.lower()
+ for token in (vendor_text, platform_name, device_name)
+ if isinstance(token, str) and len(token) > 0
+ )
+
+ if "nvidia" in combined:
+ return 32
+
+ if "advanced micro devices" in combined or " amd" in f" {combined}" or "radeon" in combined:
+ return 64
+
+ if "apple" in combined or "m1" in combined or "m2" in combined or "m3" in combined or "m4" in combined:
+ return 32
+
+ if "intel" in combined:
+ return 16 if device_type == 2 else 1
+
+ if device_type == 2:
+ bounded = min(max(1, int(max_workgroup_invocations)), 64)
+ if bounded >= 32:
+ return 32
+ return _round_down_power_of_two(bounded)
+
+ return 1
+
+
+def _estimate_subgroup_size(
+ entry: _DeviceEntry,
+ device,
+ *,
+ device_name: str,
+ driver_version: str,
+ device_type: int,
+ max_workgroup_invocations: int,
+) -> int:
+ cache_key = _device_identity_key(entry, device_name, driver_version)
+ cached = _subgroup_size_cache.get(cache_key)
+ if cached is not None:
+ return cached
+
+ platform_name = str(_device_attr(entry.platform, "name", ""))
+ vendor_text = str(_device_attr(device, "vendor", _device_attr(entry.platform, "vendor", "")))
+
+ subgroup_size = _kernel_preferred_workgroup_multiple(device)
+ if subgroup_size is None:
+ subgroup_size = _vendor_subgroup_fallback(
+ device_type=device_type,
+ vendor_text=vendor_text,
+ platform_name=platform_name,
+ device_name=device_name,
+ max_workgroup_invocations=max_workgroup_invocations,
+ )
+
+ subgroup_size = max(1, int(subgroup_size))
+ _subgroup_size_cache[cache_key] = subgroup_size
+ return subgroup_size
+
+
+def _context_from_handle(context_handle: int) -> Optional[_Context]:
+ ctx = _contexts.get(int(context_handle))
+ if ctx is None:
+ _set_error(f"Invalid context handle {context_handle}")
+ return ctx
+
+
+def _queue_indices(ctx: _Context, queue_index: int, *, all_on_negative: bool = False) -> List[int]:
+ if ctx.queue_count <= 0:
+ return []
+
+ if queue_index is None:
+ return [0]
+
+ queue_index = int(queue_index)
+
+ if all_on_negative and queue_index < 0:
+ return list(range(ctx.queue_count))
+
+ if queue_index == -1:
+ return [0]
+
+ if 0 <= queue_index < ctx.queue_count:
+ return [queue_index]
+
+ return []
+
+
+def _record_signal(signal: _Signal, event_obj: Optional[object]) -> None:
+ if signal.event is not None and signal.event is not event_obj:
+ try:
+ signal.event.release()
+ except Exception:
+ pass
+ signal.submitted = True
+ signal.done = event_obj is None
+ signal.event = event_obj
+
+
+def _query_event_done(event_obj: Optional[object]) -> bool:
+ if event_obj is None:
+ return True
+
+ try:
+ complete = int(getattr(getattr(cl, "command_execution_status", object()), "COMPLETE", 0))
+ status = _coerce_int(event_obj.command_execution_status, 0)
+ return status == complete
+ except Exception:
+ return False
+
+
+def _query_signal(signal: _Signal) -> bool:
+ signal.done = _query_event_done(signal.event) if signal.event is not None else bool(signal.done)
+ return signal.done
+
+
+def _wait_signal(signal: _Signal) -> bool:
+ if signal.event is None:
+ return bool(signal.done)
+
+ try:
+ signal.event.wait()
+ signal.done = True
+ return True
+ except Exception:
+ return _query_signal(signal)
+
+
+def _parse_local_size(source: str) -> Tuple[int, int, int]:
+ x_match = _LOCAL_X_RE.search(source)
+ y_match = _LOCAL_Y_RE.search(source)
+ z_match = _LOCAL_Z_RE.search(source)
+
+ if x_match is not None and y_match is not None and z_match is not None:
+ return (
+ _coerce_int(x_match.group(1), 1),
+ _coerce_int(y_match.group(1), 1),
+ _coerce_int(z_match.group(1), 1),
+ )
+
+ reqd_match = _REQD_LOCAL_RE.search(source)
+ if reqd_match is not None:
+ return (
+ _coerce_int(reqd_match.group(1), 1),
+ _coerce_int(reqd_match.group(2), 1),
+ _coerce_int(reqd_match.group(3), 1),
+ )
+
+ return (1, 1, 1)
+
+
+def _opencl_device_launch_limits(logical_device_index: int) -> Tuple[Tuple[int, int, int], int]:
+ entries = _enumerate_opencl_devices()
+ if logical_device_index < 0 or logical_device_index >= len(entries):
+ raise RuntimeError(
+ f"OpenCL device index {logical_device_index} is out of range for launch validation"
+ )
+
+ device = entries[logical_device_index].device
+ max_work_item_sizes = tuple(
+ _coerce_int(x, 1)
+ for x in _device_attr(device, "max_work_item_sizes", (1, 1, 1))
+ )
+
+ if len(max_work_item_sizes) < 3:
+ max_work_item_sizes = (max_work_item_sizes + (1, 1, 1))[:3]
+ else:
+ max_work_item_sizes = max_work_item_sizes[:3]
+
+ max_workgroup_size = (
+ max(1, int(max_work_item_sizes[0])),
+ max(1, int(max_work_item_sizes[1])),
+ max(1, int(max_work_item_sizes[2])),
+ )
+ max_workgroup_invocations = max(
+ 1,
+ _coerce_int(_device_attr(device, "max_work_group_size", 1), 1),
+ )
+
+ return max_workgroup_size, max_workgroup_invocations
+
+
+def _validate_local_size_for_enqueue(ctx: _Context, local_size: Tuple[int, int, int]) -> None:
+ max_workgroup_size, max_workgroup_invocations = _opencl_device_launch_limits(ctx.device_index)
+ local_x, local_y, local_z = (max(1, int(dim)) for dim in local_size)
+ local_invocations = local_x * local_y * local_z
+
+ violations = []
+ if local_x > max_workgroup_size[0]:
+ violations.append(f"x={local_x} exceeds {max_workgroup_size[0]}")
+ if local_y > max_workgroup_size[1]:
+ violations.append(f"y={local_y} exceeds {max_workgroup_size[1]}")
+ if local_z > max_workgroup_size[2]:
+ violations.append(f"z={local_z} exceeds {max_workgroup_size[2]}")
+ if local_invocations > max_workgroup_invocations:
+ violations.append(
+ f"total invocations={local_invocations} exceeds {max_workgroup_invocations}"
+ )
+
+ if violations:
+ raise RuntimeError(
+ "OpenCL local size is invalid for the active device: "
+ f"requested ({local_x}, {local_y}, {local_z}), "
+ f"device limits {max_workgroup_size} with max_work_group_size="
+ f"{max_workgroup_invocations} ({'; '.join(violations)})"
+ )
+
+
+_PUSH_CONSTANT_SCALAR_LAYOUTS: Dict[str, Tuple[int, int]] = {
+ "char": (1, 1),
+ "uchar": (1, 1),
+ "short": (2, 2),
+ "ushort": (2, 2),
+ "int": (4, 4),
+ "uint": (4, 4),
+ "long": (8, 8),
+ "ulong": (8, 8),
+ "half": (2, 2),
+ "float": (4, 4),
+ "double": (8, 8),
+}
+
+_PUSH_CONSTANT_MATRIX_LAYOUTS: Dict[str, _PushConstantTypeLayout] = {
+ "vkdispatch_mat2": _PushConstantTypeLayout(host_elem_size=16, opencl_elem_size=16, opencl_align=8),
+ "vkdispatch_mat3": _PushConstantTypeLayout(host_elem_size=36, opencl_elem_size=36, opencl_align=1),
+ "vkdispatch_mat4": _PushConstantTypeLayout(host_elem_size=64, opencl_elem_size=64, opencl_align=16),
+ "vkdispatch_packed_float3": _PushConstantTypeLayout(host_elem_size=12, opencl_elem_size=12, opencl_align=1),
+}
+
+
+def _extract_push_constant_struct_body(source: str) -> Optional[str]:
+ struct_match = _PUSH_CONSTANT_STRUCT_RE.search(source)
+ if struct_match is None:
+ return None
+ return struct_match.group("body")
+
+
+def _parse_push_constant_struct_fields(body: str) -> List[_PushConstantFieldDecl]:
+ fields: List[_PushConstantFieldDecl] = []
+
+ for raw_decl in body.split(";"):
+ decl = " ".join(raw_decl.strip().split())
+ if len(decl) == 0:
+ continue
+
+ field_match = _PUSH_CONSTANT_FIELD_RE.fullmatch(decl)
+ if field_match is None:
+ raise RuntimeError(f"Unable to parse PushConstant field declaration '{decl}'")
+
+ type_name = field_match.group("type")
+ field_name = field_match.group("name")
+ count_token = field_match.group("count")
+ count = 1 if count_token is None else _coerce_int(count_token, 0)
+
+ if count <= 0:
+ raise RuntimeError(f"Invalid PushConstant array size for field '{field_name}'")
+
+ fields.append(_PushConstantFieldDecl(type_name=type_name, field_name=field_name, count=count))
+
+ return fields
+
+
+def _push_constant_type_layout(type_name: str) -> _PushConstantTypeLayout:
+ matrix_layout = _PUSH_CONSTANT_MATRIX_LAYOUTS.get(type_name)
+ if matrix_layout is not None:
+ return matrix_layout
+
+ scalar_layout = _PUSH_CONSTANT_SCALAR_LAYOUTS.get(type_name)
+ if scalar_layout is not None:
+ size, align = scalar_layout
+ return _PushConstantTypeLayout(host_elem_size=size, opencl_elem_size=size, opencl_align=align)
+
+ vector_match = _VECTOR_TYPE_RE.fullmatch(type_name)
+ if vector_match is not None:
+ scalar_name = vector_match.group(1)
+ lane_count = _coerce_int(vector_match.group(2), 0)
+ scalar_info = _PUSH_CONSTANT_SCALAR_LAYOUTS.get(scalar_name)
+ if scalar_info is None:
+ raise RuntimeError(f"Unsupported PushConstant vector scalar type '{scalar_name}'")
+
+ scalar_size, _scalar_align = scalar_info
+ host_elem_size = scalar_size * lane_count
+
+ if lane_count == 3:
+ opencl_elem_size = scalar_size * 4
+ opencl_align = scalar_size * 4
+ else:
+ opencl_elem_size = host_elem_size
+ opencl_align = opencl_elem_size
+
+ return _PushConstantTypeLayout(
+ host_elem_size=host_elem_size,
+ opencl_elem_size=opencl_elem_size,
+ opencl_align=opencl_align,
+ )
+
+ raise RuntimeError(f"Unsupported PushConstant field type '{type_name}'")
+
+
+def _compute_push_constant_layout(field_decls: List[_PushConstantFieldDecl]) -> _PushConstantLayout:
+ host_offset = 0
+ opencl_offset = 0
+ max_opencl_align = 1
+ needs_repack = False
+ field_layouts: List[_PushConstantFieldLayout] = []
+
+ for field_decl in field_decls:
+ type_layout = _push_constant_type_layout(field_decl.type_name)
+
+ opencl_offset = _align_up(opencl_offset, type_layout.opencl_align)
+
+ if type_layout.opencl_align > max_opencl_align:
+ max_opencl_align = type_layout.opencl_align
+
+ if host_offset != opencl_offset:
+ needs_repack = True
+ if type_layout.host_elem_size != type_layout.opencl_elem_size:
+ needs_repack = True
+
+ field_layouts.append(
+ _PushConstantFieldLayout(
+ type_name=field_decl.type_name,
+ field_name=field_decl.field_name,
+ count=field_decl.count,
+ host_offset=host_offset,
+ opencl_offset=opencl_offset,
+ host_elem_size=type_layout.host_elem_size,
+ opencl_elem_size=type_layout.opencl_elem_size,
+ )
+ )
+
+ host_offset += type_layout.host_elem_size * field_decl.count
+ opencl_offset += type_layout.opencl_elem_size * field_decl.count
+
+ opencl_size = _align_up(opencl_offset, max_opencl_align)
+ if opencl_size != host_offset:
+ needs_repack = True
+
+ return _PushConstantLayout(
+ fields=tuple(field_layouts),
+ host_size=host_offset,
+ opencl_size=opencl_size,
+ opencl_alignment=max_opencl_align,
+ needs_repack=needs_repack,
+ )
+
+
+def _build_push_constant_layout(source: str, expected_host_size: int) -> Optional[_PushConstantLayout]:
+ expected_host_size = int(expected_host_size)
+ if expected_host_size <= 0:
+ return None
+
+ body = _extract_push_constant_struct_body(source)
+ if body is None:
+ raise RuntimeError("Could not find PushConstant struct declaration in OpenCL source")
+
+ field_decls = _parse_push_constant_struct_fields(body)
+ if len(field_decls) == 0:
+ raise RuntimeError("PushConstant struct declaration is empty")
+
+ layout = _compute_push_constant_layout(field_decls)
+ if layout.host_size != expected_host_size:
+ raise RuntimeError(
+ f"PushConstant host layout mismatch. Expected {expected_host_size} bytes "
+ f"but parsed {layout.host_size} bytes from OpenCL source."
+ )
+
+ return layout
+
+
+def _repack_push_constant_payload(
+ push_constant_payload: bytes,
+ layout: Optional[_PushConstantLayout],
+) -> bytes:
+ payload = _to_bytes(push_constant_payload)
+
+ if layout is None or not layout.needs_repack:
+ return payload
+
+ if len(payload) != int(layout.host_size):
+ raise RuntimeError(
+ f"PushConstant payload length mismatch for repack. "
+ f"Expected {layout.host_size} bytes but got {len(payload)} bytes."
+ )
+
+ out = bytearray(int(layout.opencl_size))
+
+ for field in layout.fields:
+ if field.host_elem_size > field.opencl_elem_size:
+ raise RuntimeError(
+ f"PushConstant field '{field.field_name}' host element size ({field.host_elem_size}) "
+ f"exceeds OpenCL ABI element size ({field.opencl_elem_size})."
+ )
+
+ for element_index in range(int(field.count)):
+ host_start = field.host_offset + (element_index * field.host_elem_size)
+ host_end = host_start + field.host_elem_size
+ opencl_start = field.opencl_offset + (element_index * field.opencl_elem_size)
+ opencl_end = opencl_start + field.host_elem_size
+ out[opencl_start:opencl_end] = payload[host_start:host_end]
+
+ return bytes(out)
+
+
+def _parse_kernel_params(source: str) -> List[_KernelParam]:
+ signature_match = _KERNEL_SIGNATURE_RE.search(source)
+ if signature_match is None:
+ raise RuntimeError("Could not find vkdispatch_main kernel signature in OpenCL source")
+
+ signature_blob = signature_match.group(1).strip()
+ if len(signature_blob) == 0:
+ return []
+
+ params: List[_KernelParam] = []
+
+ for raw_decl in [part.strip() for part in signature_blob.split(",") if len(part.strip()) > 0]:
+ name_match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)\s*$", raw_decl)
+ if name_match is None:
+ raise RuntimeError(f"Unable to parse kernel parameter declaration '{raw_decl}'")
+
+ param_name = name_match.group(1)
+
+ if param_name == "vkdispatch_uniform_ptr":
+ params.append(_KernelParam("uniform", 0, param_name))
+ continue
+
+ if param_name == "vkdispatch_pc_value":
+ params.append(_KernelParam("push_constant_value", None, param_name))
+ continue
+
+ binding_match = _BINDING_PARAM_RE.match(param_name)
+ if binding_match is not None:
+ params.append(_KernelParam("storage", _coerce_int(binding_match.group(1), 0), param_name))
+ continue
+
+ sampler_match = _SAMPLER_PARAM_RE.match(param_name)
+ if sampler_match is not None:
+ params.append(_KernelParam("sampler", _coerce_int(sampler_match.group(1), 0), param_name))
+ continue
+
+ params.append(_KernelParam("unknown", None, param_name))
+
+ return params
+
+
+def _buffer_access_flags(read_access: int, write_access: int) -> int:
+ read_enabled = int(read_access) != 0
+ write_enabled = int(write_access) != 0
+
+ if read_enabled and not write_enabled:
+ return int(cl.mem_flags.READ_ONLY)
+ if write_enabled and not read_enabled:
+ return int(cl.mem_flags.WRITE_ONLY)
+ return int(cl.mem_flags.READ_WRITE)
+
+
+def _resolve_descriptor_buffer(
+ descriptor_set: _DescriptorSet,
+ binding: int,
+ ctx: _Context,
+ keepalive: List[object],
+):
+ binding_info = descriptor_set.buffer_bindings.get(int(binding))
+ if binding_info is None:
+ raise RuntimeError(f"Missing descriptor buffer binding {binding}")
+
+ buffer_handle, offset, requested_range, _uniform, read_access, write_access = binding_info
+
+ buffer_obj = _buffers.get(int(buffer_handle))
+ if buffer_obj is None:
+ raise RuntimeError(f"Invalid buffer handle {buffer_handle} for binding {binding}")
+
+ offset = int(offset)
+ requested_range = int(requested_range)
+
+ if offset < 0:
+ raise RuntimeError(f"Negative descriptor offset {offset} for binding {binding}")
+
+ max_size = int(buffer_obj.size)
+ if offset > max_size:
+ raise RuntimeError(f"Descriptor offset {offset} exceeds buffer size {max_size} for binding {binding}")
+
+ sub_size = max_size - offset if requested_range <= 0 else requested_range
+ if sub_size < 0:
+ raise RuntimeError(f"Invalid descriptor range {sub_size} for binding {binding}")
+
+ if offset + sub_size > max_size:
+ raise RuntimeError(
+ f"Descriptor range (offset={offset}, size={sub_size}) exceeds buffer size {max_size} for binding {binding}"
+ )
+
+ if offset == 0 and sub_size == max_size:
+ return buffer_obj.cl_buffer
+
+ if (offset % ctx.sub_buffer_alignment) != 0:
+ raise RuntimeError(
+ f"Descriptor offset {offset} for binding {binding} is not aligned to "
+ f"{ctx.sub_buffer_alignment} bytes required by this OpenCL device"
+ )
+
+ sub_buffer = buffer_obj.cl_buffer.get_sub_region(
+ int(offset),
+ int(sub_size),
+ _buffer_access_flags(read_access, write_access),
+ )
+ keepalive.append(sub_buffer)
+ return sub_buffer
+
+
+def _build_kernel_args(
+ plan: _ComputePlan,
+ descriptor_set: Optional[_DescriptorSet],
+ ctx: _Context,
+ push_constant_payload: bytes = b"",
+) -> Tuple[List[object], List[object]]:
+ args: List[object] = []
+ keepalive: List[object] = []
+
+ for param in plan.params:
+ if param.kind == "uniform":
+ if descriptor_set is None:
+ raise RuntimeError("Kernel requires a descriptor set but none was provided")
+ args.append(_resolve_descriptor_buffer(descriptor_set, 0, ctx, keepalive))
+ continue
+
+ if param.kind == "storage":
+ if descriptor_set is None:
+ raise RuntimeError("Kernel requires a descriptor set but none was provided")
+ if param.binding is None:
+ raise RuntimeError("Storage parameter has no binding index")
+ args.append(_resolve_descriptor_buffer(descriptor_set, int(param.binding), ctx, keepalive))
+ continue
+
+ if param.kind == "push_constant_value":
+ if int(plan.pc_size) <= 0:
+ raise RuntimeError(
+ f"Kernel parameter '{param.raw_name}' expects push-constant data, but this compute plan has pc_size={plan.pc_size}."
+ )
+
+ if len(push_constant_payload) == 0:
+ raise RuntimeError(
+ "Missing push-constant payload for OpenCL by-value push-constant parameter "
+ f"'{param.raw_name}'."
+ )
+
+ if len(push_constant_payload) != int(plan.pc_size):
+ raise RuntimeError(
+ f"Push-constant payload size mismatch for parameter '{param.raw_name}'. "
+ f"Expected {plan.pc_size} bytes but got {len(push_constant_payload)} bytes."
+ )
+
+ args.append(_repack_push_constant_payload(push_constant_payload, plan.pc_layout))
+ continue
+
+ if param.kind == "sampler":
+ raise RuntimeError("OpenCL backend does not support image/sampler bindings")
+
+ raise RuntimeError(
+ f"Unsupported kernel parameter '{param.raw_name}'. "
+ "Expected vkdispatch_uniform_ptr / vkdispatch_pc_value / vkdispatch_binding__ptr."
+ )
+
+ return args, keepalive
+
+
+def _marker_wait_functions() -> List[object]:
+ cached = getattr(_marker_helpers, "funcs", None)
+ if cached is not None:
+ return cached
+
+ funcs: List[object] = []
+ for fn_name in (
+ "enqueue_marker",
+ "enqueue_marker_with_wait_list",
+ "enqueue_barrier_with_wait_list",
+ ):
+ fn = getattr(cl, fn_name, None)
+ if fn is not None:
+ funcs.append(fn)
+
+ _marker_helpers.funcs = funcs
+ return funcs
+
+
+def _insert_queue_marker_event(queue) -> Optional[object]:
+ for marker_fn in _marker_wait_functions():
+ try:
+ event_obj = marker_fn(queue)
+ if event_obj is not None:
+ return event_obj
+ except TypeError:
+ try:
+ event_obj = marker_fn(queue, wait_for=[])
+ if event_obj is not None:
+ return event_obj
+ except Exception:
+ continue
+ except Exception:
+ continue
+
+ return None
+
+
+def _release_event(event_obj: Optional[object]) -> None:
+ if event_obj is None:
+ return
+
+ try:
+ event_obj.release()
+ except Exception:
+ pass
+
+
+def _prune_submission_events(ctx: _Context, queue_index: int) -> int:
+ pending_events: List[object] = []
+
+ for event_obj in ctx.submission_events[queue_index]:
+ if _query_event_done(event_obj):
+ _release_event(event_obj)
+ continue
+
+ pending_events.append(event_obj)
+
+ ctx.submission_events[queue_index] = pending_events
+ return len(pending_events)
+
+
+def _reserve_submission_slot(ctx: _Context, queue_index: int) -> bool:
+ return _prune_submission_events(ctx, queue_index) < _OPENCL_MAX_INFLIGHT_SUBMISSIONS
+
+
+def _track_submission_completion(ctx: _Context, queue_index: int) -> None:
+ queue = ctx.queues[queue_index]
+ marker_event = _insert_queue_marker_event(queue)
+
+ if marker_event is None:
+ queue.finish()
+ _prune_submission_events(ctx, queue_index)
+ return
+
+ ctx.submission_events[queue_index].append(marker_event)
+ queue.flush()
+
+
+# --- API: context/init/logging ---
+
+
+def init(debug, log_level):
+ global _initialized, _debug_mode, _log_level
+
+ _debug_mode = bool(debug)
+ _log_level = int(log_level)
+ _clear_error()
+
+ if _initialized:
+ return
+
+ _initialized = True
+
+
+def log(log_level, text, file_str, line_str):
+ _ = log_level
+ _ = text
+ _ = file_str
+ _ = line_str
+
+
+def set_log_level(log_level):
+ global _log_level
+ _log_level = int(log_level)
+
+
+def get_devices():
+ if not _initialized:
+ init(False, _log_level)
+
+ entries = _enumerate_opencl_devices()
+ devices = []
+
+ for entry in entries:
+ device = entry.device
+ opencl_version = _device_attr(device, "version", "")
+ version_major, version_minor = _opencl_version_components(opencl_version)
+ version_patch = 0
+
+ driver_version = str(_device_attr(device, "driver_version", ""))
+ driver_version_num = _driver_version_number(driver_version)
+
+ vendor_id = _coerce_int(_device_attr(device, "vendor_id", 0), 0)
+ device_id = int(entry.logical_index)
+ device_type = _device_type_to_vkdispatch(_coerce_int(_device_attr(device, "type", 0), 0))
+ device_name = str(_device_attr(device, "name", f"OpenCL Device {entry.logical_index}"))
+
+ extensions = str(_device_attr(device, "extensions", ""))
+ float32_atomic_support = (
+ "cl_ext_float_atomics" in extensions
+ or "cl_khr_float_atomics" in extensions
+ )
+ float64_support = "cl_khr_fp64" in extensions or _coerce_int(_device_attr(device, "double_fp_config", 0), 0) != 0
+ float16_support = "cl_khr_fp16" in extensions or _coerce_int(_device_attr(device, "half_fp_config", 0), 0) != 0
+ int64_support = _coerce_int(_device_attr(device, "address_bits", 0), 0) >= 64
+ int16_support = _coerce_int(_device_attr(device, "preferred_vector_width_short", 0), 0) > 0
+
+ max_work_item_sizes = tuple(
+ _coerce_int(x, 1)
+ for x in _device_attr(device, "max_work_item_sizes", (1, 1, 1))
+ )
+ if len(max_work_item_sizes) < 3:
+ max_work_item_sizes = (
+ max_work_item_sizes + (1, 1, 1)
+ )[:3]
+ else:
+ max_work_item_sizes = max_work_item_sizes[:3]
+
+ max_workgroup_size = (
+ max(1, int(max_work_item_sizes[0])),
+ max(1, int(max_work_item_sizes[1])),
+ max(1, int(max_work_item_sizes[2])),
+ )
+ max_workgroup_invocations = max(1, _coerce_int(_device_attr(device, "max_work_group_size", 1), 1))
+
+ max_workgroup_count = (2 ** 31 - 1, 2 ** 31 - 1, 2 ** 31 - 1)
+
+ max_storage_buffer_range = max(
+ 1,
+ min(
+ _coerce_int(_device_attr(device, "max_mem_alloc_size", 1), 1),
+ (1 << 31) - 1,
+ ),
+ )
+ max_uniform_buffer_range = max(1, _coerce_int(_device_attr(device, "max_constant_buffer_size", 65536), 65536))
+ uniform_alignment = max(
+ 1,
+ _coerce_int(_device_attr(device, "mem_base_addr_align", 8), 8) // 8,
+ )
+ max_push_constant_size = max(0, _coerce_int(_device_attr(device, "max_parameter_size", 0), 0))
+
+ subgroup_size = _estimate_subgroup_size(
+ entry,
+ device,
+ device_name=device_name,
+ driver_version=driver_version,
+ device_type=device_type,
+ max_workgroup_invocations=max_workgroup_invocations,
+ )
+
+ max_compute_shared_memory_size = max(
+ 1,
+ _coerce_int(_device_attr(device, "local_mem_size", 1), 1),
+ )
+
+ uuid_bytes = _device_uuid(entry, device_name, driver_version)
+
+ devices.append(
+ (
+ 0, # Vulkan variant
+ int(version_major),
+ int(version_minor),
+ int(version_patch),
+ int(driver_version_num),
+ int(vendor_id),
+ int(device_id),
+ int(device_type),
+ str(device_name),
+ 1 if float32_atomic_support else 0,
+ 1 if float32_atomic_support else 0,
+ 1 if float64_support else 0,
+ 1 if float16_support else 0,
+ 1 if int64_support else 0,
+ 1 if int16_support else 0,
+ 1 if int16_support else 0, # storage_buffer_16_bit_access
+ 1 if int16_support else 0, # uniform_and_storage_buffer_16_bit_access
+ 0, # storage_push_constant_16
+ 1 if int16_support else 0, # storage_input_output_16
+ max_workgroup_size,
+ int(max_workgroup_invocations),
+ max_workgroup_count,
+ 8, # max descriptor sets (virtualized for parity)
+ int(max_push_constant_size),
+ int(max_storage_buffer_range),
+ int(max_uniform_buffer_range),
+ int(uniform_alignment),
+ subgroup_size, # subgroup size
+ 0, # subgroup stages
+ 0, # subgroup operations
+ 0, # quad operations in all stages
+ int(max_compute_shared_memory_size),
+ [(1, 0x006)], # compute + transfer queue
+ 1, # scalar block layout equivalent
+ 0, # timeline semaphores equivalent
+ uuid_bytes,
+ )
+ )
+
+ return devices
+
+
+def context_create(device_indicies, queue_families):
+ if not _initialized:
+ init(False, _log_level)
+
+ try:
+ device_ids = [int(x) for x in device_indicies]
+ except Exception:
+ _set_error("context_create expected a list of integer device indices")
+ return 0
+
+ if len(device_ids) != 1:
+ _set_error("OpenCL backend currently supports exactly one device")
+ return 0
+
+ try:
+ normalized_families = [[int(x) for x in family] for family in queue_families]
+ except Exception:
+ _set_error("context_create expected queue_families to be a nested integer list")
+ return 0
+
+ if len(normalized_families) != 1 or len(normalized_families[0]) != 1:
+ _set_error("OpenCL backend currently supports exactly one queue")
+ return 0
+
+ entries = _enumerate_opencl_devices()
+ if len(entries) == 0:
+ if _error_string is None:
+ _set_error("No OpenCL devices were found")
+ return 0
+
+ logical_device_index = int(device_ids[0])
+ if logical_device_index < 0 or logical_device_index >= len(entries):
+ _set_error(
+ f"Invalid OpenCL device index {logical_device_index}. "
+ f"Expected range [0, {len(entries) - 1}]"
+ )
+ return 0
+
+ entry = entries[logical_device_index]
+
+ try:
+ cl_context = cl.Context(devices=[entry.device])
+ queue = cl.CommandQueue(cl_context, device=entry.device)
+ sub_buffer_alignment = max(
+ 1,
+ _coerce_int(_device_attr(entry.device, "mem_base_addr_align", 8), 8) // 8,
+ )
+ ctx = _Context(
+ device_index=logical_device_index,
+ cl_context=cl_context,
+ queues=[queue],
+ queue_count=1,
+ queue_to_device=[0],
+ sub_buffer_alignment=sub_buffer_alignment,
+ submission_events=[[]],
+ stopped=False,
+ )
+ return _new_handle(_contexts, ctx)
+ except Exception as exc:
+ _set_error(f"Failed to create OpenCL context: {exc}")
+ return 0
+
+
+def context_destroy(context):
+ ctx = _contexts.pop(int(context), None)
+ if ctx is None:
+ return
+
+ for queue_events in ctx.submission_events:
+ for event_obj in queue_events:
+ _release_event(event_obj)
+ queue_events.clear()
+
+ for queue in ctx.queues:
+ try:
+ queue.finish()
+ except Exception:
+ pass
+ try:
+ queue.release()
+ except Exception:
+ pass
+
+ try:
+ ctx.cl_context.release()
+ except Exception:
+ pass
+
+
+def context_stop_threads(context):
+ ctx = _contexts.get(int(context))
+ if ctx is not None:
+ ctx.stopped = True
+
+
+def get_error_string():
+ if _error_string is None:
+ return 0
+ return _error_string
+
+
+# --- API: signals ---
+
+
+def signal_wait(signal_ptr, wait_for_timestamp, queue_index):
+ _ = queue_index
+
+ signal_obj = _signals.get(int(signal_ptr))
+ if signal_obj is None:
+ return True
+
+ if not bool(wait_for_timestamp):
+ if signal_obj.event is None:
+ return bool(signal_obj.done)
+ return bool(signal_obj.submitted)
+
+ return _wait_signal(signal_obj)
+
+
+def signal_insert(context, queue_index):
+ ctx = _context_from_handle(int(context))
+ if ctx is None:
+ return 0
+
+ selected = _queue_indices(ctx, int(queue_index))
+ if len(selected) == 0:
+ selected = [0]
+
+ signal = _Signal(context_handle=int(context), queue_index=selected[0], submitted=False, done=False)
+ handle = _new_handle(_signals, signal)
+
+ try:
+ event_obj = _insert_queue_marker_event(ctx.queues[selected[0]])
+ if event_obj is None:
+ ctx.queues[selected[0]].finish()
+ signal.done = True
+ signal.submitted = True
+ else:
+ _record_signal(signal, event_obj)
+ except Exception as exc:
+ _set_error(f"Failed to insert signal: {exc}")
+ return 0
+
+ return handle
+
+
+def signal_destroy(signal_ptr):
+ signal_obj = _signals.pop(int(signal_ptr), None)
+ if signal_obj is None:
+ return
+
+ try:
+ if signal_obj.event is not None:
+ signal_obj.event.release()
+ except Exception:
+ pass
+
+
+# --- API: buffers ---
+
+
+def buffer_create(context, size, per_device):
+ _ = per_device
+
+ ctx = _context_from_handle(int(context))
+ if ctx is None:
+ return 0
+
+ size = int(size)
+ if size <= 0:
+ _set_error("Buffer size must be greater than zero")
+ return 0
+
+ try:
+ cl_buffer = cl.Buffer(ctx.cl_context, cl.mem_flags.READ_WRITE, size=size)
+ signal_handles = [
+ _new_handle(_signals, _Signal(context_handle=int(context), queue_index=i, done=True))
+ for i in range(ctx.queue_count)
+ ]
+ obj = _Buffer(
+ context_handle=int(context),
+ size=size,
+ cl_buffer=cl_buffer,
+ staging_data=[bytearray(size) for _ in range(ctx.queue_count)],
+ signal_handles=signal_handles,
+ )
+ return _new_handle(_buffers, obj)
+ except Exception as exc:
+ _set_error(f"Failed to create OpenCL buffer: {exc}")
+ return 0
+
+
+def buffer_create_external(context, size, device_ptr):
+ _ = context
+ _ = size
+ _ = device_ptr
+ _set_error("OpenCL backend does not support external buffer aliases in MVP")
+ return 0
+
+
+def buffer_destroy(buffer):
+ obj = _buffers.pop(int(buffer), None)
+ if obj is None:
+ return
+
+ for signal_handle in obj.signal_handles:
+ signal_destroy(signal_handle)
+
+ try:
+ obj.cl_buffer.release()
+ except Exception:
+ pass
+
+
+def buffer_get_queue_signal(buffer, queue_index):
+ obj = _buffers.get(int(buffer))
+ if obj is None:
+ return _new_handle(_signals, _Signal(context_handle=0, queue_index=0, done=True))
+
+ queue_index = int(queue_index)
+ if queue_index < 0 or queue_index >= len(obj.signal_handles):
+ queue_index = 0
+
+ return obj.signal_handles[queue_index]
+
+
+def buffer_wait_staging_idle(buffer, queue_index):
+ signal_handle = buffer_get_queue_signal(buffer, queue_index)
+ signal_obj = _signals.get(int(signal_handle))
+ if signal_obj is None:
+ return True
+ return _query_signal(signal_obj)
+
+
+def buffer_write_staging(buffer, queue_index, data, size):
+ obj = _buffers.get(int(buffer))
+ if obj is None:
+ return
+
+ queue_index = int(queue_index)
+ if queue_index < 0 or queue_index >= len(obj.staging_data):
+ return
+
+ payload = _to_bytes(data)
+ size = min(int(size), len(payload), obj.size)
+ if size <= 0:
+ return
+
+ obj.staging_data[queue_index][:size] = payload[:size]
+
+
+def buffer_read_staging(buffer, queue_index, size):
+ obj = _buffers.get(int(buffer))
+ if obj is None:
+ return bytes(int(size))
+
+ queue_index = int(queue_index)
+ if queue_index < 0 or queue_index >= len(obj.staging_data):
+ return bytes(int(size))
+
+ size = max(0, int(size))
+ staging = obj.staging_data[queue_index]
+
+ if size <= len(staging):
+ return bytes(staging[:size])
+
+ return bytes(staging) + bytes(size - len(staging))
+
+
+def buffer_write(buffer, offset, size, index):
+ obj = _buffers.get(int(buffer))
+ if obj is None:
+ return
+
+ ctx = _contexts.get(obj.context_handle)
+ if ctx is None:
+ _set_error(f"Missing context for buffer handle {buffer}")
+ return
+
+ offset = int(offset)
+ size = int(size)
+ if size <= 0 or offset < 0:
+ return
+
+ try:
+ for queue_index in _queue_indices(ctx, int(index), all_on_negative=True):
+ queue = ctx.queues[queue_index]
+ end = min(offset + size, obj.size)
+ copy_size = end - offset
+ if copy_size <= 0:
+ continue
+
+ host_src = np.frombuffer(obj.staging_data[queue_index], dtype=np.uint8, count=copy_size)
+ event_obj = cl.enqueue_copy(
+ queue,
+ obj.cl_buffer,
+ host_src,
+ dst_offset=offset,
+ is_blocking=False,
+ )
+
+ signal_obj = _signals.get(obj.signal_handles[queue_index])
+ if signal_obj is not None:
+ _record_signal(signal_obj, event_obj)
+ except Exception as exc:
+ _set_error(f"Failed to write OpenCL buffer: {exc}")
+
+
+def buffer_read(buffer, offset, size, index):
+ obj = _buffers.get(int(buffer))
+ if obj is None:
+ return
+
+ ctx = _contexts.get(obj.context_handle)
+ if ctx is None:
+ _set_error(f"Missing context for buffer handle {buffer}")
+ return
+
+ queue_index = int(index)
+ if queue_index < 0 or queue_index >= ctx.queue_count:
+ _set_error(f"Invalid queue index {queue_index} for buffer read")
+ return
+
+ offset = int(offset)
+ size = int(size)
+ if size <= 0 or offset < 0:
+ return
+
+ try:
+ queue = ctx.queues[queue_index]
+ end = min(offset + size, obj.size)
+ copy_size = end - offset
+ if copy_size <= 0:
+ return
+
+ host_dst = np.frombuffer(obj.staging_data[queue_index], dtype=np.uint8, count=copy_size)
+ event_obj = cl.enqueue_copy(
+ queue,
+ host_dst,
+ obj.cl_buffer,
+ src_offset=offset,
+ is_blocking=False,
+ )
+
+ signal_obj = _signals.get(obj.signal_handles[queue_index])
+ if signal_obj is not None:
+ _record_signal(signal_obj, event_obj)
+ except Exception as exc:
+ _set_error(f"Failed to read OpenCL buffer: {exc}")
+
+
+# --- API: command lists ---
+
+
+def command_list_create(context):
+ if int(context) not in _contexts:
+ _set_error("Invalid context handle for command_list_create")
+ return 0
+
+ return _new_handle(_command_lists, _CommandList(context_handle=int(context)))
+
+
+def command_list_destroy(command_list):
+ _command_lists.pop(int(command_list), None)
+
+
+def command_list_get_instance_size(command_list):
+ obj = _command_lists.get(int(command_list))
+ if obj is None:
+ return 0
+
+ return int(sum(int(command.pc_size) for command in obj.commands))
+
+
+def command_list_reset(command_list):
+ obj = _command_lists.get(int(command_list))
+ if obj is None:
+ return
+
+ obj.commands = []
+
+
+def command_list_submit(command_list, data, instance_count, index):
+ obj = _command_lists.get(int(command_list))
+ if obj is None:
+ return True
+
+ ctx = _contexts.get(obj.context_handle)
+ if ctx is None:
+ _set_error(f"Missing context for command list {command_list}")
+ return True
+
+ instance_count = int(instance_count)
+ if instance_count <= 0:
+ return True
+
+ instance_size = command_list_get_instance_size(command_list)
+ payload = _to_bytes(data)
+ expected_payload_size = int(instance_size) * int(instance_count)
+
+ if expected_payload_size == 0:
+ if len(payload) != 0:
+ _set_error(
+ f"Unexpected push-constant data for command list with instance_size=0 "
+ f"(got {len(payload)} bytes)."
+ )
+ return True
+ elif len(payload) != expected_payload_size:
+ _set_error(
+ f"Push-constant data size mismatch. Expected {expected_payload_size} bytes "
+ f"(instance_size={instance_size}, instance_count={instance_count}) but got {len(payload)} bytes."
+ )
+ return True
+
+ queue_targets = _queue_indices(ctx, int(index), all_on_negative=True)
+ if len(queue_targets) == 0:
+ queue_targets = [0]
+
+ try:
+ for queue_index in queue_targets:
+ if not _reserve_submission_slot(ctx, queue_index):
+ return False
+
+ for queue_index in queue_targets:
+ queue = ctx.queues[queue_index]
+ for instance_index in range(instance_count):
+ instance_base_offset = instance_index * instance_size
+ per_instance_offset = 0
+ for command in obj.commands:
+ plan = _compute_plans.get(command.plan_handle)
+ if plan is None:
+ raise RuntimeError(f"Invalid compute plan handle {command.plan_handle}")
+
+ descriptor_set = None
+ if command.descriptor_set_handle != 0:
+ descriptor_set = _descriptor_sets.get(command.descriptor_set_handle)
+ if descriptor_set is None:
+ raise RuntimeError(
+ f"Invalid descriptor set handle {command.descriptor_set_handle}"
+ )
+
+ command_pc_size = int(command.pc_size)
+ pc_payload = b""
+ if command_pc_size > 0 and len(payload) > 0:
+ pc_start = instance_base_offset + per_instance_offset
+ pc_end = pc_start + command_pc_size
+ pc_payload = payload[pc_start:pc_end]
+
+ args, _keepalive = _build_kernel_args(
+ plan,
+ descriptor_set,
+ ctx,
+ pc_payload,
+ )
+
+ for arg_index, arg_value in enumerate(args):
+ plan.kernel.set_arg(arg_index, arg_value)
+
+ local_x = max(1, int(plan.local_size[0]))
+ local_y = max(1, int(plan.local_size[1]))
+ local_z = max(1, int(plan.local_size[2]))
+ _validate_local_size_for_enqueue(ctx, (local_x, local_y, local_z))
+
+ blocks_x = max(1, int(command.blocks[0]))
+ blocks_y = max(1, int(command.blocks[1]))
+ blocks_z = max(1, int(command.blocks[2]))
+
+ global_size = (
+ blocks_x * local_x,
+ blocks_y * local_y,
+ blocks_z * local_z,
+ )
+
+ cl.enqueue_nd_range_kernel(
+ queue,
+ plan.kernel,
+ global_size,
+ (local_x, local_y, local_z),
+ )
+
+ per_instance_offset += command_pc_size
+
+ if per_instance_offset != instance_size:
+ raise RuntimeError(
+ f"Internal command list size mismatch: computed {per_instance_offset} bytes, "
+ f"expected {instance_size} bytes."
+ )
+
+ _track_submission_completion(ctx, queue_index)
+ except Exception as exc:
+ _set_error(f"Failed to submit OpenCL command list: {exc}")
+
+ return True
+
+
+# --- API: descriptor sets ---
+
+
+def descriptor_set_create(plan):
+ if int(plan) not in _compute_plans:
+ _set_error("Invalid compute plan handle for descriptor_set_create")
+ return 0
+
+ return _new_handle(_descriptor_sets, _DescriptorSet(plan_handle=int(plan)))
+
+
+def descriptor_set_destroy(descriptor_set):
+ _descriptor_sets.pop(int(descriptor_set), None)
+
+
+def descriptor_set_write_buffer(
+ descriptor_set,
+ binding,
+ object,
+ offset,
+ range,
+ uniform,
+ read_access,
+ write_access,
+):
+ ds = _descriptor_sets.get(int(descriptor_set))
+ if ds is None:
+ _set_error("Invalid descriptor set handle for descriptor_set_write_buffer")
+ return
+
+ ds.buffer_bindings[int(binding)] = (
+ int(object),
+ int(offset),
+ int(range),
+ int(uniform),
+ int(read_access),
+ int(write_access),
+ )
+
+
+def descriptor_set_write_image(
+ descriptor_set,
+ binding,
+ object,
+ sampler_obj,
+ read_access,
+ write_access,
+):
+ _ = descriptor_set
+ _ = binding
+ _ = object
+ _ = sampler_obj
+ _ = read_access
+ _ = write_access
+ _set_error("OpenCL backend does not support image objects in MVP")
+
+
+# --- API: compute stage ---
+
+
+def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name):
+ ctx = _context_from_handle(int(context))
+ if ctx is None:
+ return 0
+
+ source_bytes = _to_bytes(shader_source)
+ shader_name_bytes = _to_bytes(shader_name)
+ source_text = source_bytes.decode("utf-8", errors="replace")
+ pc_size = int(pc_size)
+
+ try:
+ program = cl.Program(ctx.cl_context, source_text).build()
+ kernel = cl.Kernel(program, "vkdispatch_main")
+ except Exception as exc:
+ kernel_name = shader_name_bytes.decode("utf-8", errors="replace")
+ _set_error(f"Failed to compile OpenCL kernel '{kernel_name}': {exc}")
+ return 0
+
+ try:
+ params = _parse_kernel_params(source_text)
+ local_size = _parse_local_size(source_text)
+ pc_layout = _build_push_constant_layout(source_text, pc_size)
+ except Exception as exc:
+ _set_error(f"Failed to parse OpenCL kernel metadata: {exc}")
+ return 0
+
+ plan = _ComputePlan(
+ context_handle=int(context),
+ shader_source=source_bytes,
+ bindings=[int(x) for x in bindings],
+ shader_name=shader_name_bytes,
+ program=program,
+ kernel=kernel,
+ local_size=local_size,
+ params=params,
+ pc_size=pc_size,
+ pc_layout=pc_layout,
+ )
+
+ return _new_handle(_compute_plans, plan)
+
+
+def stage_compute_plan_destroy(plan):
+ plan_obj = _compute_plans.pop(int(plan), None)
+ if plan_obj is None:
+ return
+
+ try:
+ plan_obj.kernel.release()
+ except Exception:
+ pass
+
+ try:
+ plan_obj.program.release()
+ except Exception:
+ pass
+
+
+def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z):
+ cl_obj = _command_lists.get(int(command_list))
+ cp_obj = _compute_plans.get(int(plan))
+ if cl_obj is None or cp_obj is None:
+ _set_error("Invalid command list or compute plan handle for stage_compute_record")
+ return
+
+ cl_obj.commands.append(
+ _CommandRecord(
+ plan_handle=int(plan),
+ descriptor_set_handle=int(descriptor_set),
+ blocks=(int(blocks_x), int(blocks_y), int(blocks_z)),
+ pc_size=int(cp_obj.pc_size),
+ )
+ )
+
+
+# --- API: images/samplers (MVP unsupported) ---
+
+
+def image_create(context, extent, layers, format, type, view_type, generate_mips):
+ _ = context
+ _ = extent
+ _ = layers
+ _ = format
+ _ = type
+ _ = view_type
+ _ = generate_mips
+ _set_error("OpenCL backend does not support image objects in MVP")
+ return 0
+
+
+def image_destroy(image):
+ _images.pop(int(image), None)
+
+
+def image_create_sampler(
+ context,
+ mag_filter,
+ min_filter,
+ mip_mode,
+ address_mode,
+ mip_lod_bias,
+ min_lod,
+ max_lod,
+ border_color,
+):
+ _ = context
+ _ = mag_filter
+ _ = min_filter
+ _ = mip_mode
+ _ = address_mode
+ _ = mip_lod_bias
+ _ = min_lod
+ _ = max_lod
+ _ = border_color
+ _set_error("OpenCL backend does not support image samplers in MVP")
+ return 0
+
+
+def image_destroy_sampler(sampler):
+ _samplers.pop(int(sampler), None)
+
+
+def image_write(image, data, offset, extent, baseLayer, layerCount, device_index):
+ _ = image
+ _ = data
+ _ = offset
+ _ = extent
+ _ = baseLayer
+ _ = layerCount
+ _ = device_index
+ _set_error("OpenCL backend does not support image writes in MVP")
+
+
+def image_format_block_size(format):
+ return int(_IMAGE_BLOCK_SIZES.get(int(format), 4))
+
+
+def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index):
+ _ = image
+ _ = offset
+ _ = extent
+ _ = baseLayer
+ _ = layerCount
+ _ = device_index
+ _set_error("OpenCL backend does not support image reads in MVP")
+ return bytes(max(0, int(out_size)))
+
+
+# --- API: FFT stage (MVP unsupported) ---
+
+
+def stage_fft_plan_create(
+ context,
+ dims,
+ axes,
+ buffer_size,
+ do_r2c,
+ normalize,
+ pad_left,
+ pad_right,
+ frequency_zeropadding,
+ kernel_num,
+ kernel_convolution,
+ conjugate_convolution,
+ convolution_features,
+ input_buffer_size,
+ num_batches,
+ single_kernel_multiple_batches,
+ keep_shader_code,
+):
+ _ = context
+ _ = dims
+ _ = axes
+ _ = buffer_size
+ _ = do_r2c
+ _ = normalize
+ _ = pad_left
+ _ = pad_right
+ _ = frequency_zeropadding
+ _ = kernel_num
+ _ = kernel_convolution
+ _ = conjugate_convolution
+ _ = convolution_features
+ _ = input_buffer_size
+ _ = num_batches
+ _ = single_kernel_multiple_batches
+ _ = keep_shader_code
+ _set_error("OpenCL backend does not support FFT plans in MVP")
+ return 0
+
+
+def stage_fft_plan_destroy(plan):
+ _fft_plans.pop(int(plan), None)
+
+
+def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer):
+ _ = command_list
+ _ = plan
+ _ = buffer
+ _ = inverse
+ _ = kernel
+ _ = input_buffer
+ _set_error("OpenCL backend does not support FFT stages in MVP")
+
+
+__all__ = [
+ "LOG_LEVEL_VERBOSE",
+ "LOG_LEVEL_INFO",
+ "LOG_LEVEL_WARNING",
+ "LOG_LEVEL_ERROR",
+ "DESCRIPTOR_TYPE_STORAGE_BUFFER",
+ "DESCRIPTOR_TYPE_STORAGE_IMAGE",
+ "DESCRIPTOR_TYPE_UNIFORM_BUFFER",
+ "DESCRIPTOR_TYPE_UNIFORM_IMAGE",
+ "DESCRIPTOR_TYPE_SAMPLER",
+ "init",
+ "log",
+ "set_log_level",
+ "get_devices",
+ "context_create",
+ "signal_wait",
+ "signal_insert",
+ "signal_destroy",
+ "context_destroy",
+ "get_error_string",
+ "context_stop_threads",
+ "buffer_create",
+ "buffer_create_external",
+ "buffer_destroy",
+ "buffer_get_queue_signal",
+ "buffer_wait_staging_idle",
+ "buffer_write_staging",
+ "buffer_read_staging",
+ "buffer_write",
+ "buffer_read",
+ "command_list_create",
+ "command_list_destroy",
+ "command_list_get_instance_size",
+ "command_list_reset",
+ "command_list_submit",
+ "descriptor_set_create",
+ "descriptor_set_destroy",
+ "descriptor_set_write_buffer",
+ "descriptor_set_write_image",
+ "image_create",
+ "image_destroy",
+ "image_create_sampler",
+ "image_destroy_sampler",
+ "image_write",
+ "image_format_block_size",
+ "image_read",
+ "stage_compute_plan_create",
+ "stage_compute_plan_destroy",
+ "stage_compute_record",
+ "stage_fft_plan_create",
+ "stage_fft_plan_destroy",
+ "stage_fft_record",
+]
diff --git a/vkdispatch/backends/pycuda_native.py b/vkdispatch/backends/pycuda_native.py
deleted file mode 100644
index 5bf4068d..00000000
--- a/vkdispatch/backends/pycuda_native.py
+++ /dev/null
@@ -1,1371 +0,0 @@
-"""PyCUDA-backed runtime shim mirroring the vkdispatch_native API surface.
-
-This module intentionally matches the function names exposed by the Cython
-extension so existing Python runtime objects can call into either backend.
-"""
-
-from __future__ import annotations
-
-from contextlib import contextmanager
-from dataclasses import dataclass, field
-import hashlib
-import re
-from typing import Dict, List, Optional, Tuple
-
-try:
- import numpy as np
- import pycuda.driver as cuda
- from pycuda.compiler import SourceModule
-except Exception as exc: # pragma: no cover - import failure path
- raise ImportError(
- "The PyCUDA backend requires both 'pycuda' and 'numpy' to be installed."
- ) from exc
-
-
-# Log level constants mirrored from native bindings.
-LOG_LEVEL_VERBOSE = 0
-LOG_LEVEL_INFO = 1
-LOG_LEVEL_WARNING = 2
-LOG_LEVEL_ERROR = 3
-
-# Descriptor type enum values mirrored from vkdispatch_native/stages_extern.pxd.
-DESCRIPTOR_TYPE_STORAGE_BUFFER = 1
-DESCRIPTOR_TYPE_STORAGE_IMAGE = 2
-DESCRIPTOR_TYPE_UNIFORM_BUFFER = 3
-DESCRIPTOR_TYPE_UNIFORM_IMAGE = 4
-DESCRIPTOR_TYPE_SAMPLER = 5
-
-# Image format block sizes for formats exposed in vkdispatch.base.image.image_format.
-_IMAGE_BLOCK_SIZES = {
- 13: 1,
- 14: 1,
- 20: 2,
- 21: 2,
- 27: 3,
- 28: 3,
- 41: 4,
- 42: 4,
- 74: 2,
- 75: 2,
- 76: 2,
- 81: 4,
- 82: 4,
- 83: 4,
- 88: 6,
- 89: 6,
- 90: 6,
- 95: 8,
- 96: 8,
- 97: 8,
- 98: 4,
- 99: 4,
- 100: 4,
- 101: 8,
- 102: 8,
- 103: 8,
- 104: 12,
- 105: 12,
- 106: 12,
- 107: 16,
- 108: 16,
- 109: 16,
- 110: 8,
- 111: 8,
- 112: 8,
- 113: 16,
- 114: 16,
- 115: 16,
- 116: 24,
- 117: 24,
- 118: 24,
- 119: 32,
- 120: 32,
- 121: 32,
-}
-
-_LOCAL_X_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_X\s+(\d+)")
-_LOCAL_Y_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Y\s+(\d+)")
-_LOCAL_Z_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Z\s+(\d+)")
-_KERNEL_SIGNATURE_RE = re.compile(r"vkdispatch_main\s*\(([^)]*)\)", re.S)
-_BINDING_PARAM_RE = re.compile(r"vkdispatch_binding_(\d+)_ptr$")
-_SAMPLER_PARAM_RE = re.compile(r"vkdispatch_sampler_(\d+)$")
-
-
-# --- Runtime state ---
-
-_initialized = False
-_debug_mode = False
-_log_level = LOG_LEVEL_WARNING
-_error_string: Optional[str] = None
-_next_handle = 1
-
-_contexts: Dict[int, "_Context"] = {}
-_signals: Dict[int, "_Signal"] = {}
-_buffers: Dict[int, "_Buffer"] = {}
-_command_lists: Dict[int, "_CommandList"] = {}
-_compute_plans: Dict[int, "_ComputePlan"] = {}
-_descriptor_sets: Dict[int, "_DescriptorSet"] = {}
-_images: Dict[int, object] = {}
-_samplers: Dict[int, object] = {}
-_fft_plans: Dict[int, object] = {}
-
-
-# --- Internal objects ---
-
-
-@dataclass
-class _Signal:
- context_handle: int
- queue_index: int
- event: Optional["cuda.Event"] = None
- submitted: bool = True
- done: bool = True
-
-
-@dataclass
-class _Context:
- device_index: int
- pycuda_context: "cuda.Context"
- streams: List["cuda.Stream"]
- queue_count: int
- queue_to_device: List[int]
- stopped: bool = False
-
-
-@dataclass
-class _Buffer:
- context_handle: int
- size: int
- device_allocation: "cuda.DeviceAllocation"
- staging_data: List[object]
- signal_handles: List[int]
-
-
-@dataclass
-class _CommandRecord:
- plan_handle: int
- descriptor_set_handle: int
- blocks: Tuple[int, int, int]
- pc_size: int
-
-
-@dataclass
-class _CommandList:
- context_handle: int
- commands: List[_CommandRecord] = field(default_factory=list)
- compute_instance_size: int = 0
- pc_scratch: Optional["cuda.DeviceAllocation"] = None
- pc_scratch_size: int = 0
-
-
-@dataclass
-class _KernelParam:
- kind: str
- binding: Optional[int]
- raw_name: str
-
-
-@dataclass
-class _ComputePlan:
- context_handle: int
- shader_source: bytes
- bindings: List[int]
- pc_size: int
- shader_name: bytes
- module: SourceModule
- function: object
- local_size: Tuple[int, int, int]
- params: List[_KernelParam]
-
-
-@dataclass
-class _DescriptorSet:
- plan_handle: int
- buffer_bindings: Dict[int, Tuple[int, int, int, int, int, int]] = field(default_factory=dict)
- image_bindings: Dict[int, Tuple[int, int, int, int]] = field(default_factory=dict)
-
-
-@dataclass
-class _ResolvedLaunch:
- plan: _ComputePlan
- blocks: Tuple[int, int, int]
- pc_offset: int
- pc_size: int
- args: Tuple[object, ...]
- pc_scratch: Optional["cuda.DeviceAllocation"] = None
-
-
-# --- Helper utilities ---
-
-
-def _new_handle(registry: Dict[int, object], obj: object) -> int:
- global _next_handle
- handle = _next_handle
- _next_handle += 1
- registry[handle] = obj
- return handle
-
-
-def _to_bytes(value) -> bytes:
- if value is None:
- return b""
- if isinstance(value, bytes):
- return value
- if isinstance(value, bytearray):
- return bytes(value)
- if isinstance(value, memoryview):
- return value.tobytes()
- return bytes(value)
-
-
-def _set_error(message: str) -> None:
- global _error_string
- _error_string = str(message)
-
-
-def _clear_error() -> None:
- global _error_string
- _error_string = None
-
-
-def _queue_indices(ctx: _Context, queue_index: int, *, all_on_negative: bool = False) -> List[int]:
- if ctx.queue_count <= 0:
- return []
-
- if queue_index is None:
- return [0]
-
- queue_index = int(queue_index)
-
- if all_on_negative and queue_index < 0:
- return list(range(ctx.queue_count))
-
- if queue_index == -1:
- return [0]
-
- if 0 <= queue_index < ctx.queue_count:
- return [queue_index]
-
- return []
-
-
-def _context_from_handle(context_handle: int) -> Optional[_Context]:
- ctx = _contexts.get(int(context_handle))
- if ctx is None:
- _set_error(f"Invalid context handle {context_handle}")
- return ctx
-
-
-@contextmanager
-def _activate_context(ctx: _Context):
- ctx.pycuda_context.push()
- try:
- yield
- finally:
- cuda.Context.pop()
-
-
-def _record_signal(signal: _Signal, stream: "cuda.Stream") -> None:
- signal.submitted = True
- signal.done = False
- if signal.event is None:
- signal.event = cuda.Event()
- signal.event.record(stream)
-
-
-def _query_signal(signal: _Signal) -> bool:
- if signal.event is None:
- return bool(signal.done)
-
- try:
- done = signal.event.query()
- except Exception:
- return False
-
- signal.done = bool(done)
- return signal.done
-
-
-def _allocate_staging_storage(size: int):
- try:
- # Pagelocked host memory improves async HtoD/DtoH throughput and overlap.
- return cuda.pagelocked_empty(int(size), np.uint8)
- except Exception:
- return bytearray(int(size))
-
-
-def _parse_local_size(source: str) -> Tuple[int, int, int]:
- x_match = _LOCAL_X_RE.search(source)
- y_match = _LOCAL_Y_RE.search(source)
- z_match = _LOCAL_Z_RE.search(source)
-
- x = int(x_match.group(1)) if x_match else 1
- y = int(y_match.group(1)) if y_match else 1
- z = int(z_match.group(1)) if z_match else 1
-
- return (x, y, z)
-
-
-def _parse_kernel_params(source: str) -> List[_KernelParam]:
- signature_match = _KERNEL_SIGNATURE_RE.search(source)
- if signature_match is None:
- raise RuntimeError("Could not find vkdispatch_main kernel signature in CUDA source")
-
- signature_blob = signature_match.group(1).strip()
- if len(signature_blob) == 0:
- return []
-
- params: List[_KernelParam] = []
-
- for raw_decl in [part.strip() for part in signature_blob.split(",") if len(part.strip()) > 0]:
- name_match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)\s*$", raw_decl)
- if name_match is None:
- raise RuntimeError(f"Unable to parse kernel parameter declaration '{raw_decl}'")
-
- param_name = name_match.group(1)
-
- if param_name == "vkdispatch_uniform_ptr":
- params.append(_KernelParam("uniform", 0, param_name))
- continue
-
- if param_name == "vkdispatch_pc_ptr":
- params.append(_KernelParam("push_constant", None, param_name))
- continue
-
- binding_match = _BINDING_PARAM_RE.match(param_name)
- if binding_match is not None:
- params.append(_KernelParam("storage", int(binding_match.group(1)), param_name))
- continue
-
- sampler_match = _SAMPLER_PARAM_RE.match(param_name)
- if sampler_match is not None:
- params.append(_KernelParam("sampler", int(sampler_match.group(1)), param_name))
- continue
-
- params.append(_KernelParam("unknown", None, param_name))
-
- return params
-
-
-def _resolve_buffer_pointer(descriptor_set: _DescriptorSet, binding: int) -> int:
- binding_info = descriptor_set.buffer_bindings.get(binding)
- if binding_info is None:
- raise RuntimeError(f"Missing descriptor buffer binding {binding}")
-
- buffer_handle, offset, _range, _uniform, _read_access, _write_access = binding_info
-
- buffer_obj = _buffers.get(int(buffer_handle))
- if buffer_obj is None:
- raise RuntimeError(f"Invalid buffer handle {buffer_handle} for binding {binding}")
-
- return int(buffer_obj.device_allocation) + int(offset)
-
-
-def _ensure_pc_scratch(command_list: _CommandList, required_size: int) -> "cuda.DeviceAllocation":
- if required_size <= 0:
- required_size = 1
-
- if command_list.pc_scratch is not None and command_list.pc_scratch_size >= required_size:
- return command_list.pc_scratch
-
- command_list.pc_scratch = cuda.mem_alloc(required_size)
- command_list.pc_scratch_size = required_size
- return command_list.pc_scratch
-
-
-def _build_kernel_args(
- plan: _ComputePlan,
- descriptor_set: Optional[_DescriptorSet],
- command_list: _CommandList,
- pc_data: bytes,
- stream: "cuda.Stream",
-) -> List[object]:
- args: List[object] = []
-
- for param in plan.params:
- if param.kind == "uniform":
- if descriptor_set is None:
- raise RuntimeError("Kernel requires a descriptor set but none was provided")
-
- args.append(np.uintp(_resolve_buffer_pointer(descriptor_set, 0)))
- continue
-
- if param.kind == "storage":
- if descriptor_set is None:
- raise RuntimeError("Kernel requires a descriptor set but none was provided")
-
- if param.binding is None:
- raise RuntimeError("Storage parameter has no binding index")
-
- args.append(np.uintp(_resolve_buffer_pointer(descriptor_set, param.binding)))
- continue
-
- if param.kind == "push_constant":
- pc_scratch = _ensure_pc_scratch(command_list, len(pc_data))
-
- if len(pc_data) > 0:
- cuda.memcpy_htod_async(pc_scratch, pc_data, stream)
-
- args.append(np.uintp(int(pc_scratch)))
- continue
-
- if param.kind == "sampler":
- raise RuntimeError("PyCUDA backend does not support sampled image bindings yet")
-
- raise RuntimeError(
- f"Unsupported kernel parameter '{param.raw_name}'. "
- "Expected vkdispatch_uniform_ptr / vkdispatch_binding__ptr / vkdispatch_pc_ptr."
- )
-
- return args
-
-
-def _build_kernel_args_template(
- plan: _ComputePlan,
- descriptor_set: Optional[_DescriptorSet],
- command_list: _CommandList,
- pc_size: int,
-) -> Tuple[Tuple[object, ...], Optional["cuda.DeviceAllocation"]]:
- args: List[object] = []
- pc_scratch: Optional["cuda.DeviceAllocation"] = None
-
- for param in plan.params:
- if param.kind == "uniform":
- if descriptor_set is None:
- raise RuntimeError("Kernel requires a descriptor set but none was provided")
-
- args.append(np.uintp(_resolve_buffer_pointer(descriptor_set, 0)))
- continue
-
- if param.kind == "storage":
- if descriptor_set is None:
- raise RuntimeError("Kernel requires a descriptor set but none was provided")
-
- if param.binding is None:
- raise RuntimeError("Storage parameter has no binding index")
-
- args.append(np.uintp(_resolve_buffer_pointer(descriptor_set, param.binding)))
- continue
-
- if param.kind == "push_constant":
- if pc_scratch is None:
- pc_scratch = _ensure_pc_scratch(command_list, int(pc_size))
- args.append(np.uintp(int(pc_scratch)))
- continue
-
- if param.kind == "sampler":
- raise RuntimeError("PyCUDA backend does not support sampled image bindings yet")
-
- raise RuntimeError(
- f"Unsupported kernel parameter '{param.raw_name}'. "
- "Expected vkdispatch_uniform_ptr / vkdispatch_binding__ptr / vkdispatch_pc_ptr."
- )
-
- return tuple(args), pc_scratch
-
-
-# --- API: context/init/logging ---
-
-
-def init(debug, log_level):
- global _initialized, _debug_mode, _log_level
-
- _debug_mode = bool(debug)
- _log_level = int(log_level)
- _clear_error()
-
- if _initialized:
- return
-
- cuda.init()
- _initialized = True
-
-
-def log(log_level, text, file_str, line_str):
- _ = log_level
- _ = text
- _ = file_str
- _ = line_str
-
-
-def set_log_level(log_level):
- global _log_level
- _log_level = int(log_level)
-
-
-def get_devices():
- if not _initialized:
- init(False, _log_level)
-
- try:
- device_count = cuda.Device.count()
- except Exception as exc:
- _set_error(f"Failed to enumerate CUDA devices: {exc}")
- return []
-
- driver_version = 0
- try:
- driver_version = int(cuda.get_driver_version())
- except Exception:
- driver_version = 0
-
- devices = []
-
- for index in range(device_count):
- dev = cuda.Device(index)
- attrs = dev.get_attributes()
- cc_major, cc_minor = dev.compute_capability()
- total_memory = int(dev.total_memory())
-
- max_workgroup_size = (
- int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_X, 1024)),
- int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_Y, 1024)),
- int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_Z, 64)),
- )
-
- max_workgroup_count = (
- int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_X, 65535)),
- int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_Y, 65535)),
- int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_Z, 65535)),
- )
-
- subgroup_size = int(attrs.get(cuda.device_attribute.WARP_SIZE, 32))
- max_shared_memory = int(
- attrs.get(cuda.device_attribute.MAX_SHARED_MEMORY_PER_BLOCK, 48 * 1024)
- )
-
- try:
- bus_id = str(dev.pci_bus_id())
- except Exception:
- bus_id = f"cuda-device-{index}"
-
- uuid_bytes = hashlib.md5(bus_id.encode("utf-8")).digest()
-
- devices.append(
- (
- 0, # Vulkan variant
- int(cc_major), # major
- int(cc_minor), # minor
- 0, # patch
- driver_version,
- 0, # vendor id unknown in this API layer
- index, # device id
- 2, # discrete gpu
- str(dev.name()),
- 1, # shader_buffer_float32_atomics
- 1, # shader_buffer_float32_atomic_add
- 1, # float64 support
- 1 if (cc_major > 5 or (cc_major == 5 and cc_minor >= 3)) else 0, # float16 support
- 1, # int64
- 1, # int16
- 1, # storage_buffer_16_bit_access
- 1, # uniform_and_storage_buffer_16_bit_access
- 1, # storage_push_constant_16
- 1, # storage_input_output_16
- max_workgroup_size,
- int(attrs.get(cuda.device_attribute.MAX_THREADS_PER_BLOCK, 1024)),
- max_workgroup_count,
- 8, # max descriptor sets (virtualized for parity)
- 4096, # max push constant size
- min(total_memory, (1 << 31) - 1),
- 65536,
- 16,
- subgroup_size,
- 0x7FFFFFFF, # supported stages (virtualized for parity)
- 0x7FFFFFFF, # supported operations (virtualized for parity)
- 1,
- max_shared_memory,
- [(1, 0x002)], # compute queue
- 1, # scalar block layout
- 1, # timeline semaphores equivalent
- uuid_bytes,
- )
- )
-
- return devices
-
-
-def context_create(device_indicies, queue_families):
- if not _initialized:
- init(False, _log_level)
-
- try:
- device_ids = [int(x) for x in device_indicies]
- except Exception:
- _set_error("context_create expected a list of integer device indices")
- return 0
-
- if len(device_ids) != 1:
- _set_error("PyCUDA backend currently supports exactly one device")
- return 0
-
- if len(queue_families) != 1 or len(queue_families[0]) != 1:
- _set_error("PyCUDA backend currently supports exactly one queue")
- return 0
-
- device_index = device_ids[0]
-
- pycuda_context = None
- context_pushed = False
-
- try:
- if device_index < 0 or device_index >= cuda.Device.count():
- _set_error(f"Invalid CUDA device index {device_index}")
- return 0
-
- dev = cuda.Device(device_index)
- pycuda_context = dev.make_context()
- context_pushed = True
- stream = cuda.Stream()
-
- ctx = _Context(
- device_index=device_index,
- pycuda_context=pycuda_context,
- streams=[stream],
- queue_count=1,
- queue_to_device=[0],
- stopped=False,
- )
- handle = _new_handle(_contexts, ctx)
-
- # Leave no context current after creation.
- cuda.Context.pop()
- context_pushed = False
- return handle
- except Exception as exc:
- if context_pushed:
- try:
- cuda.Context.pop()
- except Exception:
- pass
-
- if pycuda_context is not None:
- try:
- pycuda_context.detach()
- except Exception:
- pass
-
- _set_error(f"Failed to create PyCUDA context: {exc}")
- return 0
-
-
-def context_destroy(context):
- ctx = _contexts.pop(int(context), None)
- if ctx is None:
- return
-
- try:
- with _activate_context(ctx):
- for stream in ctx.streams:
- stream.synchronize()
- except Exception:
- pass
-
- try:
- ctx.pycuda_context.detach()
- except Exception:
- pass
-
-
-def context_stop_threads(context):
- ctx = _contexts.get(int(context))
- if ctx is not None:
- ctx.stopped = True
-
-
-def get_error_string():
- if _error_string is None:
- return 0
- return _error_string
-
-
-# --- API: signals ---
-
-
-def signal_wait(signal_ptr, wait_for_timestamp, queue_index):
- signal_obj = _signals.get(int(signal_ptr))
- if signal_obj is None:
- return True
-
- if not bool(wait_for_timestamp):
- # PyCUDA records signals synchronously on submission; host-side "recorded" waits
- # should therefore complete immediately once an event exists.
- if signal_obj.event is None:
- return bool(signal_obj.done)
- return bool(signal_obj.submitted)
-
- if signal_obj.done:
- return True
-
- if signal_obj.event is None:
- return bool(signal_obj.done)
-
- ctx = _contexts.get(signal_obj.context_handle)
- if ctx is None:
- return _query_signal(signal_obj)
-
- try:
- with _activate_context(ctx):
- signal_obj.event.synchronize()
- signal_obj.done = True
- return True
- except Exception:
- return _query_signal(signal_obj)
-
-
-def signal_insert(context, queue_index):
- ctx = _context_from_handle(int(context))
- if ctx is None:
- return 0
-
- selected = _queue_indices(ctx, int(queue_index))
- if len(selected) == 0:
- selected = [0]
-
- signal = _Signal(context_handle=int(context), queue_index=selected[0], submitted=False, done=False)
- handle = _new_handle(_signals, signal)
-
- try:
- with _activate_context(ctx):
- _record_signal(signal, ctx.streams[selected[0]])
- except Exception as exc:
- _set_error(f"Failed to insert signal: {exc}")
- return 0
-
- return handle
-
-
-def signal_destroy(signal_ptr):
- _signals.pop(int(signal_ptr), None)
-
-
-# --- API: buffers ---
-
-
-def buffer_create(context, size, per_device):
- _ = per_device
-
- ctx = _context_from_handle(int(context))
- if ctx is None:
- return 0
-
- size = int(size)
- if size <= 0:
- _set_error("Buffer size must be greater than zero")
- return 0
-
- try:
- with _activate_context(ctx):
- allocation = cuda.mem_alloc(size)
-
- signal_handles = [
- _new_handle(_signals, _Signal(context_handle=int(context), queue_index=i, done=True))
- for i in range(ctx.queue_count)
- ]
-
- obj = _Buffer(
- context_handle=int(context),
- size=size,
- device_allocation=allocation,
- staging_data=[_allocate_staging_storage(size) for _ in range(ctx.queue_count)],
- signal_handles=signal_handles,
- )
- return _new_handle(_buffers, obj)
- except Exception as exc:
- _set_error(f"Failed to create CUDA buffer: {exc}")
- return 0
-
-
-def buffer_destroy(buffer):
- obj = _buffers.pop(int(buffer), None)
- if obj is None:
- return
-
- for signal_handle in obj.signal_handles:
- _signals.pop(signal_handle, None)
-
- ctx = _contexts.get(obj.context_handle)
- if ctx is None:
- return
-
- try:
- with _activate_context(ctx):
- obj.device_allocation.free()
- except Exception:
- pass
-
-
-def buffer_get_queue_signal(buffer, queue_index):
- obj = _buffers.get(int(buffer))
- if obj is None:
- return _new_handle(_signals, _Signal(context_handle=0, queue_index=0, done=True))
-
- queue_index = int(queue_index)
- if queue_index < 0 or queue_index >= len(obj.signal_handles):
- queue_index = 0
-
- return obj.signal_handles[queue_index]
-
-
-def buffer_wait_staging_idle(buffer, queue_index):
- signal_handle = buffer_get_queue_signal(buffer, queue_index)
- signal_obj = _signals.get(int(signal_handle))
- if signal_obj is None:
- return True
- return _query_signal(signal_obj)
-
-
-def buffer_write_staging(buffer, queue_index, data, size):
- obj = _buffers.get(int(buffer))
- if obj is None:
- return
-
- queue_index = int(queue_index)
- if queue_index < 0 or queue_index >= len(obj.staging_data):
- return
-
- payload = _to_bytes(data)
- size = min(int(size), len(payload), obj.size)
- if size <= 0:
- return
-
- payload_view = memoryview(payload)[:size]
- staging_view = memoryview(obj.staging_data[queue_index])
- staging_view[:size] = payload_view
-
-
-def buffer_read_staging(buffer, queue_index, size):
- obj = _buffers.get(int(buffer))
- if obj is None:
- return bytes(int(size))
-
- queue_index = int(queue_index)
- if queue_index < 0 or queue_index >= len(obj.staging_data):
- return bytes(int(size))
-
- size = max(0, int(size))
- staging = obj.staging_data[queue_index]
-
- if size <= len(staging):
- return bytes(staging[:size])
-
- return bytes(staging) + bytes(size - len(staging))
-
-
-def buffer_write(buffer, offset, size, index):
- obj = _buffers.get(int(buffer))
- if obj is None:
- return
-
- ctx = _contexts.get(obj.context_handle)
- if ctx is None:
- _set_error(f"Missing context for buffer handle {buffer}")
- return
-
- offset = int(offset)
- size = int(size)
- if size <= 0 or offset < 0:
- return
-
- try:
- with _activate_context(ctx):
- for queue_index in _queue_indices(ctx, int(index), all_on_negative=True):
- stream = ctx.streams[queue_index]
- end = min(offset + size, obj.size)
- copy_size = end - offset
- if copy_size <= 0:
- continue
-
- src_view = memoryview(obj.staging_data[queue_index])[:copy_size]
- cuda.memcpy_htod_async(int(obj.device_allocation) + offset, src_view, stream)
-
- signal = _signals.get(obj.signal_handles[queue_index])
- if signal is not None:
- _record_signal(signal, stream)
- except Exception as exc:
- _set_error(f"Failed to write CUDA buffer: {exc}")
-
-
-def buffer_read(buffer, offset, size, index):
- obj = _buffers.get(int(buffer))
- if obj is None:
- return
-
- ctx = _contexts.get(obj.context_handle)
- if ctx is None:
- _set_error(f"Missing context for buffer handle {buffer}")
- return
-
- queue_index = int(index)
- if queue_index < 0 or queue_index >= ctx.queue_count:
- _set_error(f"Invalid queue index {queue_index} for buffer read")
- return
-
- offset = int(offset)
- size = int(size)
- if size <= 0 or offset < 0:
- return
-
- try:
- with _activate_context(ctx):
- stream = ctx.streams[queue_index]
- end = min(offset + size, obj.size)
- copy_size = end - offset
- if copy_size <= 0:
- return
-
- dst_view = memoryview(obj.staging_data[queue_index])[:copy_size]
- cuda.memcpy_dtoh_async(dst_view, int(obj.device_allocation) + offset, stream)
-
- signal = _signals.get(obj.signal_handles[queue_index])
- if signal is not None:
- _record_signal(signal, stream)
- except Exception as exc:
- _set_error(f"Failed to read CUDA buffer: {exc}")
-
-
-# --- API: command lists ---
-
-
-def command_list_create(context):
- if int(context) not in _contexts:
- _set_error("Invalid context handle for command_list_create")
- return 0
-
- return _new_handle(_command_lists, _CommandList(context_handle=int(context)))
-
-
-def command_list_destroy(command_list):
- obj = _command_lists.pop(int(command_list), None)
- if obj is None:
- return
-
- ctx = _contexts.get(obj.context_handle)
- if ctx is None or obj.pc_scratch is None:
- return
-
- try:
- with _activate_context(ctx):
- obj.pc_scratch.free()
- except Exception:
- pass
-
-
-def command_list_get_instance_size(command_list):
- obj = _command_lists.get(int(command_list))
- if obj is None:
- return 0
- return int(obj.compute_instance_size)
-
-
-def command_list_reset(command_list):
- obj = _command_lists.get(int(command_list))
- if obj is None:
- return
-
- obj.commands = []
- obj.compute_instance_size = 0
-
-
-def command_list_submit(command_list, data, instance_count, index):
- obj = _command_lists.get(int(command_list))
- if obj is None:
- return True
-
- ctx = _contexts.get(obj.context_handle)
- if ctx is None:
- _set_error(f"Missing context for command list {command_list}")
- return True
-
- payload = _to_bytes(data) if data is not None else b""
- instance_count = int(instance_count)
- if instance_count <= 0:
- return True
-
- instance_size = int(obj.compute_instance_size)
-
- if instance_size > 0 and len(payload) < instance_size * instance_count:
- _set_error(
- f"Instance payload is too small ({len(payload)} bytes) for "
- f"{instance_count} instances of size {instance_size}"
- )
- return True
-
- queue_targets = _queue_indices(ctx, int(index), all_on_negative=True)
- if len(queue_targets) == 0:
- queue_targets = [0]
-
- try:
- with _activate_context(ctx):
- payload_view = memoryview(payload) if payload else None
-
- for queue_index in queue_targets:
- stream = ctx.streams[queue_index]
- resolved_launches: List[_ResolvedLaunch] = []
- pc_offset = 0
-
- for command in obj.commands:
- plan = _compute_plans.get(command.plan_handle)
- if plan is None:
- raise RuntimeError(f"Invalid compute plan handle {command.plan_handle}")
-
- descriptor_set = None
- if command.descriptor_set_handle != 0:
- descriptor_set = _descriptor_sets.get(command.descriptor_set_handle)
- if descriptor_set is None:
- raise RuntimeError(
- f"Invalid descriptor set handle {command.descriptor_set_handle}"
- )
-
- pc_size = int(command.pc_size)
- args, pc_scratch = _build_kernel_args_template(plan, descriptor_set, obj, pc_size)
- resolved_launches.append(
- _ResolvedLaunch(
- plan=plan,
- blocks=command.blocks,
- pc_offset=pc_offset,
- pc_size=pc_size,
- args=args,
- pc_scratch=pc_scratch,
- )
- )
- pc_offset += pc_size
-
- for instance in range(instance_count):
- instance_base = instance * instance_size
-
- for launch in resolved_launches:
- if launch.pc_scratch is not None and launch.pc_size > 0:
- start = instance_base + launch.pc_offset
- end = start + launch.pc_size
- cuda.memcpy_htod_async(
- launch.pc_scratch,
- payload_view[start:end],
- stream,
- )
-
- launch.plan.function(
- *launch.args,
- block=launch.plan.local_size,
- grid=launch.blocks,
- stream=stream,
- )
- except Exception as exc:
- _set_error(f"Failed to submit CUDA command list: {exc}")
-
- return True
-
-
-# --- API: descriptor sets ---
-
-
-def descriptor_set_create(plan):
- if int(plan) not in _compute_plans:
- _set_error("Invalid compute plan handle for descriptor_set_create")
- return 0
-
- return _new_handle(_descriptor_sets, _DescriptorSet(plan_handle=int(plan)))
-
-
-def descriptor_set_destroy(descriptor_set):
- _descriptor_sets.pop(int(descriptor_set), None)
-
-
-def descriptor_set_write_buffer(
- descriptor_set,
- binding,
- object,
- offset,
- range,
- uniform,
- read_access,
- write_access,
-):
- ds = _descriptor_sets.get(int(descriptor_set))
- if ds is None:
- _set_error("Invalid descriptor set handle for descriptor_set_write_buffer")
- return
-
- ds.buffer_bindings[int(binding)] = (
- int(object),
- int(offset),
- int(range),
- int(uniform),
- int(read_access),
- int(write_access),
- )
-
-
-def descriptor_set_write_image(
- descriptor_set,
- binding,
- object,
- sampler_obj,
- read_access,
- write_access,
-):
- ds = _descriptor_sets.get(int(descriptor_set))
- if ds is None:
- _set_error("Invalid descriptor set handle for descriptor_set_write_image")
- return
-
- ds.image_bindings[int(binding)] = (
- int(object),
- int(sampler_obj),
- int(read_access),
- int(write_access),
- )
-
-
-# --- API: compute stage ---
-
-
-def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name):
- ctx = _context_from_handle(int(context))
- if ctx is None:
- return 0
-
- source_bytes = _to_bytes(shader_source)
- shader_name_bytes = _to_bytes(shader_name)
- source_text = source_bytes.decode("utf-8", errors="replace")
-
- try:
- with _activate_context(ctx):
- module = SourceModule(
- source_text,
- no_extern_c=True,
- options=["-w"]
- )
- function = module.get_function("vkdispatch_main")
- except Exception as exc:
- _set_error(f"Failed to compile CUDA kernel '{shader_name_bytes.decode(errors='ignore')}': {exc}")
- return 0
-
- try:
- params = _parse_kernel_params(source_text)
- local_size = _parse_local_size(source_text)
- except Exception as exc:
- _set_error(f"Failed to parse CUDA kernel metadata: {exc}")
- return 0
-
- plan = _ComputePlan(
- context_handle=int(context),
- shader_source=source_bytes,
- bindings=[int(x) for x in bindings],
- pc_size=int(pc_size),
- shader_name=shader_name_bytes,
- module=module,
- function=function,
- local_size=local_size,
- params=params,
- )
-
- return _new_handle(_compute_plans, plan)
-
-
-def stage_compute_plan_destroy(plan):
- if plan is None:
- return
- _compute_plans.pop(int(plan), None)
-
-
-def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z):
- cl = _command_lists.get(int(command_list))
- cp = _compute_plans.get(int(plan))
- if cl is None or cp is None:
- _set_error("Invalid command list or compute plan handle for stage_compute_record")
- return
-
- cl.commands.append(
- _CommandRecord(
- plan_handle=int(plan),
- descriptor_set_handle=int(descriptor_set),
- blocks=(int(blocks_x), int(blocks_y), int(blocks_z)),
- pc_size=int(cp.pc_size),
- )
- )
- cl.compute_instance_size += int(cp.pc_size)
-
-
-# --- API: images/samplers (not yet implemented on PyCUDA backend) ---
-
-
-def image_create(context, extent, layers, format, type, view_type, generate_mips):
- _ = context
- _ = extent
- _ = layers
- _ = format
- _ = type
- _ = view_type
- _ = generate_mips
- _set_error("PyCUDA backend does not support image objects yet")
- return 0
-
-
-def image_destroy(image):
- _images.pop(int(image), None)
-
-
-def image_create_sampler(
- context,
- mag_filter,
- min_filter,
- mip_mode,
- address_mode,
- mip_lod_bias,
- min_lod,
- max_lod,
- border_color,
-):
- _ = context
- _ = mag_filter
- _ = min_filter
- _ = mip_mode
- _ = address_mode
- _ = mip_lod_bias
- _ = min_lod
- _ = max_lod
- _ = border_color
- _set_error("PyCUDA backend does not support image samplers yet")
- return 0
-
-
-def image_destroy_sampler(sampler):
- _samplers.pop(int(sampler), None)
-
-
-def image_write(image, data, offset, extent, baseLayer, layerCount, device_index):
- _ = image
- _ = data
- _ = offset
- _ = extent
- _ = baseLayer
- _ = layerCount
- _ = device_index
- _set_error("PyCUDA backend does not support image writes yet")
-
-
-def image_format_block_size(format):
- return int(_IMAGE_BLOCK_SIZES.get(int(format), 4))
-
-
-def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index):
- _ = image
- _ = offset
- _ = extent
- _ = baseLayer
- _ = layerCount
- _ = device_index
- _set_error("PyCUDA backend does not support image reads yet")
- return bytes(max(0, int(out_size)))
-
-
-# --- API: FFT stage (not yet implemented on PyCUDA backend) ---
-
-
-def stage_fft_plan_create(
- context,
- dims,
- axes,
- buffer_size,
- do_r2c,
- normalize,
- pad_left,
- pad_right,
- frequency_zeropadding,
- kernel_num,
- kernel_convolution,
- conjugate_convolution,
- convolution_features,
- input_buffer_size,
- num_batches,
- single_kernel_multiple_batches,
- keep_shader_code,
-):
- _ = context
- _ = dims
- _ = axes
- _ = buffer_size
- _ = do_r2c
- _ = normalize
- _ = pad_left
- _ = pad_right
- _ = frequency_zeropadding
- _ = kernel_num
- _ = kernel_convolution
- _ = conjugate_convolution
- _ = convolution_features
- _ = input_buffer_size
- _ = num_batches
- _ = single_kernel_multiple_batches
- _ = keep_shader_code
- _set_error("PyCUDA backend does not support FFT plans yet")
- return 0
-
-
-def stage_fft_plan_destroy(plan):
- _fft_plans.pop(int(plan), None)
-
-
-def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer):
- _ = command_list
- _ = plan
- _ = buffer
- _ = inverse
- _ = kernel
- _ = input_buffer
- _set_error("PyCUDA backend does not support FFT stages yet")
-
-
-__all__ = [
- "LOG_LEVEL_VERBOSE",
- "LOG_LEVEL_INFO",
- "LOG_LEVEL_WARNING",
- "LOG_LEVEL_ERROR",
- "DESCRIPTOR_TYPE_STORAGE_BUFFER",
- "DESCRIPTOR_TYPE_STORAGE_IMAGE",
- "DESCRIPTOR_TYPE_UNIFORM_BUFFER",
- "DESCRIPTOR_TYPE_UNIFORM_IMAGE",
- "DESCRIPTOR_TYPE_SAMPLER",
- "init",
- "log",
- "set_log_level",
- "get_devices",
- "context_create",
- "signal_wait",
- "signal_insert",
- "signal_destroy",
- "context_destroy",
- "get_error_string",
- "context_stop_threads",
- "buffer_create",
- "buffer_destroy",
- "buffer_get_queue_signal",
- "buffer_wait_staging_idle",
- "buffer_write_staging",
- "buffer_read_staging",
- "buffer_write",
- "buffer_read",
- "command_list_create",
- "command_list_destroy",
- "command_list_get_instance_size",
- "command_list_reset",
- "command_list_submit",
- "descriptor_set_create",
- "descriptor_set_destroy",
- "descriptor_set_write_buffer",
- "descriptor_set_write_image",
- "image_create",
- "image_destroy",
- "image_create_sampler",
- "image_destroy_sampler",
- "image_write",
- "image_format_block_size",
- "image_read",
- "stage_compute_plan_create",
- "stage_compute_plan_destroy",
- "stage_compute_record",
- "stage_fft_plan_create",
- "stage_fft_plan_destroy",
- "stage_fft_record",
-]
diff --git a/vkdispatch/base/backend.py b/vkdispatch/base/backend.py
deleted file mode 100644
index cf652eb1..00000000
--- a/vkdispatch/base/backend.py
+++ /dev/null
@@ -1,79 +0,0 @@
-from __future__ import annotations
-
-import importlib
-from types import ModuleType
-from typing import Dict, Optional
-
-BACKEND_VULKAN = "vulkan"
-BACKEND_PYCUDA = "pycuda"
-
-_VALID_BACKENDS = {BACKEND_VULKAN, BACKEND_PYCUDA}
-_active_backend_name: Optional[str] = None
-_backend_modules: Dict[str, ModuleType] = {}
-
-
-def normalize_backend_name(backend: Optional[str]) -> str:
- if backend is None:
- return BACKEND_VULKAN
-
- backend_name = backend.strip().lower()
- if backend_name not in _VALID_BACKENDS:
- valid = ", ".join(sorted(_VALID_BACKENDS))
- raise ValueError(f"Unknown backend '{backend}'. Expected one of: {valid}")
-
- return backend_name
-
-
-def set_active_backend(backend: str) -> str:
- global _active_backend_name
-
- backend_name = normalize_backend_name(backend)
-
- if _active_backend_name is not None and _active_backend_name != backend_name:
- raise RuntimeError(
- f"Backend is already set to '{_active_backend_name}' and cannot be changed to '{backend_name}' in this process."
- )
-
- _active_backend_name = backend_name
- return _active_backend_name
-
-
-def clear_active_backend() -> None:
- global _active_backend_name
- _active_backend_name = None
-
-
-def get_active_backend_name(default: Optional[str] = BACKEND_VULKAN) -> str:
- if _active_backend_name is not None:
- return _active_backend_name
-
- return normalize_backend_name(default)
-
-
-def _load_backend_module(backend_name: str) -> ModuleType:
- if backend_name in _backend_modules:
- return _backend_modules[backend_name]
-
- if backend_name == BACKEND_VULKAN:
- module = importlib.import_module("vkdispatch_native")
- elif backend_name == BACKEND_PYCUDA:
- module = importlib.import_module("vkdispatch.backends.pycuda_native")
- else:
- # Defensive guard for future refactors.
- raise ValueError(f"Unsupported backend '{backend_name}'")
-
- _backend_modules[backend_name] = module
- return module
-
-
-def get_backend_module(backend: Optional[str] = None) -> ModuleType:
- backend_name = normalize_backend_name(backend) if backend is not None else get_active_backend_name()
- return _load_backend_module(backend_name)
-
-
-class _BackendProxy:
- def __getattr__(self, name: str):
- return getattr(get_backend_module(), name)
-
-
-native = _BackendProxy()
diff --git a/vkdispatch/base/brython_utils.py b/vkdispatch/base/brython_utils.py
deleted file mode 100644
index fa4e7b6b..00000000
--- a/vkdispatch/base/brython_utils.py
+++ /dev/null
@@ -1,4 +0,0 @@
-import sys
-
-def is_brython() -> bool:
- return sys.implementation.name == "Brython"
\ No newline at end of file
diff --git a/vkdispatch/base/buffer.py b/vkdispatch/base/buffer.py
index 8c2ff2a8..6f49b622 100644
--- a/vkdispatch/base/buffer.py
+++ b/vkdispatch/base/buffer.py
@@ -1,22 +1,42 @@
from typing import Tuple
from typing import List
from typing import Union
+from typing import Optional
+from contextlib import nullcontext
+from .init import is_cuda
from .dtype import dtype
from .context import Handle, Signal
from .errors import check_for_errors
-from .dtype import complex64, uint32, int32, float32
+from .dtype import complex64
+from . import dtype as dtypes
-from .._compat import numpy_compat as npc
+from ..compat import numpy_compat as npc
from .dtype import to_numpy_dtype, from_numpy_dtype
-from .backend import native
+from ..backends.backend_selection import native
import typing
_ArgType = typing.TypeVar('_ArgType', bound=dtype)
+import dataclasses
+
+def _suspend_cuda_capture_if_needed():
+ if not is_cuda():
+ return nullcontext()
+
+ from ..execution_pipeline.cuda_graph_capture import suspend_cuda_capture
+ return suspend_cuda_capture()
+
+@dataclasses.dataclass
+class ExternalBufferInfo:
+ writable: bool
+ iface: dict
+ keepalive: bool
+ cuda_ptr: int
+
class Buffer(Handle, typing.Generic[_ArgType]):
"""
Represents a contiguous block of memory on the GPU (or shared across multiple devices).
@@ -37,8 +57,14 @@ class Buffer(Handle, typing.Generic[_ArgType]):
size: int
mem_size: int
signals: List[Signal]
-
- def __init__(self, shape: Tuple[int, ...], var_type: dtype) -> None:
+ is_external: bool
+ owns_memory: bool
+ is_writable: bool
+ cuda_ptr: typing.Optional[int]
+ cuda_source: typing.Any
+ cuda_array_stream: typing.Optional[typing.Any]
+
+ def __init__(self, shape: Tuple[int, ...], var_type: dtype, external_buffer: ExternalBufferInfo = None) -> None:
super().__init__()
if isinstance(shape, int):
@@ -49,7 +75,6 @@ def __init__(self, shape: Tuple[int, ...], var_type: dtype) -> None:
self.var_type: dtype = var_type
self.shape: Tuple[int] = shape
- #self.size: int = int(np.prod(shape))
size = 1
for dim in shape:
@@ -71,11 +96,25 @@ def __init__(self, shape: Tuple[int, ...], var_type: dtype) -> None:
self.shader_shape = tuple(shader_shape_internal)
self.signals = []
-
- handle = native.buffer_create(
- self.context._handle, self.mem_size, 0
- )
- check_for_errors()
+ self.is_external = external_buffer is not None
+ self.owns_memory = external_buffer is None
+ self.is_writable = True if external_buffer is None else external_buffer.writable
+ self.cuda_ptr = None if external_buffer is None else external_buffer.cuda_ptr
+ self.cuda_source = None if external_buffer is None else (external_buffer.iface if external_buffer.keepalive else None)
+ self.cuda_array_stream = None if external_buffer is None else external_buffer.iface.get("stream")
+
+ with _suspend_cuda_capture_if_needed():
+ if external_buffer is not None:
+ handle = native.buffer_create_external(
+ self.context._handle,
+ self.mem_size,
+ self.cuda_ptr,
+ )
+ else:
+ handle = native.buffer_create(
+ self.context._handle, self.mem_size, 0
+ )
+ check_for_errors()
self.signals = [
Signal(
@@ -88,6 +127,17 @@ def __init__(self, shape: Tuple[int, ...], var_type: dtype) -> None:
self.register_handle(handle)
+ def __repr__(self):
+ return f"""Buffer {self._handle}:
+ shape={self.shape}
+ var_type={self.var_type.name}
+ mem_size={self.mem_size} bytes
+ is_external={self.is_external}
+ writable={self.is_writable}
+ cuda_ptr={self.cuda_ptr}
+ cuda_iface={self.cuda_source}
+"""
+
def _destroy(self) -> None:
"""Destroy the buffer and all child handles."""
@@ -100,31 +150,33 @@ def __del__(self) -> None:
self.destroy()
def _wait_staging_idle(self, index: int):
- is_idle = native.buffer_wait_staging_idle(self._handle, index)
- check_for_errors()
+ with _suspend_cuda_capture_if_needed():
+ is_idle = native.buffer_wait_staging_idle(self._handle, index)
+ check_for_errors()
return is_idle
def _do_writes(self, data: bytes, index: int = None):
indicies = [index] if index is not None else range(self.context.queue_count)
completed_stages = [0] * len(indicies)
- while not all(stage == 1 for stage in completed_stages):
- for i in range(len(indicies)):
- if completed_stages[i] == 1:
- continue
+ with _suspend_cuda_capture_if_needed():
+ while not all(stage == 1 for stage in completed_stages):
+ for i in range(len(indicies)):
+ if completed_stages[i] == 1:
+ continue
- queue_index = indicies[i]
+ queue_index = indicies[i]
- if not self.signals[queue_index].try_wait(True, queue_index):
- continue
+ if not self.signals[queue_index].try_wait(True, queue_index):
+ continue
- completed_stages[i] = 1
+ completed_stages[i] = 1
- native.buffer_write_staging(self._handle, queue_index, data, len(data))
- check_for_errors()
+ native.buffer_write_staging(self._handle, queue_index, data, len(data))
+ check_for_errors()
- native.buffer_write(self._handle, 0, len(data), queue_index)
- check_for_errors()
+ native.buffer_write(self._handle, 0, len(data), queue_index)
+ check_for_errors()
def write(self, data: Union[bytes, bytearray, memoryview, typing.Any], index: int = None) -> None:
"""
@@ -143,6 +195,9 @@ def write(self, data: Union[bytes, bytearray, memoryview, typing.Any], index: in
assert isinstance(index, int), "Index must be an integer or None!"
assert index >= 0 and index < self.context.queue_count, "Index must be valid!"
+ if not getattr(self, "is_writable", True):
+ raise ValueError("Cannot write to a read-only buffer alias.")
+
true_data_object = None
if npc.is_array_like(data):
@@ -167,29 +222,30 @@ def _do_reads(self, var_type: dtype, shape: List[int], index: int = None) -> byt
mem_size = int(npc.prod(shape)) * var_type.item_size
- while not all(stage == 2 for stage in completed_stages):
- for i in range(len(indicies)):
- if completed_stages[i] == 2:
- continue
+ with _suspend_cuda_capture_if_needed():
+ while not all(stage == 2 for stage in completed_stages):
+ for i in range(len(indicies)):
+ if completed_stages[i] == 2:
+ continue
- queue_index = indicies[i]
+ queue_index = indicies[i]
- if completed_stages[i] == 0:
- if self.signals[queue_index].try_wait(False, queue_index):
- completed_stages[i] = 1
- native.buffer_read(self._handle, 0, mem_size, queue_index)
- check_for_errors()
- else:
- continue
+ if completed_stages[i] == 0:
+ if self.signals[queue_index].try_wait(False, queue_index):
+ completed_stages[i] = 1
+ native.buffer_read(self._handle, 0, mem_size, queue_index)
+ check_for_errors()
+ else:
+ continue
- if completed_stages[i] == 1:
- if self.signals[queue_index].try_wait(True, queue_index):
- completed_stages[i] = 2
- else:
- continue
+ if completed_stages[i] == 1:
+ if self.signals[queue_index].try_wait(True, queue_index):
+ completed_stages[i] = 2
+ else:
+ continue
- bytes_list[i] = native.buffer_read_staging(self._handle, queue_index, mem_size)
- check_for_errors()
+ bytes_list[i] = native.buffer_read_staging(self._handle, queue_index, mem_size)
+ check_for_errors()
host_arrays = []
@@ -231,6 +287,9 @@ def read(self, index: Union[int, None] = None):
def asbuffer(array: typing.Any) -> Buffer:
"""Cast an array-like object to a buffer object."""
+ if hasattr(array, "__cuda_array_interface__"):
+ return from_cuda_array(array)
+
if not npc.is_array_like(array):
raise TypeError("Expected an array-like object")
@@ -239,62 +298,151 @@ def asbuffer(array: typing.Any) -> Buffer:
return buffer
-def buffer_u32(shape: Tuple[int, ...]) -> Buffer:
- """Create a buffer of unsigned 32-bit integers with the specified shape."""
- return Buffer(shape, uint32)
-def buffer_i32(shape: Tuple[int, ...]) -> Buffer:
- """Create a buffer of signed 32-bit integers with the specified shape."""
- return Buffer(shape, int32)
+def from_cuda_array(
+ obj: typing.Any,
+ var_type: typing.Optional[dtype] = None,
+ require_contiguous: bool = True,
+ writable: typing.Optional[bool] = None,
+ keepalive: bool = True,
+) -> Buffer:
+ assert is_cuda(), "__cuda_array_interface__ is only supported with CUDA backends."
+
+ if not hasattr(obj, "__cuda_array_interface__"):
+ raise TypeError("Expected an object with __cuda_array_interface__")
+
+ npc.require_numpy("from_cuda_array")
+ np = npc.numpy_module()
+
+ iface = obj.__cuda_array_interface__
+ if not isinstance(iface, dict):
+ raise TypeError("__cuda_array_interface__ must be a dictionary")
+
+ if "shape" not in iface or "typestr" not in iface or "data" not in iface:
+ raise ValueError("__cuda_array_interface__ is missing required fields (shape/typestr/data)")
-def buffer_f32(shape: Tuple[int, ...]) -> Buffer:
- """Create a buffer of 32-bit floating-point numbers with the specified shape."""
- return Buffer(shape, float32)
+ shape = tuple(int(dim) for dim in iface["shape"])
+ if len(shape) == 0:
+ shape = (1,)
-def buffer_c64(shape: Tuple[int, ...]) -> Buffer:
- """Create a buffer of 64-bit complex numbers with the specified shape."""
- return Buffer(shape, complex64)
+ data_entry = iface["data"]
+ if not (isinstance(data_entry, tuple) and len(data_entry) >= 2):
+ raise ValueError("__cuda_array_interface__['data'] must be a tuple (ptr, read_only)")
+
+ ptr = int(data_entry[0])
+ source_read_only = bool(data_entry[1])
+
+ inferred_np_dtype = np.dtype(iface["typestr"])
+ inferred_var_type = from_numpy_dtype(inferred_np_dtype)
+ if var_type is None:
+ var_type = inferred_var_type
+
+ if not (var_type == inferred_var_type):
+ raise ValueError(
+ f"CAI dtype ({inferred_np_dtype}) does not match requested vd dtype ({var_type.name})."
+ )
+
+ if require_contiguous:
+ strides = iface.get("strides")
+ if strides is not None:
+ expected_strides = []
+ stride = int(inferred_np_dtype.itemsize)
+ for dim in reversed(shape):
+ expected_strides.insert(0, stride)
+ stride *= int(dim)
+ if tuple(int(x) for x in strides) != tuple(expected_strides):
+ raise ValueError("Only contiguous C-order CUDA arrays are supported in from_cuda_array().")
+
+ buffer_writable = (not source_read_only) if writable is None else bool(writable)
+ if buffer_writable and source_read_only:
+ raise ValueError("Requested writable=True for a read-only CUDA array.")
+
+ external_buffer_info = ExternalBufferInfo(
+ writable=buffer_writable,
+ iface=iface,
+ keepalive=keepalive,
+ cuda_ptr=ptr
+ )
+
+ return Buffer(shape, var_type, external_buffer=external_buffer_info)
class RFFTBuffer(Buffer):
- def __init__(self, shape: Tuple[int, ...]):
- super().__init__(tuple(shape[:-1]) + (shape[-1] // 2 + 1,), complex64)
+ real_shape: Tuple[int, ...]
+ fourier_shape: Tuple[int, ...]
+ real_type: dtype
+
+ def __init__(self, shape: Tuple[int, ...], fourier_type: dtype = complex64):
+ if not dtypes.is_complex(fourier_type):
+ raise ValueError("RFFTBuffer fourier_type must be complex32, complex64, or complex128")
+
+ if not dtypes.is_float_dtype(fourier_type.child_type):
+ raise ValueError("RFFTBuffer fourier_type must use a floating-point scalar")
+
+ super().__init__(tuple(shape[:-1]) + (shape[-1] // 2 + 1,), fourier_type)
self.real_shape = shape
self.fourier_shape = self.shape
-
+ self.real_type = fourier_type.child_type
+
def read_real(self, index: Union[int, None] = None):
npc.require_numpy("RFFTBuffer.read_real")
np = npc.numpy_module()
- return self.read(index).view(np.float32)[..., :self.real_shape[-1]]
+
+ packed_shape = list(self.shape[:-1]) + [self.shape[-1] * 2]
+ packed_data = self._do_reads(self.real_type, packed_shape, index)
+
+ if index is None:
+ packed_data = np.array(packed_data)
+
+ return packed_data[..., :self.real_shape[-1]]
def read_fourier(self, index: Union[int, None] = None):
return self.read(index)
-
+
def write_real(self, data, index: int = None):
npc.require_numpy("RFFTBuffer.write_real")
np = npc.numpy_module()
assert data.shape == self.real_shape, "Data shape must match real shape!"
- assert not np.issubdtype(data.dtype, np.complexfloating) , "Data dtype must be scalar!"
+ assert not np.issubdtype(data.dtype, np.complexfloating), "Data dtype must be scalar!"
- true_data = np.zeros(self.shape[:-1] + (self.shape[-1] * 2,), dtype=np.float32)
+ real_dtype = to_numpy_dtype(self.real_type)
+ true_data = np.zeros(self.shape[:-1] + (self.shape[-1] * 2,), dtype=real_dtype)
true_data[..., :self.real_shape[-1]] = data
- self.write(np.ascontiguousarray(true_data).view(np.complex64), index)
+ self.write(np.ascontiguousarray(true_data), index)
def write_fourier(self, data, index: int = None):
npc.require_numpy("RFFTBuffer.write_fourier")
np = npc.numpy_module()
assert data.shape == self.fourier_shape, f"Data shape {data.shape} must match fourier shape {self.fourier_shape}!"
- assert np.issubdtype(data.dtype, np.complexfloating) , "Data dtype must be complex!"
+ assert np.issubdtype(data.dtype, np.complexfloating), "Data dtype must be complex!"
+
+ target_fourier_dtype = to_numpy_dtype(self.var_type)
+ if npc.is_host_dtype(target_fourier_dtype):
+ # complex32: pack complex values into float16 real/imag pairs.
+ complex_data = np.ascontiguousarray(data.astype(np.complex64))
+ packed_pairs = np.empty(complex_data.shape + (2,), dtype=np.float16)
+ packed_pairs[..., 0] = complex_data.real.astype(np.float16)
+ packed_pairs[..., 1] = complex_data.imag.astype(np.float16)
+
+ packed_real_shape = self.shape[:-1] + (self.shape[-1] * 2,)
+ self.write(np.ascontiguousarray(packed_pairs).reshape(packed_real_shape), index)
+ return
- self.write(np.ascontiguousarray(data.astype(np.complex64)).view(np.float32), index)
+ self.write(np.ascontiguousarray(data.astype(target_fourier_dtype)), index)
-def asrfftbuffer(data) -> RFFTBuffer:
+
+def asrfftbuffer(data, fourier_type: Optional[dtype] = None) -> RFFTBuffer:
npc.require_numpy("asrfftbuffer")
np = npc.numpy_module()
assert not np.issubdtype(data.dtype, np.complexfloating), "Data dtype must be scalar!"
- buffer = RFFTBuffer(data.shape)
+ if fourier_type is None:
+ scalar_dtype = from_numpy_dtype(data.dtype)
+ scalar_dtype = dtypes.make_floating_dtype(scalar_dtype)
+ fourier_type = dtypes.complex_from_float(scalar_dtype)
+
+ buffer = RFFTBuffer(data.shape, fourier_type=fourier_type)
buffer.write_real(data)
return buffer
diff --git a/vkdispatch/base/buffer_allocators.py b/vkdispatch/base/buffer_allocators.py
new file mode 100644
index 00000000..e14fed86
--- /dev/null
+++ b/vkdispatch/base/buffer_allocators.py
@@ -0,0 +1,119 @@
+from .buffer import Buffer
+from . import dtype as dt
+from typing import Tuple
+
+def buffer_u32(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of unsigned 32-bit integers with the specified shape."""
+ return Buffer(shape, dt.uint32)
+
+def buffer_uv2(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of unsigned 32-bit integer vectors of size 2 with the specified shape."""
+ return Buffer(shape, dt.uvec2)
+
+def buffer_uv3(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of unsigned 32-bit integer vectors of size 3 with the specified shape."""
+ return Buffer(shape, dt.uvec3)
+
+def buffer_uv4(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of unsigned 32-bit integer vectors of size 4 with the specified shape."""
+ return Buffer(shape, dt.uvec4)
+
+def buffer_i32(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of signed 32-bit integers with the specified shape."""
+ return Buffer(shape, dt.int32)
+
+def buffer_iv2(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of signed 32-bit integer vectors of size 2 with the specified shape."""
+ return Buffer(shape, dt.ivec2)
+
+def buffer_iv3(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of signed 32-bit integer vectors of size 3 with the specified shape."""
+ return Buffer(shape, dt.ivec3)
+
+def buffer_iv4(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of signed 32-bit integer vectors of size 4 with the specified shape."""
+ return Buffer(shape, dt.ivec4)
+
+def buffer_f32(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 32-bit floating-point numbers with the specified shape."""
+ return Buffer(shape, dt.float32)
+
+def buffer_v2(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 32-bit floating-point vectors of size 2 with the specified shape."""
+ return Buffer(shape, dt.vec2)
+
+def buffer_v3(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 32-bit floating-point vectors of size 3 with the specified shape."""
+ return Buffer(shape, dt.vec3)
+
+def buffer_v4(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 32-bit floating-point vectors of size 4 with the specified shape."""
+ return Buffer(shape, dt.vec4)
+
+def buffer_c64(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 64-bit complex numbers with the specified shape."""
+ return Buffer(shape, dt.complex64)
+
+def buffer_u16(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of unsigned 16-bit integers with the specified shape."""
+ return Buffer(shape, dt.uint16)
+
+def buffer_uhv2(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of unsigned 16-bit integer vectors of size 2 with the specified shape."""
+ return Buffer(shape, dt.uhvec2)
+
+def buffer_uhv3(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of unsigned 16-bit integer vectors of size 3 with the specified shape."""
+ return Buffer(shape, dt.uhvec3)
+
+def buffer_uhv4(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of unsigned 16-bit integer vectors of size 4 with the specified shape."""
+ return Buffer(shape, dt.uhvec4)
+
+def buffer_i16(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of signed 16-bit integers with the specified shape."""
+ return Buffer(shape, dt.int16)
+
+def buffer_ihv2(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of signed 16-bit integer vectors of size 2 with the specified shape."""
+ return Buffer(shape, dt.ihvec2)
+
+def buffer_ihv3(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of signed 16-bit integer vectors of size 3 with the specified shape."""
+ return Buffer(shape, dt.ihvec3)
+
+def buffer_ihv4(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of signed 16-bit integer vectors of size 4 with the specified shape."""
+ return Buffer(shape, dt.ihvec4)
+
+def buffer_f16(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 16-bit floating-point numbers with the specified shape."""
+ return Buffer(shape, dt.float16)
+
+def buffer_hv2(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 16-bit floating-point vectors of size 2 with the specified shape."""
+ return Buffer(shape, dt.hvec2)
+
+def buffer_hv3(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 16-bit floating-point vectors of size 3 with the specified shape."""
+ return Buffer(shape, dt.hvec3)
+
+def buffer_hv4(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 16-bit floating-point vectors of size 4 with the specified shape."""
+ return Buffer(shape, dt.hvec4)
+
+def buffer_f64(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 64-bit floating-point numbers with the specified shape."""
+ return Buffer(shape, dt.float64)
+
+def buffer_dv2(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 64-bit floating-point vectors of size 2 with the specified shape."""
+ return Buffer(shape, dt.dvec2)
+
+def buffer_dv3(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 64-bit floating-point vectors of size 3 with the specified shape."""
+ return Buffer(shape, dt.dvec3)
+
+def buffer_dv4(shape: Tuple[int, ...]) -> Buffer:
+ """Create a buffer of 64-bit floating-point vectors of size 4 with the specified shape."""
+ return Buffer(shape, dt.dvec4)
\ No newline at end of file
diff --git a/vkdispatch/base/command_list.py b/vkdispatch/base/command_list.py
index 5ebd7194..99fa2799 100644
--- a/vkdispatch/base/command_list.py
+++ b/vkdispatch/base/command_list.py
@@ -1,11 +1,14 @@
from typing import Tuple
from typing import Optional
-from .backend import native
+from ..backends.backend_selection import native
+from .init import is_cuda
from .context import Handle
from .errors import check_for_errors
+from ..execution_pipeline.cuda_graph_capture import get_cuda_capture
+
from .compute_plan import ComputePlan
from .descriptor_set import DescriptorSet
@@ -76,7 +79,13 @@ def reset(self) -> None:
self.clear_parents()
- def submit(self, data: Optional[bytes] = None, queue_index: int = -2, instance_count: Optional[int] = None) -> None:
+ def submit(
+ self,
+ data: Optional[bytes] = None,
+ queue_index: int = -2,
+ instance_count: Optional[int] = None,
+ cuda_stream=None
+ ) -> None:
"""
Submits the recorded command list to the GPU queue for execution.
@@ -106,9 +115,22 @@ def submit(self, data: Optional[bytes] = None, queue_index: int = -2, instance_c
if self.get_instance_size() != 0:
assert self.get_instance_size() * instance_count == len(data), "Data length must be the product of the instance size and instance count!"
+ if cuda_stream is None and get_cuda_capture() is not None:
+ cuda_stream = get_cuda_capture().cuda_stream
+
+ if cuda_stream is not None:
+ if not is_cuda():
+ raise RuntimeError("cuda_stream=... is currently only supported with CUDA backends.")
+
+ native.cuda_stream_override_begin(cuda_stream)
+ check_for_errors()
+
done = False
while not done:
done = native.command_list_submit(
self._handle, data, instance_count, queue_index
)
check_for_errors()
+
+ if cuda_stream is not None:
+ native.cuda_stream_override_end()
diff --git a/vkdispatch/base/compute_plan.py b/vkdispatch/base/compute_plan.py
index fd997705..995ae177 100644
--- a/vkdispatch/base/compute_plan.py
+++ b/vkdispatch/base/compute_plan.py
@@ -1,4 +1,4 @@
-from .backend import native
+from ..backends.backend_selection import native
from .context import Handle
from .errors import check_for_compute_stage_errors, check_for_errors
@@ -34,7 +34,6 @@ def __init__(self, shader_source: str, binding_type_list: list, pc_size: int, sh
self.context._handle, shader_source.encode(), self.binding_list, pc_size, shader_name.encode()
)
check_for_compute_stage_errors()
-
self.register_handle(handle)
def _destroy(self) -> None:
diff --git a/vkdispatch/base/context.py b/vkdispatch/base/context.py
index 11aef807..45351b32 100644
--- a/vkdispatch/base/context.py
+++ b/vkdispatch/base/context.py
@@ -7,12 +7,16 @@
import atexit
import weakref
+import sys
import os, signal
from .errors import check_for_errors, set_running
-from .init import DeviceInfo, get_backend, get_devices, initialize, set_log_level, LogLevel, log_info
-from .backend import BACKEND_PYCUDA, native
+from .init import DeviceInfo, is_cuda, is_opencl, is_dummy, get_devices, initialize, log_info
+from ..backends.backend_selection import native
+VK_SHADER_STAGE_COMPUTE_BIT = 0x00000020
+
+VK_SUBGROUP_FEATURE_ARITHMETIC_BIT = 0x00000004
class Handle:
context: "Context"
@@ -53,6 +57,8 @@ def clear_parents(self) -> None:
"""
Clears the parent handles.
"""
+ # children_dict uses weak references, so a child key may disappear
+ # before teardown reaches this point.
for parent in self.parents.values():
parent.remove_child_handle(self)
@@ -71,10 +77,8 @@ def remove_child_handle(self, child: "Handle") -> None:
"""
Removes a child handle from the current handle.
"""
- if child._handle not in self.children_dict.keys():
- raise ValueError(f"Child handle {child._handle} does not exist in parent handle!")
-
- self.children_dict.pop(child._handle)
+ # Be idempotent to tolerate repeated teardown paths and weakref eviction.
+ self.children_dict.pop(child._handle, None)
def _destroy(self) -> None:
raise NotImplementedError("destroy is an abstract method and must be implemented by subclasses.")
@@ -84,7 +88,10 @@ def destroy(self) -> None:
Destroys the context handle and cleans up resources.
"""
if self.destroyed:
- return
+ return
+
+ self.destroyed = True
+ self.clear_parents()
child_keys = list(self.children_dict.keys())
@@ -101,13 +108,11 @@ def destroy(self) -> None:
check_for_errors()
self.canary = True
-
- self.clear_parents()
if self._handle in self.context.handles_dict.keys():
self.context.handles_dict.pop(self._handle)
- self.destroyed = True
+
class Signal:
ptr_addr: int
@@ -159,6 +164,8 @@ class Context:
queue_families: List[List[int]]
queue_count: int
subgroup_size: int
+ subgroup_enabled: bool
+ subgroup_arithmetic: bool
max_workgroup_size: Tuple[int]
max_workgroup_invocations: int
max_workgroup_count: Tuple[int, int, int]
@@ -179,7 +186,10 @@ def __init__(
self.mapped_device_ids = [dev.dev_index for dev in self.device_infos]
self._handle = native.context_create(self.mapped_device_ids, queue_families)
check_for_errors()
-
+
+ self.refresh_limits_from_device_infos()
+
+ def refresh_limits_from_device_infos(self) -> None:
subgroup_sizes = []
max_workgroup_sizes_x = []
max_workgroup_sizes_y = []
@@ -191,6 +201,9 @@ def __init__(
uniform_buffer_alignments = []
max_shared_memory = []
+ subgroup_enabled = True
+ subgroup_arithmetic = True
+
for device in self.device_infos:
subgroup_sizes.append(device.sub_group_size)
@@ -208,6 +221,14 @@ def __init__(
max_shared_memory.append(device.max_compute_shared_memory_size)
+ if not device.supported_stages & VK_SHADER_STAGE_COMPUTE_BIT:
+ subgroup_enabled = False
+
+ if not device.supported_operations & VK_SUBGROUP_FEATURE_ARITHMETIC_BIT:
+ subgroup_arithmetic = False
+
+ self.subgroup_enabled = subgroup_enabled
+ self.subgroup_arithmetic = subgroup_arithmetic
self.subgroup_size = min(subgroup_sizes)
self.max_workgroup_size = (
min(max_workgroup_sizes_x),
@@ -371,15 +392,16 @@ def make_context(
select_queue_families(dev_index, queue_family_count)
)
- if get_backend() == BACKEND_PYCUDA:
+ if is_cuda() or is_opencl():
+ backend_name = "CUDA" if is_cuda() else "OpenCL"
if len(device_ids) != 1:
raise NotImplementedError(
- "The PyCUDA backend currently supports exactly one device."
+ f"The {backend_name} backend currently supports exactly one device."
)
if len(queue_families) != 1 or len(queue_families[0]) != 1:
raise NotImplementedError(
- "The PyCUDA backend currently supports exactly one queue."
+ f"The {backend_name} backend currently supports exactly one queue."
)
total_devices = len(get_devices())
@@ -413,6 +435,125 @@ def get_context() -> Context:
def get_context_handle() -> int:
return get_context()._handle
+def _as_positive_int(name: str, value) -> int:
+ try:
+ result = int(value)
+ except Exception as exc:
+ raise ValueError(f"{name} must be a positive integer") from exc
+
+ if result <= 0:
+ raise ValueError(f"{name} must be a positive integer")
+
+ return result
+
+def _as_positive_triplet(name: str, value) -> Tuple[int, int, int]:
+ try:
+ parts = list(value)
+ except Exception as exc:
+ raise ValueError(f"{name} must contain exactly 3 positive integers") from exc
+
+ if len(parts) != 3:
+ raise ValueError(f"{name} must contain exactly 3 positive integers")
+
+ return (
+ _as_positive_int(f"{name}[0]", parts[0]),
+ _as_positive_int(f"{name}[1]", parts[1]),
+ _as_positive_int(f"{name}[2]", parts[2]),
+ )
+
+def set_dummy_context_params(
+ subgroup_size: int = None,
+ subgroup_enabled: bool = None,
+ max_workgroup_size: Tuple[int, int, int] = None,
+ max_workgroup_invocations: int = None,
+ max_workgroup_count: Tuple[int, int, int] = None,
+ max_shared_memory: int = None,
+) -> None:
+ """
+ Update cached context/device limit values for the active dummy backend context.
+
+ This only works when a dummy context already exists.
+ """
+ global __context
+
+ if not is_dummy():
+ raise RuntimeError(
+ "set_dummy_context_params() is only supported when running with backend='dummy'."
+ )
+
+ if __context is None:
+ __context = get_context()
+
+ validated_subgroup_size = None
+ validated_subgroup_enabled = None
+ validated_max_workgroup_size = None
+ validated_max_workgroup_invocations = None
+ validated_max_workgroup_count = None
+ validated_max_shared_memory = None
+
+ backend_kwargs = {}
+
+ if subgroup_size is not None:
+ validated_subgroup_size = _as_positive_int("subgroup_size", subgroup_size)
+ backend_kwargs["subgroup_size"] = validated_subgroup_size
+
+ if subgroup_enabled is not None:
+ if not isinstance(subgroup_enabled, bool):
+ raise ValueError("subgroup_enabled must be a boolean value")
+ validated_subgroup_enabled = bool(subgroup_enabled)
+ backend_kwargs["subgroup_enabled"] = subgroup_enabled
+
+ if max_workgroup_size is not None:
+ validated_max_workgroup_size = _as_positive_triplet("max_workgroup_size", max_workgroup_size)
+ backend_kwargs["max_workgroup_size"] = validated_max_workgroup_size
+
+ if max_workgroup_invocations is not None:
+ validated_max_workgroup_invocations = _as_positive_int(
+ "max_workgroup_invocations",
+ max_workgroup_invocations,
+ )
+ backend_kwargs["max_workgroup_invocations"] = validated_max_workgroup_invocations
+
+ if max_workgroup_count is not None:
+ validated_max_workgroup_count = _as_positive_triplet("max_workgroup_count", max_workgroup_count)
+ backend_kwargs["max_workgroup_count"] = validated_max_workgroup_count
+
+ if max_shared_memory is not None:
+ validated_max_shared_memory = _as_positive_int("max_shared_memory", max_shared_memory)
+ backend_kwargs["max_compute_shared_memory_size"] = validated_max_shared_memory
+
+ if backend_kwargs:
+ native.set_device_options(**backend_kwargs)
+ check_for_errors()
+
+ for device in __context.device_infos:
+ if validated_subgroup_size is not None:
+ device.sub_group_size = validated_subgroup_size
+
+ if validated_subgroup_enabled is not None:
+ if validated_subgroup_enabled:
+ device.supported_stages |= VK_SHADER_STAGE_COMPUTE_BIT
+ device.supported_operations |= VK_SUBGROUP_FEATURE_ARITHMETIC_BIT
+ else:
+ device.supported_stages &= ~VK_SHADER_STAGE_COMPUTE_BIT
+ device.supported_operations &= ~VK_SUBGROUP_FEATURE_ARITHMETIC_BIT
+
+ if validated_max_workgroup_size is not None:
+ device.max_workgroup_size = validated_max_workgroup_size
+
+ if validated_max_workgroup_invocations is not None:
+ device.max_workgroup_invocations = validated_max_workgroup_invocations
+
+ if validated_max_workgroup_count is not None:
+ device.max_workgroup_count = validated_max_workgroup_count
+
+ if validated_max_shared_memory is not None:
+ device.max_compute_shared_memory_size = validated_max_shared_memory
+
+ device.uniform_buffer_alignment = 0
+
+ __context.refresh_limits_from_device_infos()
+
def queue_wait_idle(queue_index: int = None, context: Context = None) -> None:
"""
Wait for the specified queue to finish processing. For all queues, leave queue_index as None.
@@ -490,8 +631,7 @@ def _sig_handler(signum, frame):
signal.signal(signum, signal.SIG_DFL)
os.kill(os.getpid(), signum)
-
-from .brython_utils import is_brython
-if not is_brython():
+# No need to register signal handlers in Brython, since it runs in a browser
+if not sys.implementation.name == "Brython":
signal.signal(signal.SIGINT, _sig_handler)
signal.signal(signal.SIGTERM, _sig_handler)
diff --git a/vkdispatch/base/descriptor_set.py b/vkdispatch/base/descriptor_set.py
index b4512456..e9d2823a 100644
--- a/vkdispatch/base/descriptor_set.py
+++ b/vkdispatch/base/descriptor_set.py
@@ -1,4 +1,4 @@
-from .backend import native
+from ..backends.backend_selection import native
from .errors import check_for_errors
@@ -8,6 +8,7 @@
from .image import Sampler
from .init import log_info
+from .init import is_cuda
class DescriptorSet(Handle):
"""TODO: Docstring"""
@@ -28,6 +29,9 @@ def __del__(self) -> None:
self.destroy()
def bind_buffer(self, buffer: Buffer, binding: int, offset: int = 0, range: int = 0, uniform: bool = False, read_access: bool = True, write_access: bool = True) -> None:
+ if write_access and not getattr(buffer, "is_writable", True):
+ raise ValueError("Cannot bind a read-only buffer with write access enabled.")
+
self.register_parent(buffer)
native.descriptor_set_write_buffer(
@@ -54,3 +58,10 @@ def bind_sampler(self, sampler: Sampler, binding: int, read_access: bool = True,
1 if write_access else 0
)
check_for_errors()
+
+ def set_inline_uniform_payload(self, payload: bytes) -> None:
+ if not is_cuda():
+ raise RuntimeError("Inline uniform payloads are currently only supported on CUDA backends.")
+
+ native.descriptor_set_write_inline_uniform(self._handle, payload)
+ check_for_errors()
diff --git a/vkdispatch/base/dtype.py b/vkdispatch/base/dtype.py
index fa796001..62ea81d3 100644
--- a/vkdispatch/base/dtype.py
+++ b/vkdispatch/base/dtype.py
@@ -1,6 +1,6 @@
from typing import Any, Optional
-from .._compat import numpy_compat as npc
+from ..compat import numpy_compat as npc
class dtype:
name: str
@@ -24,6 +24,18 @@ class _Scalar(dtype):
child_type = None
scalar = None
+class _I16(_Scalar):
+ name = "int16"
+ item_size = 2
+ glsl_type = "int16_t"
+ format_str = "%d"
+
+class _U16(_Scalar):
+ name = "uint16"
+ item_size = 2
+ glsl_type = "uint16_t"
+ format_str = "%u"
+
class _I32(_Scalar):
name = "int32"
item_size = 4
@@ -36,20 +48,61 @@ class _U32(_Scalar):
glsl_type = "uint"
format_str = "%u"
+class _I64(_Scalar):
+ name = "int64"
+ item_size = 8
+ glsl_type = "int64_t"
+ format_str = "%lld"
+
+class _U64(_Scalar):
+ name = "uint64"
+ item_size = 8
+ glsl_type = "uint64_t"
+ format_str = "%llu"
+
+class _F16(_Scalar):
+ name = "float16"
+ item_size = 2
+ glsl_type = "float16_t"
+ format_str = "%f"
+
class _F32(_Scalar):
name = "float32"
item_size = 4
glsl_type = "float"
format_str = "%f"
+class _F64(_Scalar):
+ name = "float64"
+ item_size = 8
+ glsl_type = "double"
+ format_str = "%lf"
+
+int16 = _I16 # type: ignore
+uint16 = _U16 # type: ignore
int32 = _I32 # type: ignore
uint32 = _U32 # type: ignore
+int64 = _I64 # type: ignore
+uint64 = _U64 # type: ignore
+float16 = _F16 # type: ignore
float32 = _F32 # type: ignore
+float64 = _F64 # type: ignore
class _Complex(dtype):
dimentions = 0
child_count = 2
+class _CF32(_Complex):
+ name = "complex32"
+ item_size = 4
+ glsl_type = "f16vec2"
+ format_str = "(%f, %f)"
+ child_type = float16
+ shape = (2,)
+ numpy_shape = (1,)
+ true_numpy_shape = ()
+ scalar = None
+
class _CF64(_Complex):
name = "complex64"
item_size = 8
@@ -61,11 +114,64 @@ class _CF64(_Complex):
true_numpy_shape = ()
scalar = None
+class _CF128(_Complex):
+ name = "complex128"
+ item_size = 16
+ glsl_type = "dvec2"
+ format_str = "(%lf, %lf)"
+ child_type = float64
+ shape = (2,)
+ numpy_shape = (1,)
+ true_numpy_shape = ()
+ scalar = None
+
+complex32 = _CF32 # type: ignore
complex64 = _CF64 # type: ignore
+complex128 = _CF128 # type: ignore
class _Vector(dtype):
dimentions = 1
+# --- float16 vectors ---
+
+class _V2F16(_Vector):
+ name = "hvec2"
+ item_size = 4
+ glsl_type = "f16vec2"
+ format_str = "(%f, %f)"
+ child_type = float16
+ child_count = 2
+ shape = (2,)
+ numpy_shape = (2,)
+ true_numpy_shape = (2,)
+ scalar = float16
+
+class _V3F16(_Vector):
+ name = "hvec3"
+ item_size = 6
+ glsl_type = "f16vec3"
+ format_str = "(%f, %f, %f)"
+ child_type = float16
+ child_count = 3
+ shape = (3,)
+ numpy_shape = (3,)
+ true_numpy_shape = (3,)
+ scalar = float16
+
+class _V4F16(_Vector):
+ name = "hvec4"
+ item_size = 8
+ glsl_type = "f16vec4"
+ format_str = "(%f, %f, %f, %f)"
+ child_type = float16
+ child_count = 4
+ shape = (4,)
+ numpy_shape = (4,)
+ true_numpy_shape = (4,)
+ scalar = float16
+
+# --- float32 vectors ---
+
class _V2F32(_Vector):
name = "vec2"
item_size = 8
@@ -102,6 +208,84 @@ class _V4F32(_Vector):
true_numpy_shape = (4,)
scalar = float32
+# --- float64 vectors ---
+
+class _V2F64(_Vector):
+ name = "dvec2"
+ item_size = 16
+ glsl_type = "dvec2"
+ format_str = "(%lf, %lf)"
+ child_type = float64
+ child_count = 2
+ shape = (2,)
+ numpy_shape = (2,)
+ true_numpy_shape = (2,)
+ scalar = float64
+
+class _V3F64(_Vector):
+ name = "dvec3"
+ item_size = 24
+ glsl_type = "dvec3"
+ format_str = "(%lf, %lf, %lf)"
+ child_type = float64
+ child_count = 3
+ shape = (3,)
+ numpy_shape = (3,)
+ true_numpy_shape = (3,)
+ scalar = float64
+
+class _V4F64(_Vector):
+ name = "dvec4"
+ item_size = 32
+ glsl_type = "dvec4"
+ format_str = "(%lf, %lf, %lf, %lf)"
+ child_type = float64
+ child_count = 4
+ shape = (4,)
+ numpy_shape = (4,)
+ true_numpy_shape = (4,)
+ scalar = float64
+
+# --- int16 vectors ---
+
+class _V2I16(_Vector):
+ name = "ihvec2"
+ item_size = 4
+ glsl_type = "i16vec2"
+ format_str = "(%d, %d)"
+ child_type = int16
+ child_count = 2
+ shape = (2,)
+ numpy_shape = (2,)
+ true_numpy_shape = (2,)
+ scalar = int16
+
+class _V3I16(_Vector):
+ name = "ihvec3"
+ item_size = 6
+ glsl_type = "i16vec3"
+ format_str = "(%d, %d, %d)"
+ child_type = int16
+ child_count = 3
+ shape = (3,)
+ numpy_shape = (3,)
+ true_numpy_shape = (3,)
+ scalar = int16
+
+class _V4I16(_Vector):
+ name = "ihvec4"
+ item_size = 8
+ glsl_type = "i16vec4"
+ format_str = "(%d, %d, %d, %d)"
+ child_type = int16
+ child_count = 4
+ shape = (4,)
+ numpy_shape = (4,)
+ true_numpy_shape = (4,)
+ scalar = int16
+
+# --- int32 vectors ---
+
class _V2I32(_Vector):
name = "ivec2"
item_size = 8
@@ -138,6 +322,46 @@ class _V4I32(_Vector):
true_numpy_shape = (4,)
scalar = int32
+# --- uint16 vectors ---
+
+class _V2U16(_Vector):
+ name = "uhvec2"
+ item_size = 4
+ glsl_type = "u16vec2"
+ format_str = "(%u, %u)"
+ child_type = uint16
+ child_count = 2
+ shape = (2,)
+ numpy_shape = (2,)
+ true_numpy_shape = (2,)
+ scalar = uint16
+
+class _V3U16(_Vector):
+ name = "uhvec3"
+ item_size = 6
+ glsl_type = "u16vec3"
+ format_str = "(%u, %u, %u)"
+ child_type = uint16
+ child_count = 3
+ shape = (3,)
+ numpy_shape = (3,)
+ true_numpy_shape = (3,)
+ scalar = uint16
+
+class _V4U16(_Vector):
+ name = "uhvec4"
+ item_size = 8
+ glsl_type = "u16vec4"
+ format_str = "(%u, %u, %u, %u)"
+ child_type = uint16
+ child_count = 4
+ shape = (4,)
+ numpy_shape = (4,)
+ true_numpy_shape = (4,)
+ scalar = uint16
+
+# --- uint32 vectors ---
+
class _V2U32(_Vector):
name = "uvec2"
item_size = 8
@@ -174,12 +398,24 @@ class _V4U32(_Vector):
true_numpy_shape = (4,)
scalar = uint32
+hvec2 = _V2F16 # type: ignore
+hvec3 = _V3F16 # type: ignore
+hvec4 = _V4F16 # type: ignore
vec2 = _V2F32 # type: ignore
vec3 = _V3F32 # type: ignore
vec4 = _V4F32 # type: ignore
+dvec2 = _V2F64 # type: ignore
+dvec3 = _V3F64 # type: ignore
+dvec4 = _V4F64 # type: ignore
+ihvec2 = _V2I16 # type: ignore
+ihvec3 = _V3I16 # type: ignore
+ihvec4 = _V4I16 # type: ignore
ivec2 = _V2I32 # type: ignore
ivec3 = _V3I32 # type: ignore
ivec4 = _V4I32 # type: ignore
+uhvec2 = _V2U16 # type: ignore
+uhvec3 = _V3U16 # type: ignore
+uhvec4 = _V4U16 # type: ignore
uvec2 = _V2U32 # type: ignore
uvec3 = _V3U32 # type: ignore
uvec4 = _V4U32 # type: ignore
@@ -227,39 +463,25 @@ class _M4F32(_Matrix):
mat3 = _M3F32
mat4 = _M4F32
+# Maps scalar dtype -> {count: vector_dtype}
+_VECTOR_TABLE = {
+ int16: {1: int16, 2: ihvec2, 3: ihvec3, 4: ihvec4},
+ uint16: {1: uint16, 2: uhvec2, 3: uhvec3, 4: uhvec4},
+ int32: {1: int32, 2: ivec2, 3: ivec3, 4: ivec4},
+ uint32: {1: uint32, 2: uvec2, 3: uvec3, 4: uvec4},
+ float16: {1: float16, 2: hvec2, 3: hvec3, 4: hvec4},
+ float32: {1: float32, 2: vec2, 3: vec3, 4: vec4},
+ float64: {1: float64, 2: dvec2, 3: dvec3, 4: dvec4},
+}
+
def to_vector(dtype: dtype, count: int) -> dtype: # type: ignore
if count < 1 or count > 4:
raise ValueError(f"Unsupported count ({count})!")
- if dtype == int32:
- if count == 1:
- return int32
- elif count == 2:
- return ivec2
- elif count == 3:
- return ivec3
- elif count == 4:
- return ivec4
- elif dtype == uint32:
- if count == 1:
- return uint32
- elif count == 2:
- return uvec2
- elif count == 3:
- return uvec3
- elif count == 4:
- return uvec4
- elif dtype == float32:
- if count == 1:
- return float32
- elif count == 2:
- return vec2
- elif count == 3:
- return vec3
- elif count == 4:
- return vec4
- else:
+ table = _VECTOR_TABLE.get(dtype)
+ if table is None:
raise ValueError(f"Unsupported dtype ({dtype})!")
+ return table[count]
def is_dtype(in_type: dtype) -> bool:
return issubclass(in_type, dtype) # type: ignore
@@ -280,23 +502,65 @@ def is_float_dtype(dtype: dtype) -> bool:
if not is_scalar(dtype):
dtype = dtype.scalar
- return dtype == float32 # or dtype == complex64
+ return dtype == float16 or dtype == float32 or dtype == float64
def is_integer_dtype(dtype: dtype) -> bool:
if not is_scalar(dtype):
dtype = dtype.scalar
- return dtype == int32 or dtype == uint32
+ return dtype in (int16, uint16, int32, uint32, int64, uint64)
+
+# Promotion precedence: float64 > float32 > float16 > int64 > int32 > int16 > uint64 > uint32 > uint16
+_SCALAR_RANK = {
+ uint16: 0,
+ int16: 1,
+ uint32: 2,
+ int32: 3,
+ uint64: 4,
+ int64: 5,
+ float16: 6,
+ float32: 7,
+ float64: 8,
+}
+
+_COMPLEX_FROM_FLOAT = {
+ float16: complex32,
+ float32: complex64,
+ float64: complex128,
+}
+
+def complex_from_float(dtype: dtype) -> dtype:
+ if not is_scalar(dtype):
+ raise ValueError(f"Unsupported dtype ({dtype})!")
+
+ result = _COMPLEX_FROM_FLOAT.get(dtype)
+ if result is None:
+ raise ValueError(f"Unsupported complex base dtype ({dtype})!")
+ return result
+
+def _promote_scalar(dtype: dtype) -> dtype:
+ """Return the floating-point type that matches the width of *dtype*.
+
+ Used by make_floating_dtype to convert integer scalars to their natural
+ floating counterpart.
+ """
+ if dtype == int16 or dtype == uint16:
+ return float16
+ if dtype == int32 or dtype == uint32:
+ return float32
+ if dtype == int64 or dtype == uint64:
+ return float64
+ return dtype
def make_floating_dtype(dtype: dtype) -> dtype:
if is_scalar(dtype):
- return float32
+ return _promote_scalar(dtype)
elif is_vector(dtype):
- return to_vector(float32, dtype.child_count)
+ return to_vector(_promote_scalar(dtype.scalar), dtype.child_count)
elif is_matrix(dtype):
return dtype
elif is_complex(dtype):
- return complex64
+ return dtype
else:
raise ValueError(f"Unsupported dtype ({dtype})!")
@@ -308,14 +572,10 @@ def vector_size(dtype: dtype) -> int:
def cross_scalar_scalar(dtype1: dtype, dtype2: dtype) -> dtype:
assert is_scalar(dtype1) and is_scalar(dtype2), "Both types must be scalar types!"
-
- if dtype1 == float32 or dtype2 == float32:
- return float32
-
- if dtype1 == int32 or dtype2 == int32:
- return int32
-
- return uint32
+
+ r1 = _SCALAR_RANK[dtype1]
+ r2 = _SCALAR_RANK[dtype2]
+ return dtype1 if r1 >= r2 else dtype2
def cross_vector_scalar(dtype1: dtype, dtype2: dtype) -> dtype:
assert is_vector(dtype1) and is_scalar(dtype2), "First type must be vector type and second type must be scalar type!"
@@ -354,10 +614,10 @@ def cross_matrix(dtype1: dtype, dtype2: dtype) -> dtype:
if is_vector(dtype2) or is_complex(dtype2):
raise ValueError("Cannot cross matrix and vector/complex types!")
-
+
if is_scalar(dtype2):
return dtype1
-
+
raise ValueError("Second type must be matrix or scalar type!")
def cross_type(dtype1: dtype, dtype2: dtype) -> dtype:
@@ -370,38 +630,91 @@ def cross_type(dtype1: dtype, dtype2: dtype) -> dtype:
return cross_vector(dtype1, dtype2)
elif is_vector(dtype2):
return cross_vector(dtype2, dtype1)
-
+
if is_complex(dtype1):
- return complex64
+ if is_complex(dtype2):
+ return complex_from_float(cross_scalar_scalar(dtype1.child_type, dtype2.child_type))
+ if is_scalar(dtype2):
+ return complex_from_float(cross_scalar_scalar(dtype1.child_type, _promote_scalar(dtype2)))
+ raise ValueError("Cannot cross complex and non-scalar types!")
elif is_complex(dtype2):
- return complex64
-
+ if is_scalar(dtype1):
+ return complex_from_float(cross_scalar_scalar(dtype2.child_type, _promote_scalar(dtype1)))
+ raise ValueError("Cannot cross complex and non-scalar types!")
+
if is_scalar(dtype1) and is_scalar(dtype2):
return cross_scalar_scalar(dtype1, dtype2)
+def cross_multiply_type(dtype1: dtype, dtype2: dtype) -> dtype:
+ """Resolve result type for multiplication.
+
+ Unlike ``cross_type``, multiplication is order-sensitive for matrix/vector
+ combinations and supports ``matN * vecN`` and ``vecN * matN``.
+ """
+ if is_matrix(dtype1) and is_vector(dtype2):
+ if dtype1.child_count != dtype2.child_count:
+ raise ValueError(
+ f"Cannot multiply matrix '{dtype1.name}' and vector '{dtype2.name}' with incompatible dimensions!"
+ )
+ if dtype1.scalar != float32 or dtype2.scalar != float32:
+ raise ValueError("Matrix/vector multiplication only supports float32 matrix and vector types.")
+ return dtype2
+
+ if is_vector(dtype1) and is_matrix(dtype2):
+ if dtype1.child_count != dtype2.child_count:
+ raise ValueError(
+ f"Cannot multiply vector '{dtype1.name}' and matrix '{dtype2.name}' with incompatible dimensions!"
+ )
+ if dtype1.scalar != float32 or dtype2.scalar != float32:
+ raise ValueError("Matrix/vector multiplication only supports float32 matrix and vector types.")
+ return dtype1
+
+ return cross_type(dtype1, dtype2)
+
def from_numpy_dtype(dtype: Any) -> dtype:
dtype_name = npc.host_dtype_name(dtype)
- if dtype_name == "int32":
- return int32
- elif dtype_name == "uint32":
- return uint32
- elif dtype_name == "float32":
- return float32
- elif dtype_name == "complex64":
- return complex64
- else:
+ _NAME_MAP = {
+ "int16": int16,
+ "uint16": uint16,
+ "int32": int32,
+ "uint32": uint32,
+ "int64": int64,
+ "uint64": uint64,
+ "float16": float16,
+ "float32": float32,
+ "float64": float64,
+ "complex32": complex32,
+ "complex64": complex64,
+ "complex128": complex128,
+ }
+
+ result = _NAME_MAP.get(dtype_name)
+ if result is None:
raise ValueError(f"Unsupported dtype ({dtype})!")
+ return result
def to_numpy_dtype(shader_type: dtype) -> Any:
- if shader_type == int32:
- return npc.host_dtype("int32") if not npc.HAS_NUMPY else npc.numpy_module().int32
- elif shader_type == uint32:
- return npc.host_dtype("uint32") if not npc.HAS_NUMPY else npc.numpy_module().uint32
- elif shader_type == float32:
- return npc.host_dtype("float32") if not npc.HAS_NUMPY else npc.numpy_module().float32
- elif shader_type == complex64:
- return npc.host_dtype("complex64") if not npc.HAS_NUMPY else npc.numpy_module().complex64
- else:
+ _TYPE_MAP = {
+ int16: "int16",
+ uint16: "uint16",
+ int32: "int32",
+ uint32: "uint32",
+ int64: "int64",
+ uint64: "uint64",
+ float16: "float16",
+ float32: "float32",
+ float64: "float64",
+ complex32: "complex32",
+ complex64: "complex64",
+ complex128: "complex128",
+ }
+
+ name = _TYPE_MAP.get(shader_type)
+ if name is None:
raise ValueError(f"Unsupported shader_type ({shader_type})!")
+
+ if npc.HAS_NUMPY and hasattr(npc.numpy_module(), name):
+ return getattr(npc.numpy_module(), name)
+ return npc.host_dtype(name)
diff --git a/vkdispatch/base/errors.py b/vkdispatch/base/errors.py
index 51bd308a..ca6068b1 100644
--- a/vkdispatch/base/errors.py
+++ b/vkdispatch/base/errors.py
@@ -1,4 +1,4 @@
-from .backend import native
+from ..backends.backend_selection import native
running = True
@@ -26,7 +26,8 @@ def check_for_errors():
raise RuntimeError(error)
else:
raise RuntimeError("Unknown error occurred")
-
+
+
def check_for_compute_stage_errors():
"""
Check for errors in the shader compilation stage of the vkdispatch_native library and raise a RuntimeError if found.
diff --git a/vkdispatch/base/image.py b/vkdispatch/base/image.py
index bb1d1427..f78ec483 100644
--- a/vkdispatch/base/image.py
+++ b/vkdispatch/base/image.py
@@ -1,9 +1,9 @@
import typing
from enum import Enum
-from .backend import native
+from ..backends.backend_selection import native
-from .._compat import numpy_compat as npc
+from ..compat import numpy_compat as npc
from . import dtype as vdt
from .context import Handle
diff --git a/vkdispatch/base/init.py b/vkdispatch/base/init.py
index 34a084a4..a4aa7c26 100644
--- a/vkdispatch/base/init.py
+++ b/vkdispatch/base/init.py
@@ -6,12 +6,17 @@
import inspect
from .errors import check_for_errors
-from .backend import (
+from ..backends.backend_selection import (
+ BACKEND_CUDA,
+ BACKEND_OPENCL,
BACKEND_VULKAN,
+ BACKEND_DUMMY,
+ BackendUnavailableError,
clear_active_backend,
get_active_backend_name,
+ get_backend_module,
native,
- normalize_backend_name,
+ get_environment_backend,
set_active_backend,
)
@@ -262,7 +267,19 @@ def get_info_string(self, verbose: bool = False) -> str:
result = f"Device {self.sorted_index}: {self.device_name}\n"
- result += f"\tVulkan Version: {self.version_major}.{self.version_minor}.{self.version_patch}\n"
+ backend_type = "Vulkan"
+ version_number = f"{self.version_major}.{self.version_minor}.{self.version_patch}"
+
+ if is_cuda():
+ backend_type = "CUDA Compute Capability"
+ version_number = f"{self.version_major}.{self.version_minor}"
+ elif is_opencl():
+ backend_type = "OpenCL"
+ version_number = f"{self.version_major}.{self.version_minor}"
+ elif is_dummy():
+ backend_type = "Dummy"
+
+ result += f"\t{backend_type} Version: {version_number}\n"
result += f"\tDevice Type: {device_type_id_to_str_dict[self.device_type]}\n"
if self.version_variant != 0:
@@ -396,46 +413,55 @@ def get_cuda_device_map():
return uuid_map
-def initialize(
- debug_mode: bool = False,
- log_level: LogLevel = LogLevel.WARNING,
- loader_debug_logs: bool = False,
- backend: Optional[str] = None,
-):
- """
- A function which initializes the Vulkan dispatch library.
-
- Args:
- debug_mode (`bool`): A flag to enable debug mode.
- log_level (`LogLevel`): The log level, which is one of the following:
- LogLevel.VERBOSE
- LogLevel.INFO
- LogLevel.WARNING
- LogLevel.ERROR
- loader_debug_logs (bool): A flag to enable vulkan loader debug logs.
- backend (`Optional[str]`): Runtime backend to use. Supported values are
- "vulkan" and "pycuda". If omitted, the currently selected backend is
- reused. If no backend was selected yet, `VKDISPATCH_BACKEND` is used
- when set, otherwise "vulkan" is used.
- """
+def _set_initialized_state(backend_name: str, devices: List[DeviceInfo]) -> None:
global __initilized_instance
- global __device_infos
global __backend_name
+ global __device_infos
- backend_name = normalize_backend_name(
- backend
- if backend is not None
- else get_active_backend_name(os.environ.get("VKDISPATCH_BACKEND"))
+ __initilized_instance = True
+ __backend_name = backend_name
+ __device_infos = devices
+
+ for ii, dev in enumerate(__device_infos):
+ dev.sorted_index = ii
+
+
+def _build_no_gpu_backend_error(
+ vulkan_error: Exception,
+ cuda_python_error: Exception,
+ opencl_error: Exception,
+) -> RuntimeError:
+ return RuntimeError(
+ "vkdispatch could not find an available GPU backend.\n"
+ f"Vulkan backend unavailable: {vulkan_error}\n"
+ f"CUDA Python backend unavailable: {cuda_python_error}\n"
+ f"OpenCL backend unavailable: {opencl_error}\n"
+ "Install the Vulkan backend with `pip install vkdispatch`, or install CUDA support "
+ "(`pip install cuda-python`), or install OpenCL support (`pip install pyopencl`), "
+ "or explicitly use `vd.initialize(backend='dummy')` "
+ "for codegen-only workflows."
)
- if __initilized_instance:
- if __backend_name != backend_name:
- raise RuntimeError(
- f"vkdispatch is already initialized with backend '{__backend_name}'. "
- f"Cannot reinitialize with '{backend_name}' in the same process."
- )
- return
+
+def _build_vulkan_backend_error(vulkan_error: Exception) -> RuntimeError:
+ return RuntimeError(
+ "vkdispatch could not load the Vulkan backend.\n"
+ f"Vulkan backend unavailable: {vulkan_error}\n"
+ "Install the Vulkan backend with `pip install vkdispatch`, use a CUDA backend "
+ "(`pip install cuda-python`), use an OpenCL backend (`pip install pyopencl`), "
+ "or explicitly use `vd.initialize(backend='dummy')` "
+ "for codegen-only workflows."
+ )
+
+
+def _initialize_with_backend(
+ backend_name: str,
+ debug_mode: bool,
+ log_level: LogLevel,
+ loader_debug_logs: bool,
+) -> None:
+ global __initilized_instance
set_active_backend(backend_name)
@@ -443,6 +469,9 @@ def initialize(
if loader_debug_logs and backend_name == BACKEND_VULKAN:
os.environ["VK_LOADER_DEBUG"] = "all"
+ # Force import now so backend availability errors are distinct from runtime init errors.
+ get_backend_module(backend_name)
+
native.init(debug_mode, log_level.value)
check_for_errors()
@@ -452,59 +481,126 @@ def initialize(
]
if backend_name != BACKEND_VULKAN:
- __initilized_instance = True
- __backend_name = backend_name
- __device_infos = devivces
- for ii, dev in enumerate(__device_infos):
- dev.sorted_index = ii
+ _set_initialized_state(backend_name, devivces)
return
is_cuda = any(dev.is_nvidia() for dev in devivces)
-
cuda_uuids = get_cuda_device_map() if is_cuda else None
if cuda_uuids is None:
- __initilized_instance = True
- __backend_name = backend_name
- __device_infos = devivces
- for ii, dev in enumerate(__device_infos):
- dev.sorted_index = ii
+ _set_initialized_state(backend_name, devivces)
return
-
+
# try to match CUDA devices to Vulkan devices by UUID
cuda_uuid_to_index = {
uuid_bytes: cuda_index
for cuda_index, uuid_bytes in cuda_uuids.items()
}
- matched_devices: List[Tuple[int, DeviceInfo, int]]= []
+ matched_devices: List[Tuple[int, DeviceInfo]] = []
unmatched_devices: List[DeviceInfo] = []
for dev in devivces:
if dev.uuid is not None and dev.uuid in cuda_uuid_to_index:
- #print(f"Matched Vulkan device {ii} ({dev.device_name}) to CUDA device {cuda_uuid_to_index[dev.uuid]} with UUID {dev.uuid.hex()}")
- matched_devices.append( (cuda_uuid_to_index[dev.uuid], dev) )
+ matched_devices.append((cuda_uuid_to_index[dev.uuid], dev))
else:
- #print(f"Could not match Vulkan device {ii} ({dev.device_name}) with UUID {dev.uuid.hex()} to any CUDA device")
unmatched_devices.append(dev)
- # sort matched devices by CUDA index
matched_devices.sort(key=lambda x: x[0])
-
- # return matched devices first (by CUDA index), then unmatched devices (by Vulkan order)
result = [dev for _, dev in matched_devices] + unmatched_devices
- #result_ids = [ii for _, _, ii in matched_devices] + unmatched_device_ids
for dev_id, dev in enumerate(result):
- #print(f"Final device order index {dev.sorted_index} -> Vulkan device {dev_id} ({dev.device_name})")
dev.sorted_index = dev_id
-
- __initilized_instance = True
- __backend_name = backend_name
- __device_infos = result
+
+ _set_initialized_state(backend_name, result)
except Exception:
if not __initilized_instance:
clear_active_backend()
raise
+def initialize(
+ debug_mode: bool = False,
+ log_level: LogLevel = LogLevel.WARNING,
+ loader_debug_logs: bool = False,
+ backend: Optional[str] = None,
+):
+ """
+ A function which initializes the Vulkan dispatch library.
+
+ Args:
+ debug_mode (`bool`): A flag to enable debug mode.
+ log_level (`LogLevel`): The log level, which is one of the following:
+ LogLevel.VERBOSE
+ LogLevel.INFO
+ LogLevel.WARNING
+ LogLevel.ERROR
+ loader_debug_logs (bool): A flag to enable vulkan loader debug logs.
+ backend (`Optional[str]`): Runtime backend to use. Supported values are
+ "vulkan", "cuda", "opencl", and "dummy". If omitted, the currently selected backend is
+ reused. If no backend was selected yet, `VKDISPATCH_BACKEND` is used
+ when set, otherwise "vulkan" is used.
+ """
+
+ global __initilized_instance
+
+ backend_name = get_active_backend_name(backend)
+ backend_explicitly_selected = (backend is not None) or (get_environment_backend() is not None)
+
+ if __initilized_instance:
+ if __backend_name != backend_name:
+ raise RuntimeError(
+ f"vkdispatch is already initialized with backend '{__backend_name}'. "
+ f"Cannot reinitialize with '{backend_name}' in the same process."
+ )
+ return
+
+ if (
+ not backend_explicitly_selected
+ and backend_name == BACKEND_VULKAN
+ ):
+ try:
+ _initialize_with_backend(
+ BACKEND_VULKAN,
+ debug_mode=debug_mode,
+ log_level=log_level,
+ loader_debug_logs=loader_debug_logs,
+ )
+ return
+ except BackendUnavailableError as vulkan_error:
+ try:
+ _initialize_with_backend(
+ BACKEND_CUDA,
+ debug_mode=debug_mode,
+ log_level=log_level,
+ loader_debug_logs=loader_debug_logs,
+ )
+ return
+ except Exception as cuda_python_error:
+ try:
+ _initialize_with_backend(
+ BACKEND_OPENCL,
+ debug_mode=debug_mode,
+ log_level=log_level,
+ loader_debug_logs=loader_debug_logs,
+ )
+ return
+ except Exception as opencl_error:
+ raise _build_no_gpu_backend_error(
+ vulkan_error,
+ cuda_python_error,
+ opencl_error,
+ ) from opencl_error
+
+ try:
+ _initialize_with_backend(
+ backend_name,
+ debug_mode=debug_mode,
+ log_level=log_level,
+ loader_debug_logs=loader_debug_logs,
+ )
+ except BackendUnavailableError as backend_error:
+ if backend_name == BACKEND_VULKAN:
+ raise _build_vulkan_backend_error(backend_error) from backend_error
+ raise
+
def get_devices() -> List[DeviceInfo]:
"""
@@ -516,7 +612,7 @@ def get_devices() -> List[DeviceInfo]:
global __device_infos
- initialize(backend=get_active_backend_name())
+ initialize()
return __device_infos
@@ -527,6 +623,46 @@ def get_backend() -> str:
return get_active_backend_name()
+def is_vulkan() -> bool:
+ """
+ A function which checks if the active backend is the Vulkan backend.
+
+ Returns:
+ `bool`: A flag indicating whether the active backend is the Vulkan backend.
+ """
+
+ return get_backend() == BACKEND_VULKAN
+
+def is_cuda() -> bool:
+ """
+ A function which checks if the active backend is a CUDA backend.
+
+ Returns:
+ `bool`: A flag indicating whether the active backend is a CUDA backend.
+ """
+
+ return get_backend() == BACKEND_CUDA
+
+def is_opencl() -> bool:
+ """
+ A function which checks if the active backend is the OpenCL backend.
+
+ Returns:
+ `bool`: A flag indicating whether the active backend is the OpenCL backend.
+ """
+
+ return get_backend() == BACKEND_OPENCL
+
+def is_dummy() -> bool:
+ """
+ A function which checks if the active backend is the dummy backend.
+
+ Returns:
+ `bool`: A flag indicating whether the active backend is the dummy backend.
+ """
+
+ return get_backend() == BACKEND_DUMMY
+
def __log_noinit(text: str, end: str = '\n', level: LogLevel = LogLevel.ERROR, stack_offset: int = 1):
"""
A function which logs a message at the specified log level.
@@ -553,7 +689,7 @@ def log(text: str, end: str = '\n', level: LogLevel = LogLevel.ERROR, stack_offs
message (`str`): The message to log.
"""
- initialize(backend=get_active_backend_name())
+ initialize()
__log_noinit(text, end, level, stack_offset + 1)
@@ -605,6 +741,6 @@ def set_log_level(level: LogLevel):
level (`LogLevel`): The log level.
"""
- initialize(backend=get_active_backend_name())
+ initialize()
native.set_log_level(level.value)
diff --git a/vkdispatch/codegen/__init__.py b/vkdispatch/codegen/__init__.py
index 0aa98580..1d07e8eb 100644
--- a/vkdispatch/codegen/__init__.py
+++ b/vkdispatch/codegen/__init__.py
@@ -30,16 +30,31 @@
from .functions.atomic_memory import atomic_add
-from .functions.type_casting import to_dtype, str_to_dtype, to_float, to_int, to_uint
-from .functions.type_casting import to_vec2, to_vec3, to_vec4, to_complex
-from .functions.type_casting import to_uvec2, to_uvec3, to_uvec4
+from .functions.type_casting import to_dtype, str_to_dtype
+from .functions.type_casting import to_float16, to_float, to_float64
+from .functions.type_casting import to_int16, to_int, to_int64, to_uint16, to_uint, to_uint64
+from .functions.type_casting import to_complex, to_complex32, to_complex64, to_complex128
+from .functions.type_casting import to_hvec2, to_hvec3, to_hvec4
+from .functions.type_casting import to_vec2, to_vec3, to_vec4
+from .functions.type_casting import to_dvec2, to_dvec3, to_dvec4
+from .functions.type_casting import to_ihvec2, to_ihvec3, to_ihvec4
from .functions.type_casting import to_ivec2, to_ivec3, to_ivec4
+from .functions.type_casting import to_uhvec2, to_uhvec3, to_uhvec4
+from .functions.type_casting import to_uvec2, to_uvec3, to_uvec4
from .functions.type_casting import to_mat2, to_mat3, to_mat4
-from .functions.registers import new_register, new_float_register, new_int_register, new_uint_register
-from .functions.registers import new_vec2_register, new_ivec2_register, new_uvec2_register, new_complex_register
-from .functions.registers import new_vec3_register, new_ivec3_register, new_uvec3_register
-from .functions.registers import new_vec4_register, new_ivec4_register, new_uvec4_register
+from .functions.registers import new_register, new_complex_register
+from .functions.registers import new_float16_register, new_float_register, new_float64_register
+from .functions.registers import new_int16_register, new_int_register, new_int64_register
+from .functions.registers import new_uint16_register, new_uint_register, new_uint64_register
+from .functions.registers import new_complex32_register, new_complex64_register, new_complex128_register
+from .functions.registers import new_hvec2_register, new_hvec3_register, new_hvec4_register
+from .functions.registers import new_vec2_register, new_vec3_register, new_vec4_register
+from .functions.registers import new_dvec2_register, new_dvec3_register, new_dvec4_register
+from .functions.registers import new_ihvec2_register, new_ihvec3_register, new_ihvec4_register
+from .functions.registers import new_ivec2_register, new_ivec3_register, new_ivec4_register
+from .functions.registers import new_uhvec2_register, new_uhvec3_register, new_uhvec4_register
+from .functions.registers import new_uvec2_register, new_uvec3_register, new_uvec4_register
from .functions.registers import new_mat2_register, new_mat3_register, new_mat4_register
from .functions.subgroups import subgroup_add, subgroup_mul
@@ -56,7 +71,7 @@
from .functions.builtin_constants import global_invocation_id, local_invocation_id, workgroup_id, local_invocation_index
from .functions.builtin_constants import workgroup_size, num_workgroups, num_subgroups, subgroup_id
-from .functions.builtin_constants import subgroup_size, subgroup_invocation_id, inf_f32, ninf_f32
+from .functions.builtin_constants import subgroup_size, subgroup_invocation_id, inf_f32, ninf_f32, inf_f64, ninf_f64, inf_f16, ninf_f16
from .functions.index_raveling import ravel_index, unravel_index
@@ -66,7 +81,7 @@
from .builder import ShaderBinding, ShaderDescription
from .builder import ShaderBuilder, ShaderFlags
-from .backends import CodeGenBackend, GLSLBackend, CUDABackend
+from .backends import CodeGenBackend, GLSLBackend, CUDABackend, OpenCLBackend
from .global_builder import set_builder, get_builder, shared_buffer, set_shader_print_line_numbers, get_shader_print_line_numbers
from .global_builder import set_codegen_backend, get_codegen_backend
diff --git a/vkdispatch/codegen/abreviations.py b/vkdispatch/codegen/abreviations.py
index 1fdff076..f9815812 100644
--- a/vkdispatch/codegen/abreviations.py
+++ b/vkdispatch/codegen/abreviations.py
@@ -7,20 +7,40 @@
from .arguments import Image2D as Img2
from .arguments import Image3D as Img3
+from vkdispatch.base.dtype import float16 as f16
from vkdispatch.base.dtype import float32 as f32
-from vkdispatch.base.dtype import uint32 as u32
+from vkdispatch.base.dtype import float64 as f64
+from vkdispatch.base.dtype import int16 as i16
+from vkdispatch.base.dtype import uint16 as u16
from vkdispatch.base.dtype import int32 as i32
+from vkdispatch.base.dtype import uint32 as u32
+from vkdispatch.base.dtype import int64 as i64
+from vkdispatch.base.dtype import uint64 as u64
+from vkdispatch.base.dtype import complex32 as c32
from vkdispatch.base.dtype import complex64 as c64
+from vkdispatch.base.dtype import complex128 as c128
+from vkdispatch.base.dtype import hvec2 as hv2
+from vkdispatch.base.dtype import hvec3 as hv3
+from vkdispatch.base.dtype import hvec4 as hv4
from vkdispatch.base.dtype import vec2 as v2
from vkdispatch.base.dtype import vec3 as v3
from vkdispatch.base.dtype import vec4 as v4
-from vkdispatch.base.dtype import uvec2 as uv2
-from vkdispatch.base.dtype import uvec3 as uv3
-from vkdispatch.base.dtype import uvec4 as uv4
+from vkdispatch.base.dtype import dvec2 as dv2
+from vkdispatch.base.dtype import dvec3 as dv3
+from vkdispatch.base.dtype import dvec4 as dv4
+from vkdispatch.base.dtype import ihvec2 as ihv2
+from vkdispatch.base.dtype import ihvec3 as ihv3
+from vkdispatch.base.dtype import ihvec4 as ihv4
from vkdispatch.base.dtype import ivec2 as iv2
from vkdispatch.base.dtype import ivec3 as iv3
from vkdispatch.base.dtype import ivec4 as iv4
+from vkdispatch.base.dtype import uhvec2 as uhv2
+from vkdispatch.base.dtype import uhvec3 as uhv3
+from vkdispatch.base.dtype import uhvec4 as uhv4
+from vkdispatch.base.dtype import uvec2 as uv2
+from vkdispatch.base.dtype import uvec3 as uv3
+from vkdispatch.base.dtype import uvec4 as uv4
from vkdispatch.base.dtype import mat2 as m2
from vkdispatch.base.dtype import mat4 as m4
diff --git a/vkdispatch/codegen/backends/__init__.py b/vkdispatch/codegen/backends/__init__.py
index 0ddf53ce..773f5bee 100644
--- a/vkdispatch/codegen/backends/__init__.py
+++ b/vkdispatch/codegen/backends/__init__.py
@@ -1,3 +1,4 @@
from .base import CodeGenBackend
from .glsl import GLSLBackend
from .cuda import CUDABackend
+from .opencl import OpenCLBackend
diff --git a/vkdispatch/codegen/backends/base.py b/vkdispatch/codegen/backends/base.py
index 5c34ab0b..aafdab6f 100644
--- a/vkdispatch/codegen/backends/base.py
+++ b/vkdispatch/codegen/backends/base.py
@@ -40,12 +40,74 @@ def mark_composite_binary_op(
def type_name(self, var_type: dtypes.dtype) -> str:
raise NotImplementedError
- def constructor(self, var_type: dtypes.dtype, args: List[str]) -> str:
+ def constructor(
+ self,
+ var_type: dtypes.dtype,
+ args: List[str],
+ arg_types: Optional[List[Optional[dtypes.dtype]]] = None,
+ ) -> str:
+ _ = arg_types
raise NotImplementedError
+ def component_access_expr(self, expr: str, component: str, base_type: dtypes.dtype) -> str:
+ return f"{expr}.{component}"
+
+ def buffer_component_expr(
+ self,
+ scalar_buffer_expr: str,
+ base_type: dtypes.dtype,
+ element_index_expr: str,
+ component_index_expr: str,
+ ) -> Optional[str]:
+ _ = (scalar_buffer_expr, base_type, element_index_expr, component_index_expr)
+ return None
+
def fma_function_name(self, var_type: dtypes.dtype) -> str:
return "fma"
+ def math_func_name(self, func_name: str, var_type: dtypes.dtype) -> str:
+ """Return the backend-specific function name for a math operation.
+
+ Backends can override this to remap function names for specific types
+ (e.g. CUDA __half intrinsics).
+ """
+ return func_name
+
+ def unary_math_expr(self, func_name: str, arg_type: dtypes.dtype, arg_expr: str) -> str:
+ return f"{self.math_func_name(func_name, arg_type)}({arg_expr})"
+
+ def binary_math_expr(
+ self,
+ func_name: str,
+ lhs_type: dtypes.dtype,
+ lhs_expr: str,
+ rhs_type: dtypes.dtype,
+ rhs_expr: str,
+ ) -> str:
+ mapped = self.math_func_name(func_name, lhs_type)
+ if func_name == "atan2":
+ mapped_atan = self.math_func_name("atan", lhs_type)
+ return f"{mapped_atan}({lhs_expr}, {rhs_expr})"
+
+ return f"{mapped}({lhs_expr}, {rhs_expr})"
+
+ def arithmetic_unary_expr(self, op: str, var_type: dtypes.dtype, var_expr: str) -> Optional[str]:
+ """Optional backend override for unary arithmetic expressions."""
+ _ = (op, var_type, var_expr)
+ return None
+
+ def arithmetic_binary_expr(
+ self,
+ op: str,
+ lhs_type: dtypes.dtype,
+ lhs_expr: str,
+ rhs_type: dtypes.dtype,
+ rhs_expr: str,
+ ) -> Optional[str]:
+ """Optional backend override for binary arithmetic expressions."""
+ _ = (op, lhs_type, lhs_expr, rhs_type, rhs_expr)
+ return None
+
def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str:
raise NotImplementedError
@@ -85,6 +147,18 @@ def inf_f32_expr(self) -> str:
def ninf_f32_expr(self) -> str:
raise NotImplementedError
+ def inf_f64_expr(self) -> str:
+ raise NotImplementedError
+
+ def ninf_f64_expr(self) -> str:
+ raise NotImplementedError
+
+ def inf_f16_expr(self) -> str:
+ raise NotImplementedError
+
+ def ninf_f16_expr(self) -> str:
+ raise NotImplementedError
+
def float_bits_to_int_expr(self, var_expr: str) -> str:
raise NotImplementedError
@@ -145,25 +219,32 @@ def memory_barrier_image_statement(self) -> str:
def group_memory_barrier_statement(self) -> str:
raise NotImplementedError
- def subgroup_add_expr(self, arg_expr: str) -> str:
+ def subgroup_add_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
raise NotImplementedError
- def subgroup_mul_expr(self, arg_expr: str) -> str:
+ def subgroup_mul_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
raise NotImplementedError
- def subgroup_min_expr(self, arg_expr: str) -> str:
+ def subgroup_min_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
raise NotImplementedError
- def subgroup_max_expr(self, arg_expr: str) -> str:
+ def subgroup_max_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
raise NotImplementedError
- def subgroup_and_expr(self, arg_expr: str) -> str:
+ def subgroup_and_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
raise NotImplementedError
- def subgroup_or_expr(self, arg_expr: str) -> str:
+ def subgroup_or_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
raise NotImplementedError
- def subgroup_xor_expr(self, arg_expr: str) -> str:
+ def subgroup_xor_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
raise NotImplementedError
def subgroup_elect_expr(self) -> str:
@@ -183,3 +264,8 @@ def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Opti
def mark_texture_sample_dimension(self, dimensions: int) -> None:
return
+
+ def atomic_add_expr(self, mem_expr: str, value_expr: str, var_type: dtypes.dtype) -> str:
+ raise NotImplementedError(
+ f"atomic_add is not supported for backend '{self.name}' and type '{var_type.name}'"
+ )
diff --git a/vkdispatch/codegen/backends/cuda.py b/vkdispatch/codegen/backends/cuda.py
deleted file mode 100644
index 51e575f0..00000000
--- a/vkdispatch/codegen/backends/cuda.py
+++ /dev/null
@@ -1,1247 +0,0 @@
-from typing import Dict, List, Optional, Set
-
-import vkdispatch.base.dtype as dtypes
-
-from .base import CodeGenBackend
-
-
-def _cuda_vec_components(dim: int) -> List[str]:
- if dim < 2 or dim > 4:
- raise ValueError(f"Unsupported vector dimension '{dim}'")
- return list("xyzw"[:dim])
-
-
-def _cuda_join_statements(statements: List[str]) -> str:
- if len(statements) == 0:
- return ""
- return " ".join(statements)
-
-
-def _cuda_emit_vec_type(
- vec_name: str,
- scalar_type: str,
- dim: int,
- *,
- allow_unary_neg: bool,
- enable_bitwise: bool,
- needed_ops: Optional[Set[str]] = None,
-) -> str:
- comps = _cuda_vec_components(dim)
- if needed_ops is None:
- needed_ops = set()
- if allow_unary_neg:
- needed_ops.add("un:-")
- if enable_bitwise:
- needed_ops.add("un:~")
- for op in ["+", "-", "*", "/"]:
- needed_ops.add(f"cmpd:{op}=:v")
- needed_ops.add(f"cmpd:{op}=:s")
- needed_ops.add(f"bin:{op}:vv")
- needed_ops.add(f"bin:{op}:vs")
- needed_ops.add(f"bin:{op}:sv")
- if enable_bitwise:
- for op in ["&", "|", "^", "<<", ">>"]:
- needed_ops.add(f"cmpd:{op}=:v")
- needed_ops.add(f"cmpd:{op}=:s")
- needed_ops.add(f"bin:{op}:vv")
- needed_ops.add(f"bin:{op}:vs")
- needed_ops.add(f"bin:{op}:sv")
-
- def has(token: str) -> bool:
- return token in needed_ops
-
- lines: List[str] = [f"struct {vec_name} {{"]
- lines.extend([f" {scalar_type} {c};" for c in comps])
- lines.append("")
- ctor_args = ", ".join([f"{scalar_type} {c}_" for c in comps])
- ctor_init = ", ".join([f"{c}({c}_)" for c in comps])
- splat_init = ", ".join([f"{c}(s)" for c in comps])
- cast_init = ", ".join([f"{c}(({scalar_type})v.{c})" for c in comps])
- lines.append(f" __device__ __forceinline__ {vec_name}() = default;")
- lines.append(f" __device__ __forceinline__ {vec_name}({ctor_args}) : {ctor_init} {{}}")
- lines.append(f" __device__ __forceinline__ explicit {vec_name}({scalar_type} s) : {splat_init} {{}}")
- lines.append(" template ")
- lines.append(f" __device__ __forceinline__ explicit {vec_name}(TVec v) : {cast_init} {{}}")
- lines.append(f" __device__ __forceinline__ {scalar_type}& operator[](int i) {{ return (&x)[i]; }}")
- lines.append(f" __device__ __forceinline__ const {scalar_type}& operator[](int i) const {{ return (&x)[i]; }}")
-
- if allow_unary_neg and has("un:-"):
- neg_expr = ", ".join([f"-{c}" for c in comps])
- lines.append(f" __device__ __forceinline__ {vec_name} operator-() const {{ return {vec_name}({neg_expr}); }}")
-
- if enable_bitwise and has("un:~"):
- not_expr = ", ".join([f"~{c}" for c in comps])
- lines.append(f" __device__ __forceinline__ {vec_name} operator~() const {{ return {vec_name}({not_expr}); }}")
-
- for op in ["+", "-", "*", "/"]:
- op_assign = op + "="
- if has(f"cmpd:{op}=:v"):
- vv_ops = _cuda_join_statements([f"{c} {op_assign} b.{c};" for c in comps])
- lines.append(
- f" __device__ __forceinline__ {vec_name}& operator{op_assign}(const {vec_name}& b) {{ {vv_ops} return *this; }}"
- )
- if has(f"cmpd:{op}=:s"):
- sv_ops = _cuda_join_statements([f"{c} {op_assign} b;" for c in comps])
- lines.append(
- f" __device__ __forceinline__ {vec_name}& operator{op_assign}({scalar_type} b) {{ {sv_ops} return *this; }}"
- )
-
- if enable_bitwise:
- for op in ["&", "|", "^", "<<", ">>"]:
- op_assign = op + "="
- if has(f"cmpd:{op}=:v"):
- vv_ops = _cuda_join_statements([f"{c} {op_assign} b.{c};" for c in comps])
- lines.append(
- f" __device__ __forceinline__ {vec_name}& operator{op_assign}(const {vec_name}& b) {{ {vv_ops} return *this; }}"
- )
- if has(f"cmpd:{op}=:s"):
- sv_ops = _cuda_join_statements([f"{c} {op_assign} b;" for c in comps])
- lines.append(
- f" __device__ __forceinline__ {vec_name}& operator{op_assign}({scalar_type} b) {{ {sv_ops} return *this; }}"
- )
-
- lines.append("};")
-
- # Arithmetic operators (vector/vector, vector/scalar, scalar/vector)
- for op in ["+", "-", "*", "/"]:
- if has(f"bin:{op}:vv"):
- vv_expr = ", ".join([f"(a.{c} {op} b.{c})" for c in comps])
- lines.append(
- f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, const {vec_name}& b) {{ return {vec_name}({vv_expr}); }}"
- )
- if has(f"bin:{op}:vs"):
- vs_expr = ", ".join([f"(a.{c} {op} b)" for c in comps])
- lines.append(
- f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, {scalar_type} b) {{ return {vec_name}({vs_expr}); }}"
- )
- if has(f"bin:{op}:sv"):
- if op in ["+", "*"]:
- sv_expr = ", ".join([f"(a {op} b.{c})" for c in comps])
- else:
- sv_expr = ", ".join([f"({scalar_type})(a {op} b.{c})" for c in comps])
- lines.append(
- f"__device__ __forceinline__ {vec_name} operator{op}({scalar_type} a, const {vec_name}& b) {{ return {vec_name}({sv_expr}); }}"
- )
-
- if enable_bitwise:
- for op in ["&", "|", "^", "<<", ">>"]:
- if has(f"bin:{op}:vv"):
- vv_expr = ", ".join([f"(a.{c} {op} b.{c})" for c in comps])
- lines.append(
- f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, const {vec_name}& b) {{ return {vec_name}({vv_expr}); }}"
- )
- if has(f"bin:{op}:vs"):
- vs_expr = ", ".join([f"(a.{c} {op} b)" for c in comps])
- lines.append(
- f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, {scalar_type} b) {{ return {vec_name}({vs_expr}); }}"
- )
- if has(f"bin:{op}:sv"):
- sv_expr = ", ".join([f"({scalar_type})(a {op} b.{c})" for c in comps])
- lines.append(
- f"__device__ __forceinline__ {vec_name} operator{op}({scalar_type} a, const {vec_name}& b) {{ return {vec_name}({sv_expr}); }}"
- )
-
- return "\n".join(lines)
-
-
-def _cuda_emit_vec_helper(helper_suffix: str, vec_name: str, scalar_type: str, dim: int) -> str:
- comps = _cuda_vec_components(dim)
- args = ", ".join([f"{scalar_type} {c}" for c in comps])
- ctor_args = ", ".join(comps)
- return "\n".join(
- [
- f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}({args}) {{ return {vec_name}({ctor_args}); }}",
- f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}({scalar_type} x) {{ return {vec_name}(x); }}",
- "template ",
- f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}(TVec v) {{ return {vec_name}(v); }}",
- ]
- )
-
-
-def _cuda_emit_mat_type(mat_name: str, vec_name: str, dim: int, needed_ops: Optional[Set[str]] = None) -> str:
- cols = [f"c{i}" for i in range(dim)]
- if needed_ops is None:
- needed_ops = {
- "un:-",
- "cmpd:+=:m", "cmpd:+=:s",
- "cmpd:-=:m", "cmpd:-=:s",
- "cmpd:*=:s", "cmpd:/=:s",
- "bin:+:mm", "bin:+:ms", "bin:+:sm",
- "bin:-:mm", "bin:-:ms", "bin:-:sm",
- "bin:*:ms", "bin:*:sm", "bin:/:ms", "bin:/:sm",
- "bin:*:mv", "bin:*:vm", "bin:*:mm",
- }
-
- def has(token: str) -> bool:
- return token in needed_ops
-
- lines: List[str] = [f"struct {mat_name} {{"]
- lines.extend([f" {vec_name} {c};" for c in cols])
- lines.append("")
- lines.append(f" __device__ __forceinline__ {mat_name}() = default;")
- ctor_args = ", ".join([f"{vec_name} {c}_" for c in cols])
- ctor_init = ", ".join([f"{c}({c}_)" for c in cols])
- lines.append(f" __device__ __forceinline__ {mat_name}({ctor_args}) : {ctor_init} {{}}")
-
- zero = "0.0f"
- diag_init = ", ".join(
- [f"c{col_idx}({vec_name}({', '.join(['s' if row_idx == col_idx else zero for row_idx in range(dim)])}))" for col_idx in range(dim)]
- )
- lines.append(f" __device__ __forceinline__ explicit {mat_name}(float s) : {diag_init} {{}}")
- lines.append(f" __device__ __forceinline__ {vec_name}& operator[](int i) {{ return (&c0)[i]; }}")
- lines.append(f" __device__ __forceinline__ const {vec_name}& operator[](int i) const {{ return (&c0)[i]; }}")
- if has("un:-"):
- lines.append(f" __device__ __forceinline__ {mat_name} operator-() const {{ return {mat_name}({', '.join([f'-c{i}' for i in range(dim)])}); }}")
-
- for op in ["+", "-"]:
- op_assign = op + "="
- if has(f"cmpd:{op}=:m"):
- mm_ops = _cuda_join_statements([f"c{i} {op_assign} b.c{i};" for i in range(dim)])
- lines.append(
- f" __device__ __forceinline__ {mat_name}& operator{op_assign}(const {mat_name}& b) {{ {mm_ops} return *this; }}"
- )
- if has(f"cmpd:{op}=:s"):
- ms_ops = _cuda_join_statements([f"c{i} {op_assign} b;" for i in range(dim)])
- lines.append(
- f" __device__ __forceinline__ {mat_name}& operator{op_assign}(float b) {{ {ms_ops} return *this; }}"
- )
-
- for op in ["*", "/"]:
- op_assign = op + "="
- if has(f"cmpd:{op}=:s"):
- ms_ops = _cuda_join_statements([f"c{i} {op_assign} b;" for i in range(dim)])
- lines.append(
- f" __device__ __forceinline__ {mat_name}& operator{op_assign}(float b) {{ {ms_ops} return *this; }}"
- )
-
- lines.append("};")
-
- # Basic arithmetic
- for op in ["+", "-"]:
- if has(f"bin:{op}:mm"):
- cols_expr = ", ".join([f"(a.c{i} {op} b.c{i})" for i in range(dim)])
- lines.append(
- f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}"
- )
- if has(f"bin:{op}:ms"):
- cols_expr = ", ".join([f"(a.c{i} {op} b)" for i in range(dim)])
- lines.append(
- f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, float b) {{ return {mat_name}({cols_expr}); }}"
- )
- if has(f"bin:{op}:sm"):
- cols_expr = ", ".join([f"(a {op} b.c{i})" for i in range(dim)])
- lines.append(
- f"__device__ __forceinline__ {mat_name} operator{op}(float a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}"
- )
-
- for op in ["*", "/"]:
- if has(f"bin:{op}:ms"):
- cols_expr = ", ".join([f"(a.c{i} {op} b)" for i in range(dim)])
- lines.append(
- f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, float b) {{ return {mat_name}({cols_expr}); }}"
- )
- if has(f"bin:{op}:sm"):
- cols_expr = ", ".join([f"(a {op} b.c{i})" for i in range(dim)])
- lines.append(
- f"__device__ __forceinline__ {mat_name} operator{op}(float a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}"
- )
-
- # GLSL-style matrix/vector products (column-major)
- vec_comps = _cuda_vec_components(dim)
- if has("bin:*:mv"):
- mat_vec_terms = [f"(m.c{i} * v.{vec_comps[i]})" for i in range(dim)]
- mat_vec_expr = " + ".join(mat_vec_terms)
- lines.append(
- f"__device__ __forceinline__ {vec_name} operator* (const {mat_name}& m, const {vec_name}& v) {{ return {mat_vec_expr}; }}"
- )
-
- if has("bin:*:vm"):
- row_exprs: List[str] = []
- for col_idx in range(dim):
- terms = [f"(v.{vec_comps[row_idx]} * m.c{col_idx}.{vec_comps[row_idx]})" for row_idx in range(dim)]
- row_exprs.append(" + ".join(terms))
- lines.append(
- f"__device__ __forceinline__ {vec_name} operator* (const {vec_name}& v, const {mat_name}& m) {{ return {vec_name}({', '.join(row_exprs)}); }}"
- )
-
- if has("bin:*:mm"):
- col_products = ", ".join([f"(a * b.c{i})" for i in range(dim)])
- lines.append(
- f"__device__ __forceinline__ {mat_name} operator* (const {mat_name}& a, const {mat_name}& b) {{ return {mat_name}({col_products}); }}"
- )
-
- return "\n".join(lines)
-
-
-def _cuda_emit_mat_helpers(mat_name: str, helper_suffix: str, vec_name: str, vec_helper_suffix: str, dim: int) -> str:
- col_type = vec_name
- col_args = ", ".join([f"{col_type} c{i}" for i in range(dim)])
- col_ctor = ", ".join([f"c{i}" for i in range(dim)])
-
- flat_names = [f"m{col}{row}" for col in range(dim) for row in range(dim)]
- flat_args = ", ".join([f"float {name}" for name in flat_names])
- flat_cols: List[str] = []
- for col in range(dim):
- values = [f"m{col}{row}" for row in range(dim)]
- flat_cols.append(f"vkdispatch_make_{vec_helper_suffix}({', '.join(values)})")
- flat_ctor = ", ".join(flat_cols)
-
- cast_cols = ", ".join([f"vkdispatch_make_{vec_helper_suffix}(m[{i}])" for i in range(dim)])
-
- return "\n".join(
- [
- f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}({col_args}) {{ return {mat_name}({col_ctor}); }}",
- f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}(float s) {{ return {mat_name}(s); }}",
- f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}({flat_args}) {{ return {mat_name}({flat_ctor}); }}",
- "template ",
- f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}(TMat m) {{ return {mat_name}({cast_cols}); }}",
- ]
- )
-
-
-def _cuda_composite_helpers() -> str:
- parts: List[str] = []
-
- vector_specs = [
- ("vkdispatch_int2", "int", 2, True, True, "int2"),
- ("vkdispatch_int3", "int", 3, True, True, "int3"),
- ("vkdispatch_int4", "int", 4, True, True, "int4"),
- ("vkdispatch_uint2", "unsigned int", 2, False, True, "uint2"),
- ("vkdispatch_uint3", "unsigned int", 3, False, True, "uint3"),
- ("vkdispatch_uint4", "unsigned int", 4, False, True, "uint4"),
- ("vkdispatch_float2", "float", 2, True, False, "float2"),
- ("vkdispatch_float3", "float", 3, True, False, "float3"),
- ("vkdispatch_float4", "float", 4, True, False, "float4"),
- ]
-
- for vec_name, scalar_type, dim, allow_neg, enable_bitwise, helper_suffix in vector_specs:
- parts.append(
- _cuda_emit_vec_type(
- vec_name,
- scalar_type,
- dim,
- allow_unary_neg=allow_neg,
- enable_bitwise=enable_bitwise,
- )
- )
- parts.append(_cuda_emit_vec_helper(helper_suffix, vec_name, scalar_type, dim))
-
- matrix_specs = [
- ("vkdispatch_mat2", "mat2", "vkdispatch_float2", "float2", 2),
- ("vkdispatch_mat3", "mat3", "vkdispatch_float3", "float3", 3),
- ("vkdispatch_mat4", "mat4", "vkdispatch_float4", "float4", 4),
- ]
-
- for mat_name, helper_suffix, vec_name, vec_helper_suffix, dim in matrix_specs:
- parts.append(_cuda_emit_mat_type(mat_name, vec_name, dim))
- parts.append(_cuda_emit_mat_helpers(mat_name, helper_suffix, vec_name, vec_helper_suffix, dim))
-
- return "\n\n".join(parts)
-
-
-_CUDA_VEC_TYPE_SPECS = {
- "int2": ("vkdispatch_int2", "int", 2, True, True),
- "int3": ("vkdispatch_int3", "int", 3, True, True),
- "int4": ("vkdispatch_int4", "int", 4, True, True),
- "uint2": ("vkdispatch_uint2", "unsigned int", 2, False, True),
- "uint3": ("vkdispatch_uint3", "unsigned int", 3, False, True),
- "uint4": ("vkdispatch_uint4", "unsigned int", 4, False, True),
- "float2": ("vkdispatch_float2", "float", 2, True, False),
- "float3": ("vkdispatch_float3", "float", 3, True, False),
- "float4": ("vkdispatch_float4", "float", 4, True, False),
-}
-
-_CUDA_MAT_TYPE_SPECS = {
- "mat2": ("vkdispatch_mat2", "vkdispatch_float2", "float2", 2),
- "mat3": ("vkdispatch_mat3", "vkdispatch_float3", "float3", 3),
- "mat4": ("vkdispatch_mat4", "vkdispatch_float4", "float4", 4),
-}
-
-
-class CUDABackend(CodeGenBackend):
- name = "cuda"
-
- _HELPER_SNIPPETS: Dict[str, str] = {
- "composite_types": "",
- "mat2_type": "",
- "mat3_type": "",
- "mat4_type": "",
- "make_mat2": "",
- "make_mat3": "",
- "make_mat4": "",
- "make_int2": "",
- "make_int3": "",
- "make_int4": "",
- "make_uint2": "",
- "make_uint3": "",
- "make_uint4": "",
- "float2_ops": "",
- "make_float2": "",
- "make_float3": "",
- "make_float4": "",
- "global_invocation_id": (
- "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_global_invocation_id() {\n"
- " return vkdispatch_uint3(\n"
- " (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x),\n"
- " (unsigned int)(blockIdx.y * blockDim.y + threadIdx.y),\n"
- " (unsigned int)(blockIdx.z * blockDim.z + threadIdx.z)\n"
- " );\n"
- "}"
- ),
- "local_invocation_id": (
- "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_local_invocation_id() {\n"
- " return vkdispatch_uint3((unsigned int)threadIdx.x, (unsigned int)threadIdx.y, (unsigned int)threadIdx.z);\n"
- "}"
- ),
- "workgroup_id": (
- "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_workgroup_id() {\n"
- " return vkdispatch_uint3((unsigned int)blockIdx.x, (unsigned int)blockIdx.y, (unsigned int)blockIdx.z);\n"
- "}"
- ),
- "local_invocation_index": (
- "__device__ __forceinline__ unsigned int vkdispatch_local_invocation_index() {\n"
- " return (unsigned int)(threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z));\n"
- "}"
- ),
- "subgroup_size": "__device__ __forceinline__ unsigned int vkdispatch_subgroup_size() { return (unsigned int)warpSize; }",
- "num_subgroups": (
- "__device__ __forceinline__ unsigned int vkdispatch_num_subgroups() {\n"
- " unsigned int local_count = (unsigned int)(blockDim.x * blockDim.y * blockDim.z);\n"
- " return (local_count + vkdispatch_subgroup_size() - 1u) / vkdispatch_subgroup_size();\n"
- "}"
- ),
- "subgroup_id": (
- "__device__ __forceinline__ unsigned int vkdispatch_subgroup_id() {\n"
- " return vkdispatch_local_invocation_index() / vkdispatch_subgroup_size();\n"
- "}"
- ),
- "subgroup_invocation_id": (
- "__device__ __forceinline__ unsigned int vkdispatch_subgroup_invocation_id() {\n"
- " return vkdispatch_local_invocation_index() % vkdispatch_subgroup_size();\n"
- "}"
- ),
- "subgroup_add": (
- "template \n"
- "__device__ __forceinline__ T vkdispatch_subgroup_add(T value) {\n"
- " unsigned int mask = 0xffffffffu;\n"
- " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
- " value += __shfl_xor_sync(mask, value, (int)offset);\n"
- " }\n"
- " return value;\n"
- "}"
- ),
- "subgroup_mul": (
- "template \n"
- "__device__ __forceinline__ T vkdispatch_subgroup_mul(T value) {\n"
- " unsigned int mask = 0xffffffffu;\n"
- " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
- " value *= __shfl_xor_sync(mask, value, (int)offset);\n"
- " }\n"
- " return value;\n"
- "}"
- ),
- "subgroup_min": (
- "template \n"
- "__device__ __forceinline__ T vkdispatch_subgroup_min(T value) {\n"
- " unsigned int mask = 0xffffffffu;\n"
- " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
- " T other = __shfl_xor_sync(mask, value, (int)offset);\n"
- " value = other < value ? other : value;\n"
- " }\n"
- " return value;\n"
- "}"
- ),
- "subgroup_max": (
- "template \n"
- "__device__ __forceinline__ T vkdispatch_subgroup_max(T value) {\n"
- " unsigned int mask = 0xffffffffu;\n"
- " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
- " T other = __shfl_xor_sync(mask, value, (int)offset);\n"
- " value = other > value ? other : value;\n"
- " }\n"
- " return value;\n"
- "}"
- ),
- "subgroup_and": (
- "template \n"
- "__device__ __forceinline__ T vkdispatch_subgroup_and(T value) {\n"
- " unsigned int mask = 0xffffffffu;\n"
- " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
- " value &= __shfl_xor_sync(mask, value, (int)offset);\n"
- " }\n"
- " return value;\n"
- "}"
- ),
- "subgroup_or": (
- "template \n"
- "__device__ __forceinline__ T vkdispatch_subgroup_or(T value) {\n"
- " unsigned int mask = 0xffffffffu;\n"
- " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
- " value |= __shfl_xor_sync(mask, value, (int)offset);\n"
- " }\n"
- " return value;\n"
- "}"
- ),
- "subgroup_xor": (
- "template \n"
- "__device__ __forceinline__ T vkdispatch_subgroup_xor(T value) {\n"
- " unsigned int mask = 0xffffffffu;\n"
- " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
- " value ^= __shfl_xor_sync(mask, value, (int)offset);\n"
- " }\n"
- " return value;\n"
- "}"
- ),
- "mod": "__device__ __forceinline__ float mod(float x, float y) { return fmodf(x, y); }",
- "fract": "__device__ __forceinline__ float fract(float x) { return x - floorf(x); }",
- "roundEven": "__device__ __forceinline__ float roundEven(float x) { return nearbyintf(x); }",
- "mix": "__device__ __forceinline__ float mix(float x, float y, float a) { return x + (y - x) * a; }",
- "step": "__device__ __forceinline__ float step(float edge, float x) { return x < edge ? 0.0f : 1.0f; }",
- "smoothstep": (
- "__device__ __forceinline__ float smoothstep(float edge0, float edge1, float x) {\n"
- " float t = fminf(fmaxf((x - edge0) / (edge1 - edge0), 0.0f), 1.0f);\n"
- " return t * t * (3.0f - 2.0f * t);\n"
- "}"
- ),
- "radians": "__device__ __forceinline__ float radians(float x) { return x * (3.14159265358979323846f / 180.0f); }",
- "degrees": "__device__ __forceinline__ float degrees(float x) { return x * (180.0f / 3.14159265358979323846f); }",
- "inversesqrt": "__device__ __forceinline__ float inversesqrt(float x) { return rsqrtf(x); }",
- "floatBitsToInt": "__device__ __forceinline__ int floatBitsToInt(float x) { return __float_as_int(x); }",
- "floatBitsToUint": "__device__ __forceinline__ unsigned int floatBitsToUint(float x) { return __float_as_uint(x); }",
- "intBitsToFloat": "__device__ __forceinline__ float intBitsToFloat(int x) { return __int_as_float(x); }",
- "uintBitsToFloat": "__device__ __forceinline__ float uintBitsToFloat(unsigned int x) { return __uint_as_float(x); }",
- "sample_texture": "",
- }
-
- _HELPER_ORDER: List[str] = [
- "composite_types",
- "global_invocation_id",
- "local_invocation_id",
- "workgroup_id",
- "local_invocation_index",
- "subgroup_size",
- "num_subgroups",
- "subgroup_id",
- "subgroup_invocation_id",
- "subgroup_add",
- "subgroup_mul",
- "subgroup_min",
- "subgroup_max",
- "subgroup_and",
- "subgroup_or",
- "subgroup_xor",
- "mod",
- "fract",
- "roundEven",
- "mix",
- "step",
- "smoothstep",
- "radians",
- "degrees",
- "inversesqrt",
- "floatBitsToInt",
- "floatBitsToUint",
- "intBitsToFloat",
- "uintBitsToFloat",
- "sample_texture",
- ]
-
- _HELPER_DEPENDENCIES: Dict[str, List[str]] = {
- "mat2_type": ["composite_types"],
- "mat3_type": ["composite_types"],
- "mat4_type": ["composite_types"],
- "make_mat2": ["composite_types"],
- "make_mat3": ["composite_types"],
- "make_mat4": ["composite_types"],
- "make_int2": ["composite_types"],
- "make_int3": ["composite_types"],
- "make_int4": ["composite_types"],
- "make_uint2": ["composite_types"],
- "make_uint3": ["composite_types"],
- "make_uint4": ["composite_types"],
- "float2_ops": ["composite_types"],
- "make_float2": ["composite_types"],
- "make_float3": ["composite_types"],
- "make_float4": ["composite_types"],
- "global_invocation_id": ["composite_types"],
- "local_invocation_id": ["composite_types"],
- "workgroup_id": ["composite_types"],
- "sample_texture": ["composite_types"],
- "num_subgroups": ["subgroup_size"],
- "subgroup_id": ["local_invocation_index", "subgroup_size"],
- "subgroup_invocation_id": ["local_invocation_index", "subgroup_size"],
- "subgroup_add": ["subgroup_size"],
- "subgroup_mul": ["subgroup_size"],
- "subgroup_min": ["subgroup_size"],
- "subgroup_max": ["subgroup_size"],
- "subgroup_and": ["subgroup_size"],
- "subgroup_or": ["subgroup_size"],
- "subgroup_xor": ["subgroup_size"],
- }
-
- def __init__(self) -> None:
- self._fixed_preamble = ""
- self.reset_state()
-
- def reset_state(self) -> None:
- self._kernel_params: List[str] = []
- self._entry_alias_lines: List[str] = []
- self._composite_type_usage: Set[str] = set()
- self._composite_vec_op_usage: Dict[str, Set[str]] = {}
- self._composite_mat_op_usage: Dict[str, Set[str]] = {}
- self._sample_texture_dims: Set[int] = set()
- self._feature_usage: Dict[str, bool] = {
- feature_name: False
- for feature_name in self._HELPER_SNIPPETS
- }
-
- def mark_feature_usage(self, feature_name: str) -> None:
- if feature_name in self._feature_usage:
- self._feature_usage[feature_name] = True
-
- def _composite_key_for_dtype(self, var_type: dtypes.dtype) -> Optional[str]:
- if var_type == dtypes.complex64 or var_type == dtypes.vec2:
- return "float2"
- if var_type == dtypes.vec3:
- return "float3"
- if var_type == dtypes.vec4:
- return "float4"
- if var_type == dtypes.ivec2:
- return "int2"
- if var_type == dtypes.ivec3:
- return "int3"
- if var_type == dtypes.ivec4:
- return "int4"
- if var_type == dtypes.uvec2:
- return "uint2"
- if var_type == dtypes.uvec3:
- return "uint3"
- if var_type == dtypes.uvec4:
- return "uint4"
- if var_type == dtypes.mat2:
- return "mat2"
- if var_type == dtypes.mat3:
- return "mat3"
- if var_type == dtypes.mat4:
- return "mat4"
- return None
-
- def _record_composite_type_key(self, key: str) -> None:
- self.mark_feature_usage("composite_types")
- self._composite_type_usage.add(key)
-
- if key in _CUDA_MAT_TYPE_SPECS:
- dim = _CUDA_MAT_TYPE_SPECS[key][3]
- self._composite_type_usage.add(f"float{dim}")
-
- def _record_composite_type(self, var_type: dtypes.dtype) -> Optional[str]:
- key = self._composite_key_for_dtype(var_type)
- if key is None:
- return None
- self._record_composite_type_key(key)
- return key
-
- def _record_vec_op(self, key: str, token: str) -> None:
- self._record_composite_type_key(key)
- self._composite_vec_op_usage.setdefault(key, set()).add(token)
-
- def _record_mat_op(self, key: str, token: str) -> None:
- self._record_composite_type_key(key)
- self._composite_mat_op_usage.setdefault(key, set()).add(token)
-
- def _propagate_matrix_vec_dependencies(self, mat_key: str, token: str) -> None:
- dim = _CUDA_MAT_TYPE_SPECS[mat_key][3]
- vec_key = f"float{dim}"
-
- if token == "un:-":
- self._record_vec_op(vec_key, "un:-")
- return
-
- if token.startswith("cmpd:"):
- if token.endswith(":m"):
- vec_token = token[:-1] + "v"
- self._record_vec_op(vec_key, vec_token)
- return
- if token.endswith(":s"):
- self._record_vec_op(vec_key, token)
- return
-
- if token.startswith("bin:"):
- parts = token.split(":")
- if len(parts) != 3:
- return
- _, op, shape = parts
- if shape == "mm":
- if op in ["+", "-"]:
- self._record_vec_op(vec_key, f"bin:{op}:vv")
- elif op == "*":
- self._record_mat_op(mat_key, "bin:*:mv")
- self._propagate_matrix_vec_dependencies(mat_key, "bin:*:mv")
- return
- if shape == "ms":
- self._record_vec_op(vec_key, f"bin:{op}:vs")
- return
- if shape == "sm":
- self._record_vec_op(vec_key, f"bin:{op}:sv")
- return
- if shape == "mv":
- self._record_vec_op(vec_key, "bin:*:vs")
- self._record_vec_op(vec_key, "bin:+:vv")
- return
- if shape == "vm":
- return
-
- def mark_composite_unary_op(self, var_type: dtypes.dtype, op: str) -> None:
- key = self._record_composite_type(var_type)
- if key is None:
- return
-
- token = f"un:{op}"
- if key in _CUDA_VEC_TYPE_SPECS:
- self._record_vec_op(key, token)
- return
- if key in _CUDA_MAT_TYPE_SPECS:
- self._record_mat_op(key, token)
- self._propagate_matrix_vec_dependencies(key, token)
-
- def mark_composite_binary_op(
- self,
- lhs_type: dtypes.dtype,
- rhs_type: dtypes.dtype,
- op: str,
- *,
- inplace: bool = False,
- ) -> None:
- lhs_key = self._record_composite_type(lhs_type)
- rhs_key = self._record_composite_type(rhs_type)
-
- lhs_is_composite = lhs_key is not None
- rhs_is_composite = rhs_key is not None
- if not lhs_is_composite and not rhs_is_composite:
- return
-
- lhs_is_scalar = dtypes.is_scalar(lhs_type)
- rhs_is_scalar = dtypes.is_scalar(rhs_type)
-
- if lhs_key in _CUDA_VEC_TYPE_SPECS and (rhs_is_scalar or rhs_key in _CUDA_VEC_TYPE_SPECS):
- if inplace:
- suffix = "s" if rhs_is_scalar else "v"
- self._record_vec_op(lhs_key, f"cmpd:{op}=:{suffix}")
- return
- shape = "vs" if rhs_is_scalar else "vv"
- self._record_vec_op(lhs_key, f"bin:{op}:{shape}")
- return
-
- if rhs_key in _CUDA_VEC_TYPE_SPECS and lhs_is_scalar and not inplace:
- self._record_vec_op(rhs_key, f"bin:{op}:sv")
- return
-
- if lhs_key in _CUDA_MAT_TYPE_SPECS:
- if inplace:
- if rhs_is_scalar:
- token = f"cmpd:{op}=:s"
- elif rhs_key in _CUDA_MAT_TYPE_SPECS:
- token = f"cmpd:{op}=:m"
- else:
- return
- self._record_mat_op(lhs_key, token)
- self._propagate_matrix_vec_dependencies(lhs_key, token)
- return
-
- if rhs_is_scalar:
- token = f"bin:{op}:ms"
- self._record_mat_op(lhs_key, token)
- self._propagate_matrix_vec_dependencies(lhs_key, token)
- return
-
- if rhs_key in _CUDA_MAT_TYPE_SPECS:
- token = "bin:*:mm" if op == "*" else f"bin:{op}:mm"
- self._record_mat_op(lhs_key, token)
- self._propagate_matrix_vec_dependencies(lhs_key, token)
- return
-
- if rhs_key in _CUDA_VEC_TYPE_SPECS and op == "*":
- token = "bin:*:mv"
- self._record_mat_op(lhs_key, token)
- self._propagate_matrix_vec_dependencies(lhs_key, token)
- return
-
- if rhs_key in _CUDA_MAT_TYPE_SPECS and lhs_is_scalar and not inplace:
- token = f"bin:{op}:sm"
- self._record_mat_op(rhs_key, token)
- self._propagate_matrix_vec_dependencies(rhs_key, token)
- return
-
- if lhs_key in _CUDA_VEC_TYPE_SPECS and rhs_key in _CUDA_MAT_TYPE_SPECS and op == "*" and not inplace:
- token = "bin:*:vm"
- self._record_mat_op(rhs_key, token)
- self._propagate_matrix_vec_dependencies(rhs_key, token)
-
- def mark_texture_sample_dimension(self, dimensions: int) -> None:
- self._sample_texture_dims.add(dimensions)
- self.mark_feature_usage("sample_texture")
- self._record_composite_type_key("float4")
- if dimensions == 2:
- self._record_composite_type_key("float2")
- elif dimensions == 3:
- self._record_composite_type_key("float3")
-
- def _emit_used_composite_helpers(self) -> str:
- if len(self._composite_type_usage) == 0:
- return ""
-
- parts: List[str] = []
-
- vec_order = ["int2", "int3", "int4", "uint2", "uint3", "uint4", "float2", "float3", "float4"]
- for key in vec_order:
- if key not in self._composite_type_usage:
- continue
- vec_name, scalar_type, dim, allow_neg, enable_bitwise = _CUDA_VEC_TYPE_SPECS[key]
- parts.append(
- _cuda_emit_vec_type(
- vec_name,
- scalar_type,
- dim,
- allow_unary_neg=allow_neg,
- enable_bitwise=enable_bitwise,
- needed_ops=self._composite_vec_op_usage.get(key, set()),
- )
- )
- parts.append(_cuda_emit_vec_helper(key, vec_name, scalar_type, dim))
-
- mat_order = ["mat2", "mat3", "mat4"]
- for key in mat_order:
- if key not in self._composite_type_usage:
- continue
- mat_name, vec_name, vec_helper_suffix, dim = _CUDA_MAT_TYPE_SPECS[key]
- parts.append(_cuda_emit_mat_type(mat_name, vec_name, dim, self._composite_mat_op_usage.get(key, set())))
- parts.append(_cuda_emit_mat_helpers(mat_name, key, vec_name, vec_helper_suffix, dim))
-
- return "\n\n".join(parts)
-
- def _emit_sample_texture_helpers(self) -> str:
- dims = set(self._sample_texture_dims)
- if len(dims) == 0:
- dims = {1, 2, 3}
-
- lines: List[str] = []
- if 1 in dims:
- lines.append(
- "__device__ __forceinline__ vkdispatch_float4 vkdispatch_sample_texture(cudaTextureObject_t tex, float coord) { return vkdispatch_make_float4(tex1D(tex, coord)); }"
- )
- lines.append(
- "__device__ __forceinline__ vkdispatch_float4 vkdispatch_sample_texture(cudaTextureObject_t tex, float coord, float lod) { return vkdispatch_make_float4(tex1DLod(tex, coord, lod)); }"
- )
- self._record_composite_type_key("float4")
- if 2 in dims:
- lines.append(
- "__device__ __forceinline__ vkdispatch_float4 vkdispatch_sample_texture(cudaTextureObject_t tex, vkdispatch_float2 coord) { return vkdispatch_make_float4(tex2D(tex, coord.x, coord.y)); }"
- )
- lines.append(
- "__device__ __forceinline__ vkdispatch_float4 vkdispatch_sample_texture(cudaTextureObject_t tex, vkdispatch_float2 coord, float lod) { return vkdispatch_make_float4(tex2DLod(tex, coord.x, coord.y, lod)); }"
- )
- self._record_composite_type_key("float2")
- self._record_composite_type_key("float4")
- if 3 in dims:
- lines.append(
- "__device__ __forceinline__ vkdispatch_float4 vkdispatch_sample_texture(cudaTextureObject_t tex, vkdispatch_float3 coord) { return vkdispatch_make_float4(tex3D(tex, coord.x, coord.y, coord.z)); }"
- )
- lines.append(
- "__device__ __forceinline__ vkdispatch_float4 vkdispatch_sample_texture(cudaTextureObject_t tex, vkdispatch_float3 coord, float lod) { return vkdispatch_make_float4(tex3DLod(tex, coord.x, coord.y, coord.z, lod)); }"
- )
- self._record_composite_type_key("float3")
- self._record_composite_type_key("float4")
-
- return "\n".join(lines)
-
- def _register_kernel_param(self, param_decl: str) -> None:
- if param_decl not in self._kernel_params:
- self._kernel_params.append(param_decl)
-
- def _register_alias_line(self, alias_line: str) -> None:
- if alias_line not in self._entry_alias_lines:
- self._entry_alias_lines.append(alias_line)
-
- @staticmethod
- def _is_plain_integer_literal(expr: str) -> bool:
- if len(expr) == 0:
- return False
- if expr[0] in "+-":
- return len(expr) > 1 and expr[1:].isdigit()
- return expr.isdigit()
-
- def type_name(self, var_type: dtypes.dtype) -> str:
- if var_type == dtypes.int32:
- return "int"
- if var_type == dtypes.uint32:
- return "unsigned int"
- if var_type == dtypes.float32:
- return "float"
- if var_type == dtypes.complex64:
- self._record_composite_type(var_type)
- return "vkdispatch_float2"
-
- if var_type == dtypes.ivec2:
- self._record_composite_type(var_type)
- return "vkdispatch_int2"
- if var_type == dtypes.ivec3:
- self._record_composite_type(var_type)
- return "vkdispatch_int3"
- if var_type == dtypes.ivec4:
- self._record_composite_type(var_type)
- return "vkdispatch_int4"
-
- if var_type == dtypes.uvec2:
- self._record_composite_type(var_type)
- return "vkdispatch_uint2"
- if var_type == dtypes.uvec3:
- self._record_composite_type(var_type)
- return "vkdispatch_uint3"
- if var_type == dtypes.uvec4:
- self._record_composite_type(var_type)
- return "vkdispatch_uint4"
-
- if var_type == dtypes.vec2:
- self._record_composite_type(var_type)
- return "vkdispatch_float2"
- if var_type == dtypes.vec3:
- self._record_composite_type(var_type)
- return "vkdispatch_float3"
- if var_type == dtypes.vec4:
- self._record_composite_type(var_type)
- return "vkdispatch_float4"
-
- if var_type == dtypes.mat2:
- self._record_composite_type(var_type)
- return "vkdispatch_mat2"
- if var_type == dtypes.mat3:
- self._record_composite_type(var_type)
- return "vkdispatch_mat3"
- if var_type == dtypes.mat4:
- self._record_composite_type(var_type)
- return "vkdispatch_mat4"
-
- raise ValueError(f"Unsupported CUDA type mapping for '{var_type.name}'")
-
- def constructor(self, var_type: dtypes.dtype, args: List[str]) -> str:
- if (
- len(args) == 1
- and var_type in (dtypes.complex64, dtypes.vec2, dtypes.vec3, dtypes.vec4)
- and self._is_plain_integer_literal(args[0])
- ):
- args = [f"{args[0]}.0f"]
-
- target_type = self.type_name(var_type)
-
- if dtypes.is_scalar(var_type):
- assert len(args) > 0, f"Constructor for scalar type '{var_type.name}' needs at least one argument."
- return f"(({target_type})({args[0]}))"
-
- if var_type == dtypes.mat2:
- self.mark_feature_usage("make_mat2")
- return f"vkdispatch_make_mat2({', '.join(args)})"
- if var_type == dtypes.mat3:
- self.mark_feature_usage("make_mat3")
- return f"vkdispatch_make_mat3({', '.join(args)})"
- if var_type == dtypes.mat4:
- self.mark_feature_usage("make_mat4")
- return f"vkdispatch_make_mat4({', '.join(args)})"
-
- helper_suffix = target_type[len("vkdispatch_"):] if target_type.startswith("vkdispatch_") else target_type
- helper_name = f"vkdispatch_make_{helper_suffix}"
- self.mark_feature_usage(f"make_{helper_suffix}")
- return f"{helper_name}({', '.join(args)})"
-
- def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str:
- self.reset_state()
-
- subgroup_support = "1" if enable_subgroup_ops else "0"
- printf_support = "1" if enable_printf else "0"
-
- self._fixed_preamble = (
- "#include \n"
- "#include \n"
- "#include \n\n"
- f"#define VKDISPATCH_ENABLE_SUBGROUP_OPS {subgroup_support}\n"
- f"#define VKDISPATCH_ENABLE_PRINTF {printf_support}\n\n"
- )
-
- return self._fixed_preamble
-
- def _resolve_helper_dependencies(self, helpers: Set[str]) -> Set[str]:
- pending = list(helpers)
- resolved = set(helpers)
-
- while len(pending) > 0:
- helper_name = pending.pop()
-
- for dependency in self._HELPER_DEPENDENCIES.get(helper_name, []):
- if dependency not in resolved:
- resolved.add(dependency)
- pending.append(dependency)
-
- return resolved
-
- def _helper_header(self) -> str:
- enabled_helpers = {
- helper_name
- for helper_name, is_enabled in self._feature_usage.items()
- if is_enabled
- }
-
- resolved_helpers = self._resolve_helper_dependencies(enabled_helpers)
-
- if len(resolved_helpers) == 0:
- return ""
-
- helper_sections: List[str] = []
-
- for helper_name in self._HELPER_ORDER:
- if helper_name in resolved_helpers:
- if helper_name == "composite_types":
- composite_helpers = self._emit_used_composite_helpers()
- if len(composite_helpers) > 0:
- helper_sections.append(composite_helpers)
- continue
- if helper_name == "sample_texture":
- texture_helpers = self._emit_sample_texture_helpers()
- if len(texture_helpers) > 0:
- helper_sections.append(texture_helpers)
- continue
-
- snippet = self._HELPER_SNIPPETS[helper_name]
- if len(snippet) > 0:
- helper_sections.append(snippet)
-
- return "\n\n".join(helper_sections) + "\n\n"
-
- def make_source(self, header: str, body: str, x: int, y: int, z: int) -> str:
- expected_size_header = (
- f"// Expected local size: ({x}, {y}, {z})\n"
- f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X {x}\n"
- f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y {y}\n"
- f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Z {z}\n"
- )
-
- helper_header = self._helper_header()
-
- if len(helper_header) == 0:
- return f"{expected_size_header}\n{header}\n{body}"
-
- if len(self._fixed_preamble) > 0 and header.startswith(self._fixed_preamble):
- header_suffix = header[len(self._fixed_preamble):]
- finalized_header = f"{self._fixed_preamble}{helper_header}{header_suffix}"
- else:
- finalized_header = f"{header}\n{helper_header}"
-
- return f"{expected_size_header}\n{finalized_header}\n{body}"
-
- def constant_namespace(self) -> str:
- return "UBO"
-
- def variable_namespace(self) -> str:
- return "PC"
-
- def exec_bounds_guard(self, exec_count_expr: str) -> str:
- gid = self.global_invocation_id_expr()
- return (
- f"if (({exec_count_expr}).x <= ({gid}).x || "
- f"({exec_count_expr}).y <= ({gid}).y || "
- f"({exec_count_expr}).z <= ({gid}).z) {{ return; }}\n"
- )
-
- def shared_buffer_declaration(self, var_type: dtypes.dtype, name: str, size: int) -> str:
- return f"__shared__ {self.type_name(var_type)} {name}[{size}];"
-
- def uniform_block_declaration(self, contents: str) -> str:
- self._register_kernel_param("const UniformObjectBuffer* vkdispatch_uniform_ptr")
- self._register_alias_line("const UniformObjectBuffer& UBO = *vkdispatch_uniform_ptr;")
- return f"\nstruct UniformObjectBuffer {{\n{contents}\n}};\n"
-
- def storage_buffer_declaration(self, binding: int, var_type: dtypes.dtype, name: str) -> str:
- struct_name = f"Buffer{binding}"
- param_name = f"vkdispatch_binding_{binding}_ptr"
- self._register_kernel_param(f"{self.type_name(var_type)}* {param_name}")
- self._register_alias_line(f"{struct_name} {name} = {{{param_name}}};")
- return f"struct {struct_name} {{ {self.type_name(var_type)}* data; }};\n"
-
- def sampler_declaration(self, binding: int, dimensions: int, name: str) -> str:
- param_name = f"vkdispatch_sampler_{binding}"
- self._register_kernel_param(f"cudaTextureObject_t {param_name}")
- self._register_alias_line(f"cudaTextureObject_t {name} = {param_name};")
- return f"// sampler binding {binding}, dimensions={dimensions}\n"
-
- def push_constant_declaration(self, contents: str) -> str:
- self._register_kernel_param("const PushConstant* vkdispatch_pc_ptr")
- self._register_alias_line("const PushConstant& PC = *vkdispatch_pc_ptr;")
- return f"\nstruct PushConstant {{\n{contents}\n}};\n"
-
- def entry_point(self, body_contents: str) -> str:
- params = ", ".join(self._kernel_params)
-
- alias_block = ""
- for line in self._entry_alias_lines:
- alias_block += f" {line}\n"
-
- return (
- f'extern "C" __global__ void vkdispatch_main({params}) {{\n'
- f"{alias_block}"
- f"{body_contents}"
- f"}}\n"
- )
-
- def inf_f32_expr(self) -> str:
- self.mark_feature_usage("uintBitsToFloat")
- return "uintBitsToFloat(0x7F800000u)"
-
- def ninf_f32_expr(self) -> str:
- self.mark_feature_usage("uintBitsToFloat")
- return "uintBitsToFloat(0xFF800000u)"
-
- def fma_function_name(self, var_type: dtypes.dtype) -> str:
- if var_type == dtypes.float32:
- return "fmaf"
- return "fma"
-
- def float_bits_to_int_expr(self, var_expr: str) -> str:
- self.mark_feature_usage("floatBitsToInt")
- return f"floatBitsToInt({var_expr})"
-
- def float_bits_to_uint_expr(self, var_expr: str) -> str:
- self.mark_feature_usage("floatBitsToUint")
- return f"floatBitsToUint({var_expr})"
-
- def int_bits_to_float_expr(self, var_expr: str) -> str:
- self.mark_feature_usage("intBitsToFloat")
- return f"intBitsToFloat({var_expr})"
-
- def uint_bits_to_float_expr(self, var_expr: str) -> str:
- self.mark_feature_usage("uintBitsToFloat")
- return f"uintBitsToFloat({var_expr})"
-
- def global_invocation_id_expr(self) -> str:
- self._record_composite_type_key("uint3")
- self.mark_feature_usage("global_invocation_id")
- return "vkdispatch_global_invocation_id()"
-
- def local_invocation_id_expr(self) -> str:
- self._record_composite_type_key("uint3")
- self.mark_feature_usage("local_invocation_id")
- return "vkdispatch_local_invocation_id()"
-
- def local_invocation_index_expr(self) -> str:
- self.mark_feature_usage("local_invocation_index")
- return "vkdispatch_local_invocation_index()"
-
- def workgroup_id_expr(self) -> str:
- self._record_composite_type_key("uint3")
- self.mark_feature_usage("workgroup_id")
- return "vkdispatch_workgroup_id()"
-
- def workgroup_size_expr(self) -> str:
- self._record_composite_type_key("uint3")
- self.mark_feature_usage("make_uint3")
- return "vkdispatch_make_uint3((unsigned int)blockDim.x, (unsigned int)blockDim.y, (unsigned int)blockDim.z)"
-
- def num_workgroups_expr(self) -> str:
- self._record_composite_type_key("uint3")
- self.mark_feature_usage("make_uint3")
- return "vkdispatch_make_uint3((unsigned int)gridDim.x, (unsigned int)gridDim.y, (unsigned int)gridDim.z)"
-
- def num_subgroups_expr(self) -> str:
- self.mark_feature_usage("num_subgroups")
- return "vkdispatch_num_subgroups()"
-
- def subgroup_id_expr(self) -> str:
- self.mark_feature_usage("subgroup_id")
- return "vkdispatch_subgroup_id()"
-
- def subgroup_size_expr(self) -> str:
- self.mark_feature_usage("subgroup_size")
- return "vkdispatch_subgroup_size()"
-
- def subgroup_invocation_id_expr(self) -> str:
- self.mark_feature_usage("subgroup_invocation_id")
- return "vkdispatch_subgroup_invocation_id()"
-
- def barrier_statement(self) -> str:
- return "__syncthreads();"
-
- def memory_barrier_statement(self) -> str:
- return "__threadfence();"
-
- def memory_barrier_buffer_statement(self) -> str:
- return "__threadfence();"
-
- def memory_barrier_shared_statement(self) -> str:
- return "__threadfence_block();"
-
- def memory_barrier_image_statement(self) -> str:
- return "__threadfence();"
-
- def group_memory_barrier_statement(self) -> str:
- return "__threadfence_block();"
-
- def subgroup_add_expr(self, arg_expr: str) -> str:
- self.mark_feature_usage("subgroup_add")
- return f"vkdispatch_subgroup_add({arg_expr})"
-
- def subgroup_mul_expr(self, arg_expr: str) -> str:
- self.mark_feature_usage("subgroup_mul")
- return f"vkdispatch_subgroup_mul({arg_expr})"
-
- def subgroup_min_expr(self, arg_expr: str) -> str:
- self.mark_feature_usage("subgroup_min")
- return f"vkdispatch_subgroup_min({arg_expr})"
-
- def subgroup_max_expr(self, arg_expr: str) -> str:
- self.mark_feature_usage("subgroup_max")
- return f"vkdispatch_subgroup_max({arg_expr})"
-
- def subgroup_and_expr(self, arg_expr: str) -> str:
- self.mark_feature_usage("subgroup_and")
- return f"vkdispatch_subgroup_and({arg_expr})"
-
- def subgroup_or_expr(self, arg_expr: str) -> str:
- self.mark_feature_usage("subgroup_or")
- return f"vkdispatch_subgroup_or({arg_expr})"
-
- def subgroup_xor_expr(self, arg_expr: str) -> str:
- self.mark_feature_usage("subgroup_xor")
- return f"vkdispatch_subgroup_xor({arg_expr})"
-
- def subgroup_elect_expr(self) -> str:
- self.mark_feature_usage("subgroup_invocation_id")
- return "((int)(vkdispatch_subgroup_invocation_id() == 0u))"
-
- def subgroup_barrier_statement(self) -> str:
- return "__syncwarp();"
-
- def printf_statement(self, fmt: str, args: List[str]) -> str:
- safe_fmt = fmt.replace("\\", "\\\\").replace('"', '\\"')
-
- if len(args) == 0:
- return f'printf("{safe_fmt}");'
-
- return f'printf("{safe_fmt}", {", ".join(args)});'
-
- def texture_size_expr(self, texture_expr: str, lod: int, dimensions: int) -> str:
- # CUDA texture objects do not expose shape directly in device code.
- # The future CUDA backend should pass explicit texture shape parameters.
- if dimensions == 1:
- return "1.0f"
- if dimensions == 2:
- self.mark_feature_usage("make_float2")
- return "vkdispatch_make_float2(1.0f)"
- if dimensions == 3:
- self.mark_feature_usage("make_float3")
- return "vkdispatch_make_float3(1.0f)"
-
- raise ValueError(f"Unsupported texture dimensions '{dimensions}'")
-
- def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Optional[str] = None) -> str:
- self.mark_feature_usage("sample_texture")
- if lod_expr is None:
- return f"vkdispatch_sample_texture({texture_expr}, {coord_expr})"
-
- return f"vkdispatch_sample_texture({texture_expr}, {coord_expr}, {lod_expr})"
diff --git a/vkdispatch/codegen/backends/cuda/__init__.py b/vkdispatch/codegen/backends/cuda/__init__.py
new file mode 100644
index 00000000..31730746
--- /dev/null
+++ b/vkdispatch/codegen/backends/cuda/__init__.py
@@ -0,0 +1,3 @@
+from .backend import CUDABackend
+
+__all__ = ["CUDABackend"]
diff --git a/vkdispatch/codegen/backends/cuda/backend.py b/vkdispatch/codegen/backends/cuda/backend.py
new file mode 100644
index 00000000..7cd91f29
--- /dev/null
+++ b/vkdispatch/codegen/backends/cuda/backend.py
@@ -0,0 +1,892 @@
+from typing import Dict, List, Optional, Set, Tuple
+
+import vkdispatch.base.dtype as dtypes
+
+from ..base import CodeGenBackend
+from .composite_emitters import (
+ _cuda_emit_mat_helpers,
+ _cuda_emit_mat_type,
+ _cuda_emit_subgroup_shuffle_xor_vec_overloads,
+ _cuda_emit_vec_helper,
+ _cuda_emit_vec_type,
+ _cuda_emit_vec_wrapper_conversion_helpers,
+)
+from .helper_snippets import (
+ _HELPER_DEPENDENCIES as _CUDA_HELPER_DEPENDENCIES,
+ _HELPER_ORDER as _CUDA_HELPER_ORDER,
+ _HELPER_SNIPPETS as _CUDA_HELPER_SNIPPETS,
+ initialize_feature_usage,
+)
+from .math_utils import (
+ cuda_fast_binary_math_name,
+ cuda_fast_unary_math_name,
+ cuda_float_vec_components_for_suffix,
+ cuda_float_vec_helper_suffix,
+ cuda_scalar_binary_math_name,
+ cuda_scalar_unary_math_name,
+ emit_used_vec_math_helpers,
+)
+from .specs import (
+ _CUDA_MAT_ORDER,
+ _CUDA_MAT_TYPE_SPECS,
+ _CUDA_VEC_ORDER,
+ _CUDA_VEC_TYPE_SPECS,
+ _DTYPE_TO_COMPOSITE_KEY as _CUDA_DTYPE_TO_COMPOSITE_KEY,
+ _FLOAT_VEC_DTYPES as _CUDA_FLOAT_VEC_DTYPES,
+ _FLOAT_VEC_HELPER_SUFFIX_MAP as _CUDA_FLOAT_VEC_HELPER_SUFFIX_MAP,
+ _SCALAR_TYPE_NAMES as _CUDA_SCALAR_TYPE_NAMES,
+)
+
+class CUDABackend(CodeGenBackend):
+ name = "cuda"
+ _CUDA_BUILTIN_UVEC3_SENTINELS: Dict[str, Dict[str, str]] = {
+ "global_invocation_id": {
+ "sentinel": "VKDISPATCH_CUDA_GLOBAL_INVOCATION_ID_SENTINEL()",
+ "x": "(unsigned int)(blockIdx.x * blockDim.x + threadIdx.x)",
+ "y": "(unsigned int)(blockIdx.y * blockDim.y + threadIdx.y)",
+ "z": "(unsigned int)(blockIdx.z * blockDim.z + threadIdx.z)",
+ },
+ "local_invocation_id": {
+ "sentinel": "VKDISPATCH_CUDA_LOCAL_INVOCATION_ID_SENTINEL()",
+ "x": "(unsigned int)threadIdx.x",
+ "y": "(unsigned int)threadIdx.y",
+ "z": "(unsigned int)threadIdx.z",
+ },
+ "workgroup_id": {
+ "sentinel": "VKDISPATCH_CUDA_WORKGROUP_ID_SENTINEL()",
+ "x": "(unsigned int)blockIdx.x",
+ "y": "(unsigned int)blockIdx.y",
+ "z": "(unsigned int)blockIdx.z",
+ },
+ }
+
+ _HELPER_SNIPPETS: Dict[str, str] = _CUDA_HELPER_SNIPPETS
+ _HELPER_ORDER: List[str] = _CUDA_HELPER_ORDER
+ _HELPER_DEPENDENCIES: Dict[str, List[str]] = _CUDA_HELPER_DEPENDENCIES
+
+ def __init__(self) -> None:
+ self._fixed_preamble = ""
+ self.reset_state()
+
+ def reset_state(self) -> None:
+ self._kernel_params: List[str] = []
+ self._entry_alias_lines: List[str] = []
+ self._composite_type_usage: Set[str] = set()
+ self._composite_vec_op_usage: Dict[str, Set[str]] = {}
+ self._composite_mat_op_usage: Dict[str, Set[str]] = {}
+ self._composite_vec_unary_math_usage: Dict[str, Set[str]] = {}
+ self._composite_vec_binary_math_usage: Dict[str, Set[str]] = {}
+ self._sample_texture_dims: Set[int] = set()
+ self._needs_cuda_fp16: bool = False
+ self._feature_usage: Dict[str, bool] = initialize_feature_usage()
+
+ def mark_feature_usage(self, feature_name: str) -> None:
+ if feature_name in self._feature_usage:
+ self._feature_usage[feature_name] = True
+
+ _DTYPE_TO_COMPOSITE_KEY = _CUDA_DTYPE_TO_COMPOSITE_KEY
+
+ def _composite_key_for_dtype(self, var_type: dtypes.dtype) -> Optional[str]:
+ return self._DTYPE_TO_COMPOSITE_KEY.get(var_type)
+
+ def _record_composite_type_key(self, key: str) -> None:
+ self.mark_feature_usage("composite_types")
+ self._composite_type_usage.add(key)
+
+ if key in _CUDA_MAT_TYPE_SPECS:
+ dim = _CUDA_MAT_TYPE_SPECS[key][3]
+ self._composite_type_usage.add(f"float{dim}")
+
+ def _record_composite_type(self, var_type: dtypes.dtype) -> Optional[str]:
+ key = self._composite_key_for_dtype(var_type)
+ if key is None:
+ return None
+ self._record_composite_type_key(key)
+ return key
+
+ def _record_vec_op(self, key: str, token: str) -> None:
+ self._record_composite_type_key(key)
+ self._composite_vec_op_usage.setdefault(key, set()).add(token)
+
+ def _record_mat_op(self, key: str, token: str) -> None:
+ self._record_composite_type_key(key)
+ self._composite_mat_op_usage.setdefault(key, set()).add(token)
+
+ def _record_vec_unary_math(self, key: str, func_name: str) -> None:
+ self._record_composite_type_key(key)
+ self._composite_vec_unary_math_usage.setdefault(key, set()).add(func_name)
+
+ def _record_vec_binary_math(self, key: str, func_name: str, signature: str) -> None:
+ self._record_composite_type_key(key)
+ self._composite_vec_binary_math_usage.setdefault(key, set()).add(f"{func_name}:{signature}")
+
+ def _propagate_matrix_vec_dependencies(self, mat_key: str, token: str) -> None:
+ dim = _CUDA_MAT_TYPE_SPECS[mat_key][3]
+ vec_key = f"float{dim}"
+
+ if token == "un:-":
+ self._record_vec_op(vec_key, "un:-")
+ return
+
+ if token.startswith("cmpd:"):
+ if token.endswith(":m"):
+ vec_token = token[:-1] + "v"
+ self._record_vec_op(vec_key, vec_token)
+ return
+ if token.endswith(":s"):
+ self._record_vec_op(vec_key, token)
+ return
+
+ if token.startswith("bin:"):
+ parts = token.split(":")
+ if len(parts) != 3:
+ return
+ _, op, shape = parts
+ if shape == "mm":
+ if op in ["+", "-"]:
+ self._record_vec_op(vec_key, f"bin:{op}:vv")
+ elif op == "*":
+ self._record_mat_op(mat_key, "bin:*:mv")
+ self._propagate_matrix_vec_dependencies(mat_key, "bin:*:mv")
+ return
+ if shape == "ms":
+ self._record_vec_op(vec_key, f"bin:{op}:vs")
+ return
+ if shape == "sm":
+ self._record_vec_op(vec_key, f"bin:{op}:sv")
+ return
+ if shape == "mv":
+ self._record_vec_op(vec_key, "bin:*:vs")
+ self._record_vec_op(vec_key, "bin:+:vv")
+ return
+ if shape == "vm":
+ return
+
+ def mark_composite_unary_op(self, var_type: dtypes.dtype, op: str) -> None:
+ key = self._record_composite_type(var_type)
+ if key is None:
+ return
+
+ token = f"un:{op}"
+ if key in _CUDA_VEC_TYPE_SPECS:
+ self._record_vec_op(key, token)
+ return
+ if key in _CUDA_MAT_TYPE_SPECS:
+ self._record_mat_op(key, token)
+ self._propagate_matrix_vec_dependencies(key, token)
+
+ def mark_composite_binary_op(
+ self,
+ lhs_type: dtypes.dtype,
+ rhs_type: dtypes.dtype,
+ op: str,
+ *,
+ inplace: bool = False,
+ ) -> None:
+ lhs_key = self._record_composite_type(lhs_type)
+ rhs_key = self._record_composite_type(rhs_type)
+
+ lhs_is_composite = lhs_key is not None
+ rhs_is_composite = rhs_key is not None
+ if not lhs_is_composite and not rhs_is_composite:
+ return
+
+ lhs_is_scalar = dtypes.is_scalar(lhs_type)
+ rhs_is_scalar = dtypes.is_scalar(rhs_type)
+
+ if lhs_key in _CUDA_VEC_TYPE_SPECS and (rhs_is_scalar or rhs_key in _CUDA_VEC_TYPE_SPECS):
+ if inplace:
+ suffix = "s" if rhs_is_scalar else "v"
+ self._record_vec_op(lhs_key, f"cmpd:{op}=:{suffix}")
+ return
+ shape = "vs" if rhs_is_scalar else "vv"
+ self._record_vec_op(lhs_key, f"bin:{op}:{shape}")
+ return
+
+ if rhs_key in _CUDA_VEC_TYPE_SPECS and lhs_is_scalar and not inplace:
+ self._record_vec_op(rhs_key, f"bin:{op}:sv")
+ return
+
+ if lhs_key in _CUDA_MAT_TYPE_SPECS:
+ if inplace:
+ if rhs_is_scalar:
+ token = f"cmpd:{op}=:s"
+ elif rhs_key in _CUDA_MAT_TYPE_SPECS:
+ token = f"cmpd:{op}=:m"
+ else:
+ return
+ self._record_mat_op(lhs_key, token)
+ self._propagate_matrix_vec_dependencies(lhs_key, token)
+ return
+
+ if rhs_is_scalar:
+ token = f"bin:{op}:ms"
+ self._record_mat_op(lhs_key, token)
+ self._propagate_matrix_vec_dependencies(lhs_key, token)
+ return
+
+ if rhs_key in _CUDA_MAT_TYPE_SPECS:
+ token = "bin:*:mm" if op == "*" else f"bin:{op}:mm"
+ self._record_mat_op(lhs_key, token)
+ self._propagate_matrix_vec_dependencies(lhs_key, token)
+ return
+
+ if rhs_key in _CUDA_VEC_TYPE_SPECS and op == "*":
+ token = "bin:*:mv"
+ self._record_mat_op(lhs_key, token)
+ self._propagate_matrix_vec_dependencies(lhs_key, token)
+ return
+
+ if rhs_key in _CUDA_MAT_TYPE_SPECS and lhs_is_scalar and not inplace:
+ token = f"bin:{op}:sm"
+ self._record_mat_op(rhs_key, token)
+ self._propagate_matrix_vec_dependencies(rhs_key, token)
+ return
+
+ if lhs_key in _CUDA_VEC_TYPE_SPECS and rhs_key in _CUDA_MAT_TYPE_SPECS and op == "*" and not inplace:
+ token = "bin:*:vm"
+ self._record_mat_op(rhs_key, token)
+ self._propagate_matrix_vec_dependencies(rhs_key, token)
+
+ def _emit_used_composite_helpers(self) -> str:
+ if len(self._composite_type_usage) == 0:
+ return ""
+
+ parts: List[str] = []
+
+ # Subgroup helpers use vector binary operators internally (e.g. value = value + shuffled)
+ # even if user code never directly emits the corresponding operator on that vector type.
+ subgroup_vec_op_requirements = [
+ ("subgroup_add", "bin:+:vv"),
+ ("subgroup_mul", "bin:*:vv"),
+ ("subgroup_and", "bin:&:vv"),
+ ("subgroup_or", "bin:|:vv"),
+ ("subgroup_xor", "bin:^:vv"),
+ ]
+ for feature_name, token in subgroup_vec_op_requirements:
+ if not self._feature_usage.get(feature_name, False):
+ continue
+ for key in self._composite_type_usage:
+ if key in _CUDA_VEC_TYPE_SPECS:
+ self._composite_vec_op_usage.setdefault(key, set()).add(token)
+
+ emitted_vec_keys: Set[str] = set()
+ for key in _CUDA_VEC_ORDER:
+ if key not in self._composite_type_usage:
+ continue
+ vec_name, scalar_type, dim, cuda_native_type, allow_neg, enable_bitwise = _CUDA_VEC_TYPE_SPECS[key]
+ emitted_vec_keys.add(key)
+ parts.append(
+ _cuda_emit_vec_type(
+ vec_name,
+ scalar_type,
+ dim,
+ cuda_native_type,
+ allow_unary_neg=allow_neg,
+ enable_bitwise=enable_bitwise,
+ needed_ops=self._composite_vec_op_usage.get(key, set()),
+ )
+ )
+ parts.append(_cuda_emit_vec_helper(key, vec_name, scalar_type, dim))
+ for key in _CUDA_VEC_ORDER:
+ if key not in emitted_vec_keys:
+ continue
+ vec_name, scalar_type, dim, _, _, _ = _CUDA_VEC_TYPE_SPECS[key]
+ conversion_helpers = _cuda_emit_vec_wrapper_conversion_helpers(
+ key,
+ vec_name,
+ scalar_type,
+ dim,
+ available_keys=emitted_vec_keys,
+ )
+ if len(conversion_helpers) > 0:
+ parts.append(conversion_helpers)
+
+ subgroup_shuffle_overloads = _cuda_emit_subgroup_shuffle_xor_vec_overloads(emitted_vec_keys)
+ if len(subgroup_shuffle_overloads) > 0:
+ parts.append(subgroup_shuffle_overloads)
+
+ for key in _CUDA_MAT_ORDER:
+ if key not in self._composite_type_usage:
+ continue
+ mat_name, vec_name, vec_helper_suffix, dim = _CUDA_MAT_TYPE_SPECS[key]
+ parts.append(_cuda_emit_mat_type(mat_name, vec_name, dim, self._composite_mat_op_usage.get(key, set())))
+ parts.append(_cuda_emit_mat_helpers(mat_name, key, vec_name, vec_helper_suffix, dim))
+
+ vec_math_helpers = self._emit_used_vec_math_helpers()
+ if len(vec_math_helpers) > 0:
+ parts.append(vec_math_helpers)
+
+ return "\n\n".join(parts)
+
+ @staticmethod
+ def _cuda_scalar_unary_math_name(func_name: str, scalar_type: str) -> str:
+ return cuda_scalar_unary_math_name(func_name, scalar_type)
+
+ @staticmethod
+ def _cuda_scalar_binary_math_name(func_name: str, scalar_type: str) -> str:
+ return cuda_scalar_binary_math_name(func_name, scalar_type)
+
+ def _emit_used_vec_math_helpers(self) -> str:
+ return emit_used_vec_math_helpers(
+ self._composite_vec_unary_math_usage,
+ self._composite_vec_binary_math_usage,
+ )
+
+ def _register_kernel_param(self, param_decl: str) -> None:
+ if param_decl not in self._kernel_params:
+ self._kernel_params.append(param_decl)
+
+ def _register_alias_line(self, alias_line: str) -> None:
+ if alias_line not in self._entry_alias_lines:
+ self._entry_alias_lines.append(alias_line)
+
+ @staticmethod
+ def _is_plain_integer_literal(expr: str) -> bool:
+ if len(expr) == 0:
+ return False
+ if expr[0] in "+-":
+ return len(expr) > 1 and expr[1:].isdigit()
+ return expr.isdigit()
+
+ _SCALAR_TYPE_NAMES = _CUDA_SCALAR_TYPE_NAMES
+
+ def type_name(self, var_type: dtypes.dtype) -> str:
+ scalar_name = self._SCALAR_TYPE_NAMES.get(var_type)
+ if scalar_name is not None:
+ if var_type == dtypes.float16:
+ self._needs_cuda_fp16 = True
+ return scalar_name
+
+ key = self._composite_key_for_dtype(var_type)
+ if key is not None:
+ self._record_composite_type(var_type)
+ if key in _CUDA_VEC_TYPE_SPECS:
+ # Track fp16 header need when half vector types are used.
+ if _CUDA_VEC_TYPE_SPECS[key][1] == "__half":
+ self._needs_cuda_fp16 = True
+ return _CUDA_VEC_TYPE_SPECS[key][0]
+ if key in _CUDA_MAT_TYPE_SPECS:
+ return _CUDA_MAT_TYPE_SPECS[key][0]
+
+ raise ValueError(f"Unsupported CUDA type mapping for '{var_type.name}'")
+
+ _FLOAT_VEC_DTYPES = _CUDA_FLOAT_VEC_DTYPES
+
+ def constructor(
+ self,
+ var_type: dtypes.dtype,
+ args: List[str],
+ arg_types: Optional[List[Optional[dtypes.dtype]]] = None,
+ ) -> str:
+ _ = arg_types
+ if (
+ len(args) == 1
+ and var_type in self._FLOAT_VEC_DTYPES
+ and self._is_plain_integer_literal(args[0])
+ ):
+ scalar_type = None
+ if dtypes.is_complex(var_type):
+ scalar_type = var_type.child_type
+ elif dtypes.is_vector(var_type):
+ scalar_type = var_type.scalar
+
+ if scalar_type == dtypes.float64:
+ args = [f"{args[0]}.0"]
+ else:
+ args = [f"{args[0]}.0f"]
+
+ target_type = self.type_name(var_type)
+
+ if dtypes.is_scalar(var_type):
+ assert len(args) > 0, f"Constructor for scalar type '{var_type.name}' needs at least one argument."
+ return f"(({target_type})({args[0]}))"
+
+ if var_type == dtypes.mat2:
+ self.mark_feature_usage("make_mat2")
+ return f"vkdispatch_make_mat2({', '.join(args)})"
+ if var_type == dtypes.mat3:
+ self.mark_feature_usage("make_mat3")
+ return f"vkdispatch_make_mat3({', '.join(args)})"
+ if var_type == dtypes.mat4:
+ self.mark_feature_usage("make_mat4")
+ return f"vkdispatch_make_mat4({', '.join(args)})"
+
+ helper_suffix = target_type[len("vkdispatch_"):] if target_type.startswith("vkdispatch_") else target_type
+ helper_name = f"vkdispatch_make_{helper_suffix}"
+ self.mark_feature_usage(f"make_{helper_suffix}")
+ return f"{helper_name}({', '.join(args)})"
+
+ def component_access_expr(self, expr: str, component: str, base_type: dtypes.dtype) -> str:
+ if dtypes.is_scalar(base_type):
+ if component == "x":
+ return expr
+ return super().component_access_expr(expr, component, base_type)
+
+ if dtypes.is_vector(base_type) or dtypes.is_complex(base_type):
+ direct_builtin_component = self._cuda_builtin_uvec3_component_expr(expr, component, base_type)
+ if direct_builtin_component is not None:
+ return direct_builtin_component
+ return f"{expr}.v.{component}"
+
+ return super().component_access_expr(expr, component, base_type)
+
+ def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str:
+ subgroup_support = "1" if enable_subgroup_ops else "0"
+ printf_support = "1" if enable_printf else "0"
+
+ self._enable_subgroup_ops = enable_subgroup_ops
+ self._enable_printf = enable_printf
+
+ helper_header = self._helper_header()
+ fp16_include = "#include \n" if self._needs_cuda_fp16 else ""
+
+ self._fixed_preamble = (
+ "#include \n"
+ f"{fp16_include}\n"
+ f"#define VKDISPATCH_ENABLE_SUBGROUP_OPS {subgroup_support}\n"
+ f"#define VKDISPATCH_ENABLE_PRINTF {printf_support}\n\n"
+ f"{helper_header}\n\n"
+ )
+
+ return self._fixed_preamble
+
+ def _resolve_helper_dependencies(self, helpers: Set[str]) -> Set[str]:
+ pending = list(helpers)
+ resolved = set(helpers)
+
+ while len(pending) > 0:
+ helper_name = pending.pop()
+
+ for dependency in self._HELPER_DEPENDENCIES.get(helper_name, []):
+ if dependency not in resolved:
+ resolved.add(dependency)
+ pending.append(dependency)
+
+ return resolved
+
+ def _helper_header(self) -> str:
+ enabled_helpers = {
+ helper_name
+ for helper_name, is_enabled in self._feature_usage.items()
+ if is_enabled
+ }
+
+ resolved_helpers = self._resolve_helper_dependencies(enabled_helpers)
+
+ if len(resolved_helpers) == 0:
+ return ""
+
+ helper_sections: List[str] = []
+
+ for helper_name in self._HELPER_ORDER:
+ if helper_name in resolved_helpers:
+ if helper_name == "composite_types":
+ composite_helpers = self._emit_used_composite_helpers()
+ if len(composite_helpers) > 0:
+ helper_sections.append(composite_helpers)
+ continue
+
+ snippet = self._HELPER_SNIPPETS[helper_name]
+ if len(snippet) > 0:
+ helper_sections.append(snippet)
+
+ return "\n\n".join(helper_sections) + "\n\n"
+
+ def make_source(self, header: str, body: str, x: int, y: int, z: int) -> str:
+ header, body = self._finalize_cuda_builtin_uvec3_sentinels(header, body)
+
+ expected_size_header = (
+ f"// Expected local size: ({x}, {y}, {z})\n"
+ f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X {x}\n"
+ f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y {y}\n"
+ f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Z {z}\n"
+ )
+
+ return f"{expected_size_header}\n{header}\n{body}"
+
+ def constant_namespace(self) -> str:
+ return "UBO"
+
+ def variable_namespace(self) -> str:
+ return "PC"
+
+ def exec_bounds_guard(self, exec_count_expr: str) -> str:
+ gid = self.global_invocation_id_expr()
+ exec_expr = f"({exec_count_expr})"
+ gid_expr = f"({gid})"
+ return (
+ f"if ({self.component_access_expr(exec_expr, 'x', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'x', dtypes.uvec3)} || "
+ f"{self.component_access_expr(exec_expr, 'y', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'y', dtypes.uvec3)} || "
+ f"{self.component_access_expr(exec_expr, 'z', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'z', dtypes.uvec3)}) {{ return; }}\n"
+ )
+
+ def shared_buffer_declaration(self, var_type: dtypes.dtype, name: str, size: int) -> str:
+ return f"__shared__ {self.type_name(var_type)} {name}[{size}];"
+
+ def uniform_block_declaration(self, contents: str) -> str:
+ self._register_kernel_param("const UniformObjectBuffer vkdispatch_uniform_value")
+ self._register_alias_line("const UniformObjectBuffer& UBO = vkdispatch_uniform_value;")
+ return f"\nstruct UniformObjectBuffer {{\n{contents}\n}};\n"
+
+ def storage_buffer_declaration(self, binding: int, var_type: dtypes.dtype, name: str) -> str:
+ struct_name = f"Buffer{binding}"
+ param_name = f"vkdispatch_binding_{binding}_ptr"
+ self._register_kernel_param(f"{self.type_name(var_type)}* {param_name}")
+ self._register_alias_line(f"{struct_name} {name} = {{{param_name}}};")
+ return f"struct {struct_name} {{ {self.type_name(var_type)}* data; }};\n"
+
+ def sampler_declaration(self, binding: int, dimensions: int, name: str) -> str:
+ param_name = f"vkdispatch_sampler_{binding}"
+ self._register_kernel_param(f"cudaTextureObject_t {param_name}")
+ self._register_alias_line(f"cudaTextureObject_t {name} = {param_name};")
+ return f"// sampler binding {binding}, dimensions={dimensions}\n"
+
+ def push_constant_declaration(self, contents: str) -> str:
+ self._register_kernel_param("const PushConstant vkdispatch_pc_value")
+ self._register_alias_line("const PushConstant& PC = vkdispatch_pc_value;")
+ return f"\nstruct PushConstant {{\n{contents}\n}};\n"
+
+ def entry_point(self, body_contents: str) -> str:
+ params = ", ".join(self._kernel_params)
+
+ alias_block = ""
+ for line in self._entry_alias_lines:
+ alias_block += f" {line}\n"
+
+ return (
+ f'extern "C" __global__ void vkdispatch_main({params}) {{\n'
+ f"{alias_block}"
+ f"{body_contents}"
+ f"}}\n"
+ )
+
+ def inf_f32_expr(self) -> str:
+ self.mark_feature_usage("uintBitsToFloat")
+ return "uintBitsToFloat(0x7F800000u)"
+
+ def ninf_f32_expr(self) -> str:
+ self.mark_feature_usage("uintBitsToFloat")
+ return "uintBitsToFloat(0xFF800000u)"
+
+ def inf_f64_expr(self) -> str:
+ self.mark_feature_usage("longlong_as_double")
+ return "__longlong_as_double(0x7FF0000000000000LL)"
+
+ def ninf_f64_expr(self) -> str:
+ self.mark_feature_usage("longlong_as_double")
+ return "__longlong_as_double(0xFFF0000000000000LL)"
+
+ def inf_f16_expr(self) -> str:
+ self.mark_feature_usage("ushort_as_half")
+ return "__ushort_as_half(0x7C00u)"
+
+ def ninf_f16_expr(self) -> str:
+ self.mark_feature_usage("ushort_as_half")
+ return "__ushort_as_half(0xFC00u)"
+
+ def fma_function_name(self, var_type: dtypes.dtype) -> str:
+ if var_type == dtypes.float16:
+ return "__hfma"
+ if var_type == dtypes.float32:
+ return "fmaf"
+ return "fma"
+
+ def math_func_name(self, func_name: str, var_type: dtypes.dtype) -> str:
+ scalar = var_type
+ if dtypes.is_vector(var_type) or dtypes.is_matrix(var_type):
+ scalar = var_type.scalar
+ elif dtypes.is_complex(var_type):
+ scalar = var_type.child_type
+
+ if scalar == dtypes.float16:
+ return self._cuda_scalar_unary_math_name(func_name, "__half")
+ if scalar == dtypes.float32:
+ return self._cuda_fast_unary_math_name(func_name)
+ # double and integer types use standard C names
+ return func_name
+
+ @staticmethod
+ def _cuda_fast_unary_math_name(func_name: str) -> str:
+ return cuda_fast_unary_math_name(func_name)
+
+ @staticmethod
+ def _cuda_fast_binary_math_name(func_name: str) -> str:
+ return cuda_fast_binary_math_name(func_name)
+
+ _FLOAT_VEC_HELPER_SUFFIX_MAP = _CUDA_FLOAT_VEC_HELPER_SUFFIX_MAP
+
+ @staticmethod
+ def _cuda_float_vec_helper_suffix(var_type: dtypes.dtype) -> Optional[str]:
+ return cuda_float_vec_helper_suffix(var_type)
+
+ @staticmethod
+ def _cuda_float_vec_components_for_suffix(helper_suffix: str) -> List[str]:
+ return cuda_float_vec_components_for_suffix(helper_suffix)
+
+ def _cuda_componentwise_unary_math_expr(self, func_name: str, arg_type: dtypes.dtype, arg_expr: str) -> Optional[str]:
+ helper_suffix = self._cuda_float_vec_helper_suffix(arg_type)
+ if helper_suffix is None:
+ return None
+
+ self._record_vec_unary_math(helper_suffix, func_name)
+ return f"{func_name}({arg_expr})"
+
+ def _cuda_componentwise_binary_math_expr(
+ self,
+ func_name: str,
+ lhs_type: dtypes.dtype,
+ lhs_expr: str,
+ rhs_type: dtypes.dtype,
+ rhs_expr: str,
+ ) -> Optional[str]:
+ lhs_helper = self._cuda_float_vec_helper_suffix(lhs_type)
+ rhs_helper = self._cuda_float_vec_helper_suffix(rhs_type)
+
+ if lhs_helper is None and rhs_helper is None:
+ return None
+
+ if lhs_helper is not None and rhs_helper is not None and lhs_helper != rhs_helper:
+ return None
+
+ helper_suffix = lhs_helper if lhs_helper is not None else rhs_helper
+ assert helper_suffix is not None
+
+ signature = ("v" if lhs_helper is not None else "s") + ("v" if rhs_helper is not None else "s")
+ self._record_vec_binary_math(helper_suffix, func_name, signature)
+ return f"{func_name}({lhs_expr}, {rhs_expr})"
+
+ def unary_math_expr(self, func_name: str, arg_type: dtypes.dtype, arg_expr: str) -> str:
+ vector_expr = self._cuda_componentwise_unary_math_expr(func_name, arg_type, arg_expr)
+ if vector_expr is not None:
+ return vector_expr
+
+ mapped = self.math_func_name(func_name, arg_type)
+ return f"{mapped}({arg_expr})"
+
+ def binary_math_expr(
+ self,
+ func_name: str,
+ lhs_type: dtypes.dtype,
+ lhs_expr: str,
+ rhs_type: dtypes.dtype,
+ rhs_expr: str,
+ ) -> str:
+ vector_expr = self._cuda_componentwise_binary_math_expr(
+ func_name,
+ lhs_type,
+ lhs_expr,
+ rhs_type,
+ rhs_expr,
+ )
+ if vector_expr is not None:
+ return vector_expr
+
+ if dtypes.is_scalar(lhs_type) and dtypes.is_scalar(rhs_type):
+ scalar = lhs_type
+ scalar_name = self._SCALAR_TYPE_NAMES.get(scalar, "float")
+ return f"{self._cuda_scalar_binary_math_name(func_name, scalar_name)}({lhs_expr}, {rhs_expr})"
+
+ return f"{func_name}({lhs_expr}, {rhs_expr})"
+
+ def float_bits_to_int_expr(self, var_expr: str) -> str:
+ self.mark_feature_usage("floatBitsToInt")
+ return f"floatBitsToInt({var_expr})"
+
+ def float_bits_to_uint_expr(self, var_expr: str) -> str:
+ self.mark_feature_usage("floatBitsToUint")
+ return f"floatBitsToUint({var_expr})"
+
+ def int_bits_to_float_expr(self, var_expr: str) -> str:
+ self.mark_feature_usage("intBitsToFloat")
+ return f"intBitsToFloat({var_expr})"
+
+ def uint_bits_to_float_expr(self, var_expr: str) -> str:
+ self.mark_feature_usage("uintBitsToFloat")
+ return f"uintBitsToFloat({var_expr})"
+
+ def global_invocation_id_expr(self) -> str:
+ return self._CUDA_BUILTIN_UVEC3_SENTINELS["global_invocation_id"]["sentinel"]
+
+ def local_invocation_id_expr(self) -> str:
+ return self._CUDA_BUILTIN_UVEC3_SENTINELS["local_invocation_id"]["sentinel"]
+
+ def local_invocation_index_expr(self) -> str:
+ self.mark_feature_usage("local_invocation_index")
+ return "vkdispatch_local_invocation_index()"
+
+ def workgroup_id_expr(self) -> str:
+ return self._CUDA_BUILTIN_UVEC3_SENTINELS["workgroup_id"]["sentinel"]
+
+ def workgroup_size_expr(self) -> str:
+ self._record_composite_type_key("uint3")
+ self.mark_feature_usage("make_uint3")
+ return "vkdispatch_make_uint3((unsigned int)blockDim.x, (unsigned int)blockDim.y, (unsigned int)blockDim.z)"
+
+ def num_workgroups_expr(self) -> str:
+ self._record_composite_type_key("uint3")
+ self.mark_feature_usage("make_uint3")
+ return "vkdispatch_make_uint3((unsigned int)gridDim.x, (unsigned int)gridDim.y, (unsigned int)gridDim.z)"
+
+ def num_subgroups_expr(self) -> str:
+ self.mark_feature_usage("num_subgroups")
+ return "vkdispatch_num_subgroups()"
+
+ def subgroup_id_expr(self) -> str:
+ self.mark_feature_usage("subgroup_id")
+ return "vkdispatch_subgroup_id()"
+
+ def subgroup_size_expr(self) -> str:
+ self.mark_feature_usage("subgroup_size")
+ return "vkdispatch_subgroup_size()"
+
+ def subgroup_invocation_id_expr(self) -> str:
+ self.mark_feature_usage("subgroup_invocation_id")
+ return "vkdispatch_subgroup_invocation_id()"
+
+ def barrier_statement(self) -> str:
+ return "__syncthreads();"
+
+ def memory_barrier_statement(self) -> str:
+ return "__threadfence();"
+
+ def memory_barrier_buffer_statement(self) -> str:
+ return "__threadfence();"
+
+ def memory_barrier_shared_statement(self) -> str:
+ return "__threadfence_block();"
+
+ def memory_barrier_image_statement(self) -> str:
+ return "__threadfence();"
+
+ def group_memory_barrier_statement(self) -> str:
+ return "__threadfence_block();"
+
+ @staticmethod
+ def _strip_outer_parens(expr: str) -> str:
+ stripped = expr.strip()
+ while len(stripped) >= 2 and stripped[0] == "(" and stripped[-1] == ")":
+ depth = 0
+ balanced = True
+ for idx, ch in enumerate(stripped):
+ if ch == "(":
+ depth += 1
+ elif ch == ")":
+ depth -= 1
+ if depth < 0:
+ balanced = False
+ break
+ if depth == 0 and idx != len(stripped) - 1:
+ balanced = False
+ break
+ if not balanced or depth != 0:
+ break
+ stripped = stripped[1:-1].strip()
+ return stripped
+
+ def _cuda_builtin_uvec3_component_expr(
+ self,
+ expr: str,
+ component: str,
+ base_type: dtypes.dtype,
+ ) -> Optional[str]:
+ if base_type != dtypes.uvec3 or component not in ("x", "y", "z"):
+ return None
+
+ stripped_expr = self._strip_outer_parens(expr)
+ for builtin_spec in self._CUDA_BUILTIN_UVEC3_SENTINELS.values():
+ if stripped_expr == builtin_spec["sentinel"]:
+ return builtin_spec[component]
+
+ return None
+
+ def _finalize_cuda_builtin_uvec3_sentinels(self, header: str, body: str) -> Tuple[str, str]:
+ for builtin_spec in self._CUDA_BUILTIN_UVEC3_SENTINELS.values():
+ sentinel = builtin_spec["sentinel"]
+ if sentinel not in header and sentinel not in body:
+ continue
+
+ self._record_composite_type_key("uint3")
+ self.mark_feature_usage("make_uint3")
+ replacement = (
+ "vkdispatch_make_uint3("
+ f"{builtin_spec['x']}, {builtin_spec['y']}, {builtin_spec['z']}"
+ ")"
+ )
+ header = header.replace(sentinel, replacement)
+ body = body.replace(sentinel, replacement)
+
+ return header, body
+
+ def subgroup_add_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
+ self.mark_feature_usage("subgroup_add")
+ return f"vkdispatch_subgroup_add({arg_expr})"
+
+ def subgroup_mul_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
+ self.mark_feature_usage("subgroup_mul")
+ return f"vkdispatch_subgroup_mul({arg_expr})"
+
+ def subgroup_min_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
+ self.mark_feature_usage("subgroup_min")
+ return f"vkdispatch_subgroup_min({arg_expr})"
+
+ def subgroup_max_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
+ self.mark_feature_usage("subgroup_max")
+ return f"vkdispatch_subgroup_max({arg_expr})"
+
+ def subgroup_and_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
+ self.mark_feature_usage("subgroup_and")
+ return f"vkdispatch_subgroup_and({arg_expr})"
+
+ def subgroup_or_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
+ self.mark_feature_usage("subgroup_or")
+ return f"vkdispatch_subgroup_or({arg_expr})"
+
+ def subgroup_xor_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
+ self.mark_feature_usage("subgroup_xor")
+ return f"vkdispatch_subgroup_xor({arg_expr})"
+
+ def subgroup_elect_expr(self) -> str:
+ self.mark_feature_usage("subgroup_invocation_id")
+ return "((int)(vkdispatch_subgroup_invocation_id() == 0u))"
+
+ def subgroup_barrier_statement(self) -> str:
+ return "__syncwarp();"
+
+ def printf_statement(self, fmt: str, args: List[str]) -> str:
+ #safe_fmt = fmt.replace("\\", "\\\\").replace('"', '\\"')
+
+ if len(args) == 0:
+ return f'printf("{fmt}");'
+
+ return f'printf("{fmt}", {", ".join(args)});'
+
+ def texture_size_expr(self, texture_expr: str, lod: int, dimensions: int) -> str:
+ # CUDA texture objects do not expose shape directly in device code.
+ # The future CUDA backend should pass explicit texture shape parameters.
+ if dimensions == 1:
+ return "1.0f"
+ if dimensions == 2:
+ self.mark_feature_usage("make_float2")
+ return "vkdispatch_make_float2(1.0f)"
+ if dimensions == 3:
+ self.mark_feature_usage("make_float3")
+ return "vkdispatch_make_float3(1.0f)"
+
+ raise ValueError(f"Unsupported texture dimensions '{dimensions}'")
+
+ def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Optional[str] = None) -> str:
+ raise NotImplementedError("Direct texture sampling is not supported in CUDA backend. Use vkdispatch_sample_texture helper functions instead.")
+
+ def atomic_add_expr(self, mem_expr: str, value_expr: str, var_type: dtypes.dtype) -> str:
+ if var_type not in (dtypes.int32, dtypes.uint32):
+ raise NotImplementedError(f"CUDA atomic_add only supports int32/uint32, got '{var_type.name}'")
+
+ return f"atomicAdd(&({mem_expr}), {value_expr})"
diff --git a/vkdispatch/codegen/backends/cuda/composite_emitters.py b/vkdispatch/codegen/backends/cuda/composite_emitters.py
new file mode 100644
index 00000000..abb23ed6
--- /dev/null
+++ b/vkdispatch/codegen/backends/cuda/composite_emitters.py
@@ -0,0 +1,380 @@
+from typing import List, Optional, Set
+
+from .specs import _CUDA_MAT_TYPE_SPECS, _CUDA_VEC_ORDER, _CUDA_VEC_TYPE_SPECS
+
+
+def _cuda_vec_components(dim: int) -> List[str]:
+ if dim < 2 or dim > 4:
+ raise ValueError(f"Unsupported vector dimension '{dim}'")
+ return list("xyzw"[:dim])
+
+
+def _cuda_join_statements(statements: List[str]) -> str:
+ if len(statements) == 0:
+ return ""
+ return " ".join(statements)
+
+
+def _cuda_emit_vec_type(
+ vec_name: str,
+ scalar_type: str,
+ dim: int,
+ cuda_native_type: str,
+ *,
+ allow_unary_neg: bool,
+ enable_bitwise: bool,
+ needed_ops: Optional[Set[str]] = None,
+) -> str:
+ comps = _cuda_vec_components(dim)
+ if needed_ops is None:
+ needed_ops = set()
+ if allow_unary_neg:
+ needed_ops.add("un:-")
+ if enable_bitwise:
+ needed_ops.add("un:~")
+ for op in ["+", "-", "*", "/"]:
+ needed_ops.add(f"cmpd:{op}=:v")
+ needed_ops.add(f"cmpd:{op}=:s")
+ needed_ops.add(f"bin:{op}:vv")
+ needed_ops.add(f"bin:{op}:vs")
+ needed_ops.add(f"bin:{op}:sv")
+ if enable_bitwise:
+ for op in ["&", "|", "^", "<<", ">>"]:
+ needed_ops.add(f"cmpd:{op}=:v")
+ needed_ops.add(f"cmpd:{op}=:s")
+ needed_ops.add(f"bin:{op}:vv")
+ needed_ops.add(f"bin:{op}:vs")
+ needed_ops.add(f"bin:{op}:sv")
+
+ def has(token: str) -> bool:
+ return token in needed_ops
+
+ def self_comp(c: str) -> str:
+ return f"v.{c}"
+
+ def wrap_comp(obj: str, c: str) -> str:
+ return f"{obj}.v.{c}"
+
+ def native_comp(obj: str, c: str) -> str:
+ return f"{obj}.{c}"
+
+ def index_op_body() -> str:
+ branches: List[str] = []
+ for idx, c in enumerate(comps):
+ prefix = "if" if idx == 0 else "else if"
+ branches.append(f"{prefix} (i == {idx}) return v.{c};")
+ branches.append(f"else return v.{comps[0]};")
+ return " ".join(branches)
+
+ lines: List[str] = [f"struct {vec_name} {{"]
+ lines.append(f" {cuda_native_type} v;")
+ lines.append("")
+ ctor_args = ", ".join([f"{scalar_type} {c}_" for c in comps])
+ ctor_init = "{" + ", ".join([f"{c}_" for c in comps]) + "}"
+ splat_init = "{" + ", ".join(["s" for _ in comps]) + "}"
+ cast_init = "{" + ", ".join([f"({scalar_type}){native_comp('src', c)}" for c in comps]) + "}"
+ member_guard = ", ".join([f"(void)(((const TVec*)0)->{c})" for c in comps])
+ lines.append(f" __device__ __forceinline__ {vec_name}() = default;")
+ lines.append(f" __device__ __forceinline__ {vec_name}({ctor_args}) : v{ctor_init} {{}}")
+ lines.append(f" __device__ __forceinline__ explicit {vec_name}({scalar_type} s) : v{splat_init} {{}}")
+ lines.append(f" __device__ __forceinline__ explicit {vec_name}(const {cuda_native_type}& native) : v(native) {{}}")
+ lines.append(f" template ")
+ lines.append(f" __device__ __forceinline__ explicit {vec_name}(const TVec& src) : v{cast_init} {{}}")
+ lines.append(f" __device__ __forceinline__ {scalar_type}& operator[](int i) {{ {index_op_body()} }}")
+ lines.append(f" __device__ __forceinline__ const {scalar_type}& operator[](int i) const {{ {index_op_body()} }}")
+
+ if allow_unary_neg and has("un:-"):
+ neg_expr = ", ".join([f"-{self_comp(c)}" for c in comps])
+ lines.append(f" __device__ __forceinline__ {vec_name} operator-() const {{ return {vec_name}({neg_expr}); }}")
+
+ if enable_bitwise and has("un:~"):
+ not_expr = ", ".join([f"~{self_comp(c)}" for c in comps])
+ lines.append(f" __device__ __forceinline__ {vec_name} operator~() const {{ return {vec_name}({not_expr}); }}")
+
+ for op in ["+", "-", "*", "/"]:
+ op_assign = op + "="
+ if has(f"cmpd:{op}=:v"):
+ vv_ops = _cuda_join_statements([f"{self_comp(c)} {op_assign} {wrap_comp('b', c)};" for c in comps])
+ lines.append(
+ f" __device__ __forceinline__ {vec_name}& operator{op_assign}(const {vec_name}& b) {{ {vv_ops} return *this; }}"
+ )
+ if has(f"cmpd:{op}=:s"):
+ sv_ops = _cuda_join_statements([f"{self_comp(c)} {op_assign} b;" for c in comps])
+ lines.append(
+ f" __device__ __forceinline__ {vec_name}& operator{op_assign}({scalar_type} b) {{ {sv_ops} return *this; }}"
+ )
+
+ if enable_bitwise:
+ for op in ["&", "|", "^", "<<", ">>"]:
+ op_assign = op + "="
+ if has(f"cmpd:{op}=:v"):
+ vv_ops = _cuda_join_statements([f"{self_comp(c)} {op_assign} {wrap_comp('b', c)};" for c in comps])
+ lines.append(
+ f" __device__ __forceinline__ {vec_name}& operator{op_assign}(const {vec_name}& b) {{ {vv_ops} return *this; }}"
+ )
+ if has(f"cmpd:{op}=:s"):
+ sv_ops = _cuda_join_statements([f"{self_comp(c)} {op_assign} b;" for c in comps])
+ lines.append(
+ f" __device__ __forceinline__ {vec_name}& operator{op_assign}({scalar_type} b) {{ {sv_ops} return *this; }}"
+ )
+
+ lines.append("};")
+ lines.append(
+ f'static_assert(sizeof({vec_name}) == sizeof({cuda_native_type}), "{vec_name} size must match {cuda_native_type}");'
+ )
+ lines.append(
+ f'static_assert(alignof({vec_name}) == alignof({cuda_native_type}), "{vec_name} alignment must match {cuda_native_type}");'
+ )
+
+ for op in ["+", "-", "*", "/"]:
+ if has(f"bin:{op}:vv"):
+ vv_expr = ", ".join([f"({wrap_comp('a', c)} {op} {wrap_comp('b', c)})" for c in comps])
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, const {vec_name}& b) {{ return {vec_name}({vv_expr}); }}"
+ )
+ if has(f"bin:{op}:vs"):
+ vs_expr = ", ".join([f"({wrap_comp('a', c)} {op} b)" for c in comps])
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, {scalar_type} b) {{ return {vec_name}({vs_expr}); }}"
+ )
+ if has(f"bin:{op}:sv"):
+ if op in ["+", "*"]:
+ sv_expr = ", ".join([f"(a {op} {wrap_comp('b', c)})" for c in comps])
+ else:
+ sv_expr = ", ".join([f"({scalar_type})(a {op} {wrap_comp('b', c)})" for c in comps])
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} operator{op}({scalar_type} a, const {vec_name}& b) {{ return {vec_name}({sv_expr}); }}"
+ )
+
+ if enable_bitwise:
+ for op in ["&", "|", "^", "<<", ">>"]:
+ if has(f"bin:{op}:vv"):
+ vv_expr = ", ".join([f"({wrap_comp('a', c)} {op} {wrap_comp('b', c)})" for c in comps])
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, const {vec_name}& b) {{ return {vec_name}({vv_expr}); }}"
+ )
+ if has(f"bin:{op}:vs"):
+ vs_expr = ", ".join([f"({wrap_comp('a', c)} {op} b)" for c in comps])
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, {scalar_type} b) {{ return {vec_name}({vs_expr}); }}"
+ )
+ if has(f"bin:{op}:sv"):
+ sv_expr = ", ".join([f"({scalar_type})(a {op} {wrap_comp('b', c)})" for c in comps])
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} operator{op}({scalar_type} a, const {vec_name}& b) {{ return {vec_name}({sv_expr}); }}"
+ )
+
+ return "\n".join(lines)
+
+
+def _cuda_emit_vec_helper(helper_suffix: str, vec_name: str, scalar_type: str, dim: int) -> str:
+ comps = _cuda_vec_components(dim)
+ args = ", ".join([f"{scalar_type} {c}" for c in comps])
+ ctor_args = ", ".join(comps)
+ member_guard = ", ".join([f"(void)(((const TVec*)0)->{c})" for c in comps])
+ return "\n".join(
+ [
+ f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}({args}) {{ return {vec_name}({ctor_args}); }}",
+ f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}({scalar_type} x) {{ return {vec_name}(x); }}",
+ f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}(const {vec_name}& v) {{ return v; }}",
+ f"template ",
+ f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}(const TVec& v) {{ return {vec_name}(v); }}",
+ ]
+ )
+
+
+def _cuda_emit_vec_wrapper_conversion_helpers(
+ helper_suffix: str,
+ vec_name: str,
+ scalar_type: str,
+ dim: int,
+ *,
+ available_keys: Optional[Set[str]] = None,
+) -> str:
+ comps = _cuda_vec_components(dim)
+ dim_keys = [key for key in _CUDA_VEC_TYPE_SPECS if key.endswith(str(dim))]
+ if available_keys is not None:
+ dim_keys = [key for key in dim_keys if key in available_keys]
+
+ lines: List[str] = []
+ for src_key in dim_keys:
+ if src_key == helper_suffix:
+ continue
+ src_vec_name = _CUDA_VEC_TYPE_SPECS[src_key][0]
+ ctor_args = ", ".join([f"({scalar_type})src.v.{c}" for c in comps])
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}(const {src_vec_name}& src) {{ return {vec_name}({ctor_args}); }}"
+ )
+
+ return "\n".join(lines)
+
+
+def _cuda_emit_mat_type(mat_name: str, vec_name: str, dim: int, needed_ops: Optional[Set[str]] = None) -> str:
+ cols = [f"c{i}" for i in range(dim)]
+ if needed_ops is None:
+ needed_ops = {
+ "un:-",
+ "cmpd:+=:m",
+ "cmpd:+=:s",
+ "cmpd:-=:m",
+ "cmpd:-=:s",
+ "cmpd:*=:s",
+ "cmpd:/=:s",
+ "bin:+:mm",
+ "bin:+:ms",
+ "bin:+:sm",
+ "bin:-:mm",
+ "bin:-:ms",
+ "bin:-:sm",
+ "bin:*:ms",
+ "bin:*:sm",
+ "bin:/:ms",
+ "bin:/:sm",
+ "bin:*:mv",
+ "bin:*:vm",
+ "bin:*:mm",
+ }
+
+ def has(token: str) -> bool:
+ return token in needed_ops
+
+ lines: List[str] = [f"struct {mat_name} {{"]
+ lines.extend([f" {vec_name} {c};" for c in cols])
+ lines.append("")
+ lines.append(f" __device__ __forceinline__ {mat_name}() = default;")
+ ctor_args = ", ".join([f"{vec_name} {c}_" for c in cols])
+ ctor_init = ", ".join([f"{c}({c}_)" for c in cols])
+ lines.append(f" __device__ __forceinline__ {mat_name}({ctor_args}) : {ctor_init} {{}}")
+
+ zero = "0.0f"
+ diag_init = ", ".join(
+ [f"c{col_idx}({vec_name}({', '.join(['s' if row_idx == col_idx else zero for row_idx in range(dim)])}))" for col_idx in range(dim)]
+ )
+ lines.append(f" __device__ __forceinline__ explicit {mat_name}(float s) : {diag_init} {{}}")
+ lines.append(f" __device__ __forceinline__ {vec_name}& operator[](int i) {{ return (&c0)[i]; }}")
+ lines.append(f" __device__ __forceinline__ const {vec_name}& operator[](int i) const {{ return (&c0)[i]; }}")
+ if has("un:-"):
+ lines.append(
+ f" __device__ __forceinline__ {mat_name} operator-() const {{ return {mat_name}({', '.join([f'-c{i}' for i in range(dim)])}); }}"
+ )
+
+ for op in ["+", "-"]:
+ op_assign = op + "="
+ if has(f"cmpd:{op}=:m"):
+ mm_ops = _cuda_join_statements([f"c{i} {op_assign} b.c{i};" for i in range(dim)])
+ lines.append(
+ f" __device__ __forceinline__ {mat_name}& operator{op_assign}(const {mat_name}& b) {{ {mm_ops} return *this; }}"
+ )
+ if has(f"cmpd:{op}=:s"):
+ ms_ops = _cuda_join_statements([f"c{i} {op_assign} b;" for i in range(dim)])
+ lines.append(
+ f" __device__ __forceinline__ {mat_name}& operator{op_assign}(float b) {{ {ms_ops} return *this; }}"
+ )
+
+ for op in ["*", "/"]:
+ op_assign = op + "="
+ if has(f"cmpd:{op}=:s"):
+ ms_ops = _cuda_join_statements([f"c{i} {op_assign} b;" for i in range(dim)])
+ lines.append(
+ f" __device__ __forceinline__ {mat_name}& operator{op_assign}(float b) {{ {ms_ops} return *this; }}"
+ )
+
+ lines.append("};")
+
+ for op in ["+", "-"]:
+ if has(f"bin:{op}:mm"):
+ cols_expr = ", ".join([f"(a.c{i} {op} b.c{i})" for i in range(dim)])
+ lines.append(
+ f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}"
+ )
+ if has(f"bin:{op}:ms"):
+ cols_expr = ", ".join([f"(a.c{i} {op} b)" for i in range(dim)])
+ lines.append(
+ f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, float b) {{ return {mat_name}({cols_expr}); }}"
+ )
+ if has(f"bin:{op}:sm"):
+ cols_expr = ", ".join([f"(a {op} b.c{i})" for i in range(dim)])
+ lines.append(
+ f"__device__ __forceinline__ {mat_name} operator{op}(float a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}"
+ )
+
+ for op in ["*", "/"]:
+ if has(f"bin:{op}:ms"):
+ cols_expr = ", ".join([f"(a.c{i} {op} b)" for i in range(dim)])
+ lines.append(
+ f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, float b) {{ return {mat_name}({cols_expr}); }}"
+ )
+ if has(f"bin:{op}:sm"):
+ cols_expr = ", ".join([f"(a {op} b.c{i})" for i in range(dim)])
+ lines.append(
+ f"__device__ __forceinline__ {mat_name} operator{op}(float a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}"
+ )
+
+ vec_comps = _cuda_vec_components(dim)
+ if has("bin:*:mv"):
+ mat_vec_terms = [f"(m.c{i} * v.v.{vec_comps[i]})" for i in range(dim)]
+ mat_vec_expr = " + ".join(mat_vec_terms)
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} operator* (const {mat_name}& m, const {vec_name}& v) {{ return {mat_vec_expr}; }}"
+ )
+
+ if has("bin:*:vm"):
+ row_exprs: List[str] = []
+ for col_idx in range(dim):
+ terms = [f"(v.v.{vec_comps[row_idx]} * m.c{col_idx}.v.{vec_comps[row_idx]})" for row_idx in range(dim)]
+ row_exprs.append(" + ".join(terms))
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} operator* (const {vec_name}& v, const {mat_name}& m) {{ return {vec_name}({', '.join(row_exprs)}); }}"
+ )
+
+ if has("bin:*:mm"):
+ col_products = ", ".join([f"(a * b.c{i})" for i in range(dim)])
+ lines.append(
+ f"__device__ __forceinline__ {mat_name} operator* (const {mat_name}& a, const {mat_name}& b) {{ return {mat_name}({col_products}); }}"
+ )
+
+ return "\n".join(lines)
+
+
+def _cuda_emit_mat_helpers(mat_name: str, helper_suffix: str, vec_name: str, vec_helper_suffix: str, dim: int) -> str:
+ col_type = vec_name
+ col_args = ", ".join([f"{col_type} c{i}" for i in range(dim)])
+ col_ctor = ", ".join([f"c{i}" for i in range(dim)])
+
+ flat_names = [f"m{col}{row}" for col in range(dim) for row in range(dim)]
+ flat_args = ", ".join([f"float {name}" for name in flat_names])
+ flat_cols: List[str] = []
+ for col in range(dim):
+ values = [f"m{col}{row}" for row in range(dim)]
+ flat_cols.append(f"vkdispatch_make_{vec_helper_suffix}({', '.join(values)})")
+ flat_ctor = ", ".join(flat_cols)
+
+ cast_cols = ", ".join([f"vkdispatch_make_{vec_helper_suffix}(m[{i}])" for i in range(dim)])
+
+ return "\n".join(
+ [
+ f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}({col_args}) {{ return {mat_name}({col_ctor}); }}",
+ f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}(float s) {{ return {mat_name}(s); }}",
+ f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}({flat_args}) {{ return {mat_name}({flat_ctor}); }}",
+ "template ",
+ f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}(TMat m) {{ return {mat_name}({cast_cols}); }}",
+ ]
+ )
+
+
+def _cuda_emit_subgroup_shuffle_xor_vec_overloads(vec_keys: Set[str]) -> str:
+ lines: List[str] = []
+
+ for key in _CUDA_VEC_ORDER:
+ if key not in vec_keys:
+ continue
+
+ vec_name, _, dim, _, _, _ = _CUDA_VEC_TYPE_SPECS[key]
+ comps = _cuda_vec_components(dim)
+ comp_exprs = ", ".join([f"__shfl_xor_sync(mask, value.v.{c}, lane_mask)" for c in comps])
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} vkdispatch_subgroup_shuffle_xor(unsigned int mask, const {vec_name}& value, int lane_mask) "
+ f"{{ return vkdispatch_make_{key}({comp_exprs}); }}"
+ )
+
+ return "\n".join(lines)
diff --git a/vkdispatch/codegen/backends/cuda/helper_snippets.py b/vkdispatch/codegen/backends/cuda/helper_snippets.py
new file mode 100644
index 00000000..93fa3eeb
--- /dev/null
+++ b/vkdispatch/codegen/backends/cuda/helper_snippets.py
@@ -0,0 +1,287 @@
+from typing import Dict, List
+
+
+_HELPER_SNIPPETS: Dict[str, str] = {
+ "composite_types": "",
+ "mat2_type": "",
+ "mat3_type": "",
+ "mat4_type": "",
+ "make_mat2": "",
+ "make_mat3": "",
+ "make_mat4": "",
+ "make_short2": "",
+ "make_short3": "",
+ "make_short4": "",
+ "make_ushort2": "",
+ "make_ushort3": "",
+ "make_ushort4": "",
+ "make_int2": "",
+ "make_int3": "",
+ "make_int4": "",
+ "make_uint2": "",
+ "make_uint3": "",
+ "make_uint4": "",
+ "make_half2": "",
+ "make_half3": "",
+ "make_half4": "",
+ "float2_ops": "",
+ "make_float2": "",
+ "make_float3": "",
+ "make_float4": "",
+ "make_double2": "",
+ "make_double3": "",
+ "make_double4": "",
+ "global_invocation_id": (
+ "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_global_invocation_id() {\n"
+ " return vkdispatch_uint3(\n"
+ " (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x),\n"
+ " (unsigned int)(blockIdx.y * blockDim.y + threadIdx.y),\n"
+ " (unsigned int)(blockIdx.z * blockDim.z + threadIdx.z)\n"
+ " );\n"
+ "}"
+ ),
+ "local_invocation_id": (
+ "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_local_invocation_id() {\n"
+ " return vkdispatch_uint3((unsigned int)threadIdx.x, (unsigned int)threadIdx.y, (unsigned int)threadIdx.z);\n"
+ "}"
+ ),
+ "workgroup_id": (
+ "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_workgroup_id() {\n"
+ " return vkdispatch_uint3((unsigned int)blockIdx.x, (unsigned int)blockIdx.y, (unsigned int)blockIdx.z);\n"
+ "}"
+ ),
+ "local_invocation_index": (
+ "__device__ __forceinline__ unsigned int vkdispatch_local_invocation_index() {\n"
+ " return (unsigned int)(threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z));\n"
+ "}"
+ ),
+ "subgroup_size": "__device__ __forceinline__ unsigned int vkdispatch_subgroup_size() { return (unsigned int)warpSize; }",
+ "num_subgroups": (
+ "__device__ __forceinline__ unsigned int vkdispatch_num_subgroups() {\n"
+ " unsigned int local_count = (unsigned int)(blockDim.x * blockDim.y * blockDim.z);\n"
+ " return (local_count + vkdispatch_subgroup_size() - 1u) / vkdispatch_subgroup_size();\n"
+ "}"
+ ),
+ "subgroup_id": (
+ "__device__ __forceinline__ unsigned int vkdispatch_subgroup_id() {\n"
+ " return vkdispatch_local_invocation_index() / vkdispatch_subgroup_size();\n"
+ "}"
+ ),
+ "subgroup_invocation_id": (
+ "__device__ __forceinline__ unsigned int vkdispatch_subgroup_invocation_id() {\n"
+ " return vkdispatch_local_invocation_index() % vkdispatch_subgroup_size();\n"
+ "}"
+ ),
+ "subgroup_shuffle_xor": (
+ "template \n"
+ "__device__ __forceinline__ T vkdispatch_subgroup_shuffle_xor(unsigned int mask, T value, int lane_mask) {\n"
+ " return __shfl_xor_sync(mask, value, lane_mask);\n"
+ "}"
+ ),
+ "subgroup_add": (
+ "template \n"
+ "__device__ __forceinline__ T vkdispatch_subgroup_add(T value) {\n"
+ " unsigned int mask = 0xffffffffu;\n"
+ " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
+ " value = value + vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n"
+ " }\n"
+ " return value;\n"
+ "}"
+ ),
+ "subgroup_mul": (
+ "template \n"
+ "__device__ __forceinline__ T vkdispatch_subgroup_mul(T value) {\n"
+ " unsigned int mask = 0xffffffffu;\n"
+ " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
+ " value = value * vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n"
+ " }\n"
+ " return value;\n"
+ "}"
+ ),
+ "subgroup_min": (
+ "template \n"
+ "__device__ __forceinline__ T vkdispatch_subgroup_min(T value) {\n"
+ " unsigned int mask = 0xffffffffu;\n"
+ " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
+ " T other = vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n"
+ " value = other < value ? other : value;\n"
+ " }\n"
+ " return value;\n"
+ "}"
+ ),
+ "subgroup_max": (
+ "template \n"
+ "__device__ __forceinline__ T vkdispatch_subgroup_max(T value) {\n"
+ " unsigned int mask = 0xffffffffu;\n"
+ " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
+ " T other = vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n"
+ " value = other > value ? other : value;\n"
+ " }\n"
+ " return value;\n"
+ "}"
+ ),
+ "subgroup_and": (
+ "template \n"
+ "__device__ __forceinline__ T vkdispatch_subgroup_and(T value) {\n"
+ " unsigned int mask = 0xffffffffu;\n"
+ " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
+ " value = value & vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n"
+ " }\n"
+ " return value;\n"
+ "}"
+ ),
+ "subgroup_or": (
+ "template \n"
+ "__device__ __forceinline__ T vkdispatch_subgroup_or(T value) {\n"
+ " unsigned int mask = 0xffffffffu;\n"
+ " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
+ " value = value | vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n"
+ " }\n"
+ " return value;\n"
+ "}"
+ ),
+ "subgroup_xor": (
+ "template \n"
+ "__device__ __forceinline__ T vkdispatch_subgroup_xor(T value) {\n"
+ " unsigned int mask = 0xffffffffu;\n"
+ " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n"
+ " value = value ^ vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n"
+ " }\n"
+ " return value;\n"
+ "}"
+ ),
+ "mod": (
+ "__device__ __forceinline__ float mod(float x, float y) { return fmodf(x, y); }\n"
+ "__device__ __forceinline__ double mod(double x, double y) { return fmod(x, y); }"
+ ),
+ "fract": (
+ "__device__ __forceinline__ float fract(float x) { return x - floorf(x); }\n"
+ "__device__ __forceinline__ double fract(double x) { return x - floor(x); }"
+ ),
+ "roundEven": (
+ "__device__ __forceinline__ float roundEven(float x) { return nearbyintf(x); }\n"
+ "__device__ __forceinline__ double roundEven(double x) { return nearbyint(x); }"
+ ),
+ "mix": (
+ "__device__ __forceinline__ float mix(float x, float y, float a) { return x + (y - x) * a; }\n"
+ "__device__ __forceinline__ double mix(double x, double y, double a) { return x + (y - x) * a; }"
+ ),
+ "step": (
+ "__device__ __forceinline__ float step(float edge, float x) { return x < edge ? 0.0f : 1.0f; }\n"
+ "__device__ __forceinline__ double step(double edge, double x) { return x < edge ? 0.0 : 1.0; }"
+ ),
+ "smoothstep": (
+ "__device__ __forceinline__ float smoothstep(float edge0, float edge1, float x) {\n"
+ " float t = fminf(fmaxf((x - edge0) / (edge1 - edge0), 0.0f), 1.0f);\n"
+ " return t * t * (3.0f - 2.0f * t);\n"
+ "}\n"
+ "__device__ __forceinline__ double smoothstep(double edge0, double edge1, double x) {\n"
+ " double t = fmin(fmax((x - edge0) / (edge1 - edge0), 0.0), 1.0);\n"
+ " return t * t * (3.0 - 2.0 * t);\n"
+ "}"
+ ),
+ "radians": (
+ "__device__ __forceinline__ float radians(float x) { return x * (3.14159265358979323846f / 180.0f); }\n"
+ "__device__ __forceinline__ double radians(double x) { return x * (3.14159265358979323846 / 180.0); }"
+ ),
+ "degrees": (
+ "__device__ __forceinline__ float degrees(float x) { return x * (180.0f / 3.14159265358979323846f); }\n"
+ "__device__ __forceinline__ double degrees(double x) { return x * (180.0 / 3.14159265358979323846); }"
+ ),
+ "inversesqrt": (
+ "__device__ __forceinline__ float inversesqrt(float x) { return rsqrtf(x); }\n"
+ "__device__ __forceinline__ double inversesqrt(double x) { return rsqrt(x); }"
+ ),
+ "floatBitsToInt": "__device__ __forceinline__ int floatBitsToInt(float x) { return __float_as_int(x); }",
+ "floatBitsToUint": "__device__ __forceinline__ unsigned int floatBitsToUint(float x) { return __float_as_uint(x); }",
+ "intBitsToFloat": "__device__ __forceinline__ float intBitsToFloat(int x) { return __int_as_float(x); }",
+ "uintBitsToFloat": "__device__ __forceinline__ float uintBitsToFloat(unsigned int x) { return __uint_as_float(x); }",
+ "longlong_as_double": "__device__ __forceinline__ double longlong_as_double(long long x) { return __longlong_as_double(x); }",
+ "ushort_as_half": "__device__ __forceinline__ __half ushort_as_half(unsigned short x) { __half h; *reinterpret_cast(&h) = x; return h; }",
+ "sample_texture": "",
+}
+
+_HELPER_ORDER: List[str] = [
+ "composite_types",
+ "global_invocation_id",
+ "local_invocation_id",
+ "workgroup_id",
+ "local_invocation_index",
+ "subgroup_size",
+ "num_subgroups",
+ "subgroup_id",
+ "subgroup_invocation_id",
+ "subgroup_shuffle_xor",
+ "subgroup_add",
+ "subgroup_mul",
+ "subgroup_min",
+ "subgroup_max",
+ "subgroup_and",
+ "subgroup_or",
+ "subgroup_xor",
+ "mod",
+ "fract",
+ "roundEven",
+ "mix",
+ "step",
+ "smoothstep",
+ "radians",
+ "degrees",
+ "inversesqrt",
+ "floatBitsToInt",
+ "floatBitsToUint",
+ "intBitsToFloat",
+ "uintBitsToFloat",
+ "longlong_as_double",
+ "ushort_as_half",
+ "sample_texture",
+]
+
+_HELPER_DEPENDENCIES: Dict[str, List[str]] = {
+ "mat2_type": ["composite_types"],
+ "mat3_type": ["composite_types"],
+ "mat4_type": ["composite_types"],
+ "make_mat2": ["composite_types"],
+ "make_mat3": ["composite_types"],
+ "make_mat4": ["composite_types"],
+ "make_short2": ["composite_types"],
+ "make_short3": ["composite_types"],
+ "make_short4": ["composite_types"],
+ "make_ushort2": ["composite_types"],
+ "make_ushort3": ["composite_types"],
+ "make_ushort4": ["composite_types"],
+ "make_int2": ["composite_types"],
+ "make_int3": ["composite_types"],
+ "make_int4": ["composite_types"],
+ "make_uint2": ["composite_types"],
+ "make_uint3": ["composite_types"],
+ "make_uint4": ["composite_types"],
+ "make_half2": ["composite_types"],
+ "make_half3": ["composite_types"],
+ "make_half4": ["composite_types"],
+ "float2_ops": ["composite_types"],
+ "make_float2": ["composite_types"],
+ "make_float3": ["composite_types"],
+ "make_float4": ["composite_types"],
+ "make_double2": ["composite_types"],
+ "make_double3": ["composite_types"],
+ "make_double4": ["composite_types"],
+ "global_invocation_id": ["composite_types"],
+ "local_invocation_id": ["composite_types"],
+ "workgroup_id": ["composite_types"],
+ "sample_texture": ["composite_types"],
+ "num_subgroups": ["subgroup_size"],
+ "subgroup_id": ["local_invocation_index", "subgroup_size"],
+ "subgroup_invocation_id": ["local_invocation_index", "subgroup_size"],
+ "subgroup_add": ["subgroup_size", "subgroup_shuffle_xor"],
+ "subgroup_mul": ["subgroup_size", "subgroup_shuffle_xor"],
+ "subgroup_min": ["subgroup_size", "subgroup_shuffle_xor"],
+ "subgroup_max": ["subgroup_size", "subgroup_shuffle_xor"],
+ "subgroup_and": ["subgroup_size", "subgroup_shuffle_xor"],
+ "subgroup_or": ["subgroup_size", "subgroup_shuffle_xor"],
+ "subgroup_xor": ["subgroup_size", "subgroup_shuffle_xor"],
+}
+
+
+def initialize_feature_usage() -> Dict[str, bool]:
+ return {feature_name: False for feature_name in _HELPER_SNIPPETS}
diff --git a/vkdispatch/codegen/backends/cuda/math_utils.py b/vkdispatch/codegen/backends/cuda/math_utils.py
new file mode 100644
index 00000000..fc5ce5ad
--- /dev/null
+++ b/vkdispatch/codegen/backends/cuda/math_utils.py
@@ -0,0 +1,174 @@
+from typing import Dict, List, Optional, Set
+
+import vkdispatch.base.dtype as dtypes
+
+from .composite_emitters import _cuda_vec_components
+from .specs import _CUDA_VEC_TYPE_SPECS, _FLOAT_VEC_HELPER_SUFFIX_MAP
+
+
+def cuda_fast_unary_math_name(func_name: str) -> str:
+ if func_name == "sin":
+ return "__sinf"
+ if func_name == "cos":
+ return "__cosf"
+ if func_name == "tan":
+ return "__tanf"
+ if func_name == "exp":
+ return "__expf"
+ if func_name == "exp2":
+ return "__exp2f"
+ if func_name == "log":
+ return "__logf"
+ if func_name == "log2":
+ return "__log2f"
+ if func_name == "asin":
+ return "asinf"
+ if func_name == "acos":
+ return "acosf"
+ if func_name == "atan":
+ return "atanf"
+ if func_name == "sinh":
+ return "sinhf"
+ if func_name == "cosh":
+ return "coshf"
+ if func_name == "tanh":
+ return "tanhf"
+ if func_name == "asinh":
+ return "asinhf"
+ if func_name == "acosh":
+ return "acoshf"
+ if func_name == "atanh":
+ return "atanhf"
+ if func_name == "sqrt":
+ return "sqrtf"
+
+ return func_name
+
+
+def cuda_fast_binary_math_name(func_name: str) -> str:
+ if func_name == "atan2":
+ return "atan2f"
+ if func_name == "pow":
+ return "__powf"
+
+ return func_name
+
+
+def cuda_scalar_unary_math_name(func_name: str, scalar_type: str) -> str:
+ if scalar_type == "__half":
+ half_math = {
+ "sin": "hsin",
+ "cos": "hcos",
+ "exp": "hexp",
+ "exp2": "hexp2",
+ "log": "hlog",
+ "log2": "hlog2",
+ "sqrt": "hsqrt",
+ }
+ return half_math.get(func_name, func_name)
+ if scalar_type == "double":
+ return func_name
+ return cuda_fast_unary_math_name(func_name)
+
+
+def cuda_scalar_binary_math_name(func_name: str, scalar_type: str) -> str:
+ if scalar_type == "__half":
+ return func_name
+ if scalar_type == "double":
+ return func_name
+ return cuda_fast_binary_math_name(func_name)
+
+
+def cuda_float_vec_components_for_suffix(helper_suffix: str) -> List[str]:
+ dim_char = helper_suffix[-1]
+ if dim_char == "2":
+ return ["x", "y"]
+ if dim_char == "3":
+ return ["x", "y", "z"]
+ if dim_char == "4":
+ return ["x", "y", "z", "w"]
+
+ raise ValueError(f"Unsupported CUDA float vector helper suffix '{helper_suffix}'")
+
+
+def cuda_float_vec_helper_suffix(var_type: dtypes.dtype) -> Optional[str]:
+ return _FLOAT_VEC_HELPER_SUFFIX_MAP.get(var_type)
+
+
+def emit_used_vec_math_helpers(
+ composite_vec_unary_math_usage: Dict[str, Set[str]],
+ composite_vec_binary_math_usage: Dict[str, Set[str]],
+) -> str:
+ helper_sections: List[str] = []
+
+ unary_order = [
+ "sin",
+ "cos",
+ "tan",
+ "asin",
+ "acos",
+ "atan",
+ "sinh",
+ "cosh",
+ "tanh",
+ "asinh",
+ "acosh",
+ "atanh",
+ "exp",
+ "exp2",
+ "log",
+ "log2",
+ "sqrt",
+ ]
+ binary_order = ["atan2", "pow"]
+ signature_order = ["vv", "vs", "sv"]
+
+ for key in ["half2", "half3", "half4", "float2", "float3", "float4", "double2", "double3", "double4"]:
+ unary_funcs = composite_vec_unary_math_usage.get(key, set())
+ binary_tokens = composite_vec_binary_math_usage.get(key, set())
+ if len(unary_funcs) == 0 and len(binary_tokens) == 0:
+ continue
+
+ if key not in _CUDA_VEC_TYPE_SPECS:
+ continue
+
+ vec_name, scalar_type, dim, _, _, _ = _CUDA_VEC_TYPE_SPECS[key]
+ comps = _cuda_vec_components(dim)
+ lines: List[str] = []
+
+ for func_name in unary_order:
+ if func_name not in unary_funcs:
+ continue
+ scalar_func = cuda_scalar_unary_math_name(func_name, scalar_type)
+ comp_args = ", ".join([f"{scalar_func}(v.v.{c})" for c in comps])
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} {func_name}(const {vec_name}& v) {{ return vkdispatch_make_{key}({comp_args}); }}"
+ )
+
+ for func_name in binary_order:
+ scalar_func = cuda_scalar_binary_math_name(func_name, scalar_type)
+ for signature in signature_order:
+ token = f"{func_name}:{signature}"
+ if token not in binary_tokens:
+ continue
+
+ if signature == "vv":
+ comp_args = ", ".join([f"{scalar_func}(a.v.{c}, b.v.{c})" for c in comps])
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} {func_name}(const {vec_name}& a, const {vec_name}& b) {{ return vkdispatch_make_{key}({comp_args}); }}"
+ )
+ elif signature == "vs":
+ comp_args = ", ".join([f"{scalar_func}(a.v.{c}, b)" for c in comps])
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} {func_name}(const {vec_name}& a, {scalar_type} b) {{ return vkdispatch_make_{key}({comp_args}); }}"
+ )
+ elif signature == "sv":
+ comp_args = ", ".join([f"{scalar_func}(a, b.v.{c})" for c in comps])
+ lines.append(
+ f"__device__ __forceinline__ {vec_name} {func_name}({scalar_type} a, const {vec_name}& b) {{ return vkdispatch_make_{key}({comp_args}); }}"
+ )
+
+ if len(lines) > 0:
+ helper_sections.append("\n".join(lines))
+
+ return "\n\n".join(helper_sections)
diff --git a/vkdispatch/codegen/backends/cuda/specs.py b/vkdispatch/codegen/backends/cuda/specs.py
new file mode 100644
index 00000000..c029b5b0
--- /dev/null
+++ b/vkdispatch/codegen/backends/cuda/specs.py
@@ -0,0 +1,120 @@
+from typing import Dict, FrozenSet, Tuple
+
+import vkdispatch.base.dtype as dtypes
+
+
+_CUDA_VEC_TYPE_SPECS: Dict[str, Tuple[str, str, int, str, bool, bool]] = {
+ "short2": ("vkdispatch_short2", "short", 2, "short2", True, True),
+ "short3": ("vkdispatch_short3", "short", 3, "short3", True, True),
+ "short4": ("vkdispatch_short4", "short", 4, "short4", True, True),
+ "ushort2": ("vkdispatch_ushort2", "unsigned short", 2, "ushort2", False, True),
+ "ushort3": ("vkdispatch_ushort3", "unsigned short", 3, "ushort3", False, True),
+ "ushort4": ("vkdispatch_ushort4", "unsigned short", 4, "ushort4", False, True),
+ "int2": ("vkdispatch_int2", "int", 2, "int2", True, True),
+ "int3": ("vkdispatch_int3", "int", 3, "int3", True, True),
+ "int4": ("vkdispatch_int4", "int", 4, "int4", True, True),
+ "uint2": ("vkdispatch_uint2", "unsigned int", 2, "uint2", False, True),
+ "uint3": ("vkdispatch_uint3", "unsigned int", 3, "uint3", False, True),
+ "uint4": ("vkdispatch_uint4", "unsigned int", 4, "uint4", False, True),
+ "half2": ("vkdispatch_half2", "__half", 2, "half2", True, False),
+ "half3": ("vkdispatch_half3", "__half", 3, "half3", True, False),
+ "half4": ("vkdispatch_half4", "__half", 4, "half4", True, False),
+ "float2": ("vkdispatch_float2", "float", 2, "float2", True, False),
+ "float3": ("vkdispatch_float3", "float", 3, "float3", True, False),
+ "float4": ("vkdispatch_float4", "float", 4, "float4", True, False),
+ "double2": ("vkdispatch_double2", "double", 2, "double2", True, False),
+ "double3": ("vkdispatch_double3", "double", 3, "double3", True, False),
+ "double4": ("vkdispatch_double4", "double", 4, "double4", True, False),
+}
+
+_CUDA_MAT_TYPE_SPECS: Dict[str, Tuple[str, str, str, int]] = {
+ "mat2": ("vkdispatch_mat2", "vkdispatch_float2", "float2", 2),
+ "mat3": ("vkdispatch_mat3", "vkdispatch_float3", "float3", 3),
+ "mat4": ("vkdispatch_mat4", "vkdispatch_float4", "float4", 4),
+}
+
+_CUDA_VEC_ORDER = [
+ "short2", "short3", "short4",
+ "ushort2", "ushort3", "ushort4",
+ "int2", "int3", "int4",
+ "uint2", "uint3", "uint4",
+ "half2", "half3", "half4",
+ "float2", "float3", "float4",
+ "double2", "double3", "double4",
+]
+
+_CUDA_MAT_ORDER = ["mat2", "mat3", "mat4"]
+
+_DTYPE_TO_COMPOSITE_KEY = {
+ dtypes.ihvec2: "short2",
+ dtypes.ihvec3: "short3",
+ dtypes.ihvec4: "short4",
+ dtypes.uhvec2: "ushort2",
+ dtypes.uhvec3: "ushort3",
+ dtypes.uhvec4: "ushort4",
+ dtypes.ivec2: "int2",
+ dtypes.ivec3: "int3",
+ dtypes.ivec4: "int4",
+ dtypes.uvec2: "uint2",
+ dtypes.uvec3: "uint3",
+ dtypes.uvec4: "uint4",
+ dtypes.hvec2: "half2",
+ dtypes.hvec3: "half3",
+ dtypes.hvec4: "half4",
+ dtypes.complex32: "half2",
+ dtypes.complex64: "float2",
+ dtypes.complex128: "double2",
+ dtypes.vec2: "float2",
+ dtypes.vec3: "float3",
+ dtypes.vec4: "float4",
+ dtypes.dvec2: "double2",
+ dtypes.dvec3: "double3",
+ dtypes.dvec4: "double4",
+ dtypes.mat2: "mat2",
+ dtypes.mat3: "mat3",
+ dtypes.mat4: "mat4",
+}
+
+_SCALAR_TYPE_NAMES = {
+ dtypes.int16: "short",
+ dtypes.uint16: "unsigned short",
+ dtypes.int32: "int",
+ dtypes.uint32: "unsigned int",
+ dtypes.int64: "long long",
+ dtypes.uint64: "unsigned long long",
+ dtypes.float16: "__half",
+ dtypes.float32: "float",
+ dtypes.float64: "double",
+}
+
+_FLOAT_VEC_DTYPES: FrozenSet[dtypes.dtype] = frozenset(
+ {
+ dtypes.complex32,
+ dtypes.complex64,
+ dtypes.complex128,
+ dtypes.hvec2,
+ dtypes.hvec3,
+ dtypes.hvec4,
+ dtypes.vec2,
+ dtypes.vec3,
+ dtypes.vec4,
+ dtypes.dvec2,
+ dtypes.dvec3,
+ dtypes.dvec4,
+ }
+)
+
+_FLOAT_VEC_HELPER_SUFFIX_MAP = {
+ dtypes.hvec2: "half2",
+ dtypes.hvec3: "half3",
+ dtypes.hvec4: "half4",
+ dtypes.complex32: "half2",
+ dtypes.complex64: "float2",
+ dtypes.complex128: "double2",
+ dtypes.vec2: "float2",
+ dtypes.vec3: "float3",
+ dtypes.vec4: "float4",
+ dtypes.dvec2: "double2",
+ dtypes.dvec3: "double3",
+ dtypes.dvec4: "double4",
+}
diff --git a/vkdispatch/codegen/backends/glsl.py b/vkdispatch/codegen/backends/glsl.py
index e0c82738..c2187e06 100644
--- a/vkdispatch/codegen/backends/glsl.py
+++ b/vkdispatch/codegen/backends/glsl.py
@@ -1,17 +1,52 @@
-from typing import List, Optional
+from typing import List, Optional, Set
import vkdispatch.base.dtype as dtypes
from .base import CodeGenBackend
+# Map scalar dtypes to GLSL extension names.
+_GLSL_TYPE_EXTENSIONS = {
+ dtypes.float16: "GL_EXT_shader_explicit_arithmetic_types_float16",
+ dtypes.int16: "GL_EXT_shader_explicit_arithmetic_types_int16",
+ dtypes.uint16: "GL_EXT_shader_explicit_arithmetic_types_int16",
+ dtypes.int64: "GL_ARB_gpu_shader_int64",
+ dtypes.uint64: "GL_ARB_gpu_shader_int64",
+ dtypes.float64: "GL_ARB_gpu_shader_fp64",
+}
+
class GLSLBackend(CodeGenBackend):
name = "glsl"
+ def __init__(self) -> None:
+ super().__init__()
+ self._needed_extensions: Set[str] = set()
+
+ def reset_state(self) -> None:
+ self._needed_extensions = set()
+
+ def _track_type_extension(self, var_type: dtypes.dtype) -> None:
+ """Record the GLSL extension required by *var_type* (if any)."""
+ scalar = var_type
+ if dtypes.is_vector(var_type) or dtypes.is_matrix(var_type):
+ scalar = var_type.scalar
+ elif dtypes.is_complex(var_type):
+ scalar = var_type.child_type
+ ext = _GLSL_TYPE_EXTENSIONS.get(scalar)
+ if ext is not None:
+ self._needed_extensions.add(ext)
+
def type_name(self, var_type: dtypes.dtype) -> str:
+ self._track_type_extension(var_type)
return var_type.glsl_type
- def constructor(self, var_type: dtypes.dtype, args: List[str]) -> str:
+ def constructor(
+ self,
+ var_type: dtypes.dtype,
+ args: List[str],
+ arg_types: Optional[List[Optional[dtypes.dtype]]] = None,
+ ) -> str:
+ _ = arg_types
return f"{self.type_name(var_type)}({', '.join(args)})"
def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str:
@@ -24,10 +59,17 @@ def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str:
if enable_printf:
header += "#extension GL_EXT_debug_printf : require\n"
- return header
+ ext_block = ""
+ for ext in sorted(self._needed_extensions):
+ ext_line = f"#extension {ext} : require\n"
+ if ext_line not in header:
+ ext_block += ext_line
+
+ return header + ext_block
def make_source(self, header: str, body: str, x: int, y: int, z: int) -> str:
layout_str = f"layout(local_size_x = {x}, local_size_y = {y}, local_size_z = {z}) in;"
+
return f"{header}\n{layout_str}\n{body}"
def constant_namespace(self) -> str:
@@ -63,6 +105,18 @@ def inf_f32_expr(self) -> str:
def ninf_f32_expr(self) -> str:
return "uintBitsToFloat(0xFF800000)"
+ def inf_f64_expr(self) -> str:
+ return "packDouble2x32(uvec2(0x00000000u, 0x7FF00000u))"
+
+ def ninf_f64_expr(self) -> str:
+ return "packDouble2x32(uvec2(0x00000000u, 0xFFF00000u))"
+
+ def inf_f16_expr(self) -> str:
+ return "float16_t(uintBitsToFloat(0x7F800000))"
+
+ def ninf_f16_expr(self) -> str:
+ return "float16_t(uintBitsToFloat(0xFF800000))"
+
def float_bits_to_int_expr(self, var_expr: str) -> str:
return f"floatBitsToInt({var_expr})"
@@ -123,25 +177,32 @@ def memory_barrier_image_statement(self) -> str:
def group_memory_barrier_statement(self) -> str:
return "groupMemoryBarrier();"
- def subgroup_add_expr(self, arg_expr: str) -> str:
+ def subgroup_add_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
return f"subgroupAdd({arg_expr})"
- def subgroup_mul_expr(self, arg_expr: str) -> str:
+ def subgroup_mul_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
return f"subgroupMul({arg_expr})"
- def subgroup_min_expr(self, arg_expr: str) -> str:
+ def subgroup_min_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
return f"subgroupMin({arg_expr})"
- def subgroup_max_expr(self, arg_expr: str) -> str:
+ def subgroup_max_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
return f"subgroupMax({arg_expr})"
- def subgroup_and_expr(self, arg_expr: str) -> str:
+ def subgroup_and_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
return f"subgroupAnd({arg_expr})"
- def subgroup_or_expr(self, arg_expr: str) -> str:
+ def subgroup_or_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
return f"subgroupOr({arg_expr})"
- def subgroup_xor_expr(self, arg_expr: str) -> str:
+ def subgroup_xor_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str:
+ _ = arg_type
return f"subgroupXor({arg_expr})"
def subgroup_elect_expr(self) -> str:
@@ -166,3 +227,9 @@ def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Opti
return f"texture({texture_expr}, {coord_expr})"
return f"texture({texture_expr}, {coord_expr}, {lod_expr})"
+
+ def atomic_add_expr(self, mem_expr: str, value_expr: str, var_type: dtypes.dtype) -> str:
+ if var_type not in (dtypes.int32, dtypes.uint32):
+ raise NotImplementedError(f"GLSL atomic_add only supports int32/uint32, got '{var_type.name}'")
+
+ return f"atomicAdd({mem_expr}, {value_expr})"
diff --git a/vkdispatch/codegen/backends/opencl.py b/vkdispatch/codegen/backends/opencl.py
new file mode 100644
index 00000000..907f0508
--- /dev/null
+++ b/vkdispatch/codegen/backends/opencl.py
@@ -0,0 +1,701 @@
+from typing import List, Optional, Set
+
+import vkdispatch.base.dtype as dtypes
+
+from .base import CodeGenBackend
+
+
+class OpenCLBackend(CodeGenBackend):
+ name = "opencl"
+
+ _SCALAR_TYPE_NAMES = {
+ dtypes.int16: "short",
+ dtypes.uint16: "ushort",
+ dtypes.int32: "int",
+ dtypes.uint32: "uint",
+ dtypes.int64: "long",
+ dtypes.uint64: "ulong",
+ dtypes.float16: "half",
+ dtypes.float32: "float",
+ dtypes.float64: "double",
+ }
+
+ _MATRIX_TYPE_NAMES = {
+ dtypes.mat2: "vkdispatch_mat2",
+ dtypes.mat3: "vkdispatch_mat3",
+ dtypes.mat4: "vkdispatch_mat4",
+ }
+
+ def __init__(self) -> None:
+ self.reset_state()
+
+ def reset_state(self) -> None:
+ self._kernel_params: List[str] = []
+ self._entry_alias_lines: List[str] = []
+ self._shared_buffer_lines: List[str] = []
+ self._matrix_type_usage: Set[int] = set()
+
+ def _register_kernel_param(self, param_decl: str) -> None:
+ if param_decl not in self._kernel_params:
+ self._kernel_params.append(param_decl)
+
+ def _register_alias_line(self, alias_line: str) -> None:
+ self._entry_alias_lines.append(alias_line)
+
+ def _record_matrix_dim(self, dim: int) -> None:
+ if dim not in (2, 3, 4):
+ raise ValueError(f"Unsupported OpenCL matrix dimension '{dim}'")
+ self._matrix_type_usage.add(dim)
+
+ def _record_matrix_type(self, var_type: dtypes.dtype) -> None:
+ if dtypes.is_matrix(var_type):
+ self._record_matrix_dim(var_type.child_count)
+
+ @staticmethod
+ def _matrix_helper_name(dim: int, constructor_kind: str) -> str:
+ return f"vkdispatch_make_mat{dim}_{constructor_kind}"
+
+ def _is_matrix_copy_constructor_arg(self, arg_expr: str, dim: int) -> bool:
+ stripped = arg_expr.strip()
+ mat_type = self._matrix_struct_name(dim)
+
+ if stripped.startswith(f"({mat_type})") or stripped.startswith(f"(({mat_type})"):
+ return True
+
+ if f"vkdispatch_make_mat{dim}_" in stripped:
+ return True
+
+ if f"vkdispatch_mat{dim}_" in stripped:
+ return True
+
+ return False
+
+ @classmethod
+ def _scalar_type_name(cls, scalar_type: dtypes.dtype) -> str:
+ type_name = cls._SCALAR_TYPE_NAMES.get(scalar_type)
+ if type_name is None:
+ raise ValueError(f"Unsupported OpenCL scalar type mapping for '{scalar_type.name}'")
+ return type_name
+
+ def type_name(self, var_type: dtypes.dtype) -> str:
+ if dtypes.is_scalar(var_type):
+ return self._scalar_type_name(var_type)
+
+ if dtypes.is_vector(var_type):
+ return f"{self._scalar_type_name(var_type.scalar)}{var_type.child_count}"
+
+ if dtypes.is_complex(var_type):
+ return f"{self._scalar_type_name(var_type.child_type)}2"
+
+ if dtypes.is_matrix(var_type):
+ self._record_matrix_type(var_type)
+ matrix_name = self._MATRIX_TYPE_NAMES.get(var_type)
+ if matrix_name is None:
+ raise ValueError(f"Unsupported OpenCL matrix type mapping for '{var_type.name}'")
+ return matrix_name
+
+ raise ValueError(f"Unsupported OpenCL type mapping for '{var_type.name}'")
+
+ def constructor(
+ self,
+ var_type: dtypes.dtype,
+ args: List[str],
+ arg_types: Optional[List[Optional[dtypes.dtype]]] = None,
+ ) -> str:
+ target_type = self.type_name(var_type)
+
+ if dtypes.is_scalar(var_type):
+ assert len(args) > 0, f"Constructor for scalar type '{var_type.name}' needs at least one argument."
+ return f"(({target_type})({args[0]}))"
+
+ if dtypes.is_matrix(var_type):
+ dim = var_type.child_count
+ assert len(args) in (1, dim, dim * dim), (
+ f"Constructor for matrix type '{var_type.name}' needs 1, {dim}, or {dim * dim} arguments."
+ )
+ if len(args) == 1:
+ single_arg = args[0]
+ helper_name = self._matrix_helper_name(
+ dim,
+ "copy" if self._is_matrix_copy_constructor_arg(single_arg, dim) else "scalar",
+ )
+ return f"{helper_name}({single_arg})"
+
+ if len(args) == dim:
+ return f"{self._matrix_helper_name(dim, 'cols')}({', '.join(args)})"
+
+ return f"{self._matrix_helper_name(dim, 'flat')}({', '.join(args)})"
+
+ # NVIDIA's OpenCL frontend rejects direct vector casts between different
+ # vector base types (e.g. uint2 -> float2). Use convert_* builtins when
+ # we know this is a vector/complex-to-vector/complex conversion.
+ if (
+ len(args) == 1
+ and arg_types is not None
+ and len(arg_types) == 1
+ and arg_types[0] is not None
+ and (dtypes.is_vector(var_type) or dtypes.is_complex(var_type))
+ and (dtypes.is_vector(arg_types[0]) or dtypes.is_complex(arg_types[0]))
+ ):
+ return f"convert_{target_type}({args[0]})"
+
+ return f"(({target_type})({', '.join(args)}))"
+
+ def component_access_expr(self, expr: str, component: str, base_type: dtypes.dtype) -> str:
+ if dtypes.is_scalar(base_type) and component == "x":
+ return expr
+ return super().component_access_expr(expr, component, base_type)
+
+ def buffer_component_expr(
+ self,
+ scalar_buffer_expr: str,
+ base_type: dtypes.dtype,
+ element_index_expr: str,
+ component_index_expr: str,
+ ) -> Optional[str]:
+ if dtypes.is_complex(base_type):
+ component_count = base_type.child_count
+ elif dtypes.is_vector(base_type):
+ component_count = base_type.child_count
+ else:
+ return None
+
+ return (
+ f"{scalar_buffer_expr}["
+ f"(({element_index_expr}) * {component_count}) + ({component_index_expr})"
+ f"]"
+ )
+
+ def _cast_math_arg(self, arg_type: dtypes.dtype, arg_expr: str) -> str:
+ if dtypes.is_scalar(arg_type) or dtypes.is_vector(arg_type) or dtypes.is_complex(arg_type):
+ return self.constructor(arg_type, [arg_expr], arg_types=[arg_type])
+
+ return arg_expr
+
+ def math_func_name(self, func_name: str, var_type: dtypes.dtype) -> str:
+ func_name_dict = {
+ "sin": "native_sin",
+ "cos": "native_cos",
+ "tan": "native_tan",
+ "sqrt": "native_sqrt",
+ "exp": "native_exp",
+ "exp2": "native_exp2",
+ "log": "native_log",
+ "log2": "native_log2",
+ }
+
+ if func_name in func_name_dict:
+ return func_name_dict[func_name]
+
+ return func_name
+
+ def unary_math_expr(self, func_name: str, arg_type: dtypes.dtype, arg_expr: str) -> str:
+ mapped = self.math_func_name(func_name, arg_type)
+ return f"{mapped}({self._cast_math_arg(arg_type, arg_expr)})"
+
+ def binary_math_expr(
+ self,
+ func_name: str,
+ lhs_type: dtypes.dtype,
+ lhs_expr: str,
+ rhs_type: dtypes.dtype,
+ rhs_expr: str,
+ ) -> str:
+ mapped = self.math_func_name(func_name, lhs_type)
+ lhs_cast_expr = self._cast_math_arg(lhs_type, lhs_expr)
+ rhs_cast_expr = self._cast_math_arg(rhs_type, rhs_expr)
+ return f"{mapped}({lhs_cast_expr}, {rhs_cast_expr})"
+
+ def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str:
+ _ = enable_subgroup_ops
+ _ = enable_printf
+ header = (
+ "// OpenCL C source generated by vkdispatch\n"
+ "#ifdef cl_khr_global_int32_base_atomics\n"
+ "#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable\n"
+ "#endif\n"
+ "#ifdef cl_khr_local_int32_base_atomics\n"
+ "#pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable\n"
+ "#endif\n"
+ "#ifdef cl_khr_fp64\n"
+ "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
+ "#endif\n"
+ "#ifdef cl_khr_fp16\n"
+ "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
+ "#endif\n"
+ )
+ matrix_helpers = self._emit_matrix_helpers()
+ if len(matrix_helpers) > 0:
+ header += f"\n{matrix_helpers}\n"
+ return header
+
+ def _emit_matrix_helpers(self) -> str:
+ if len(self._matrix_type_usage) == 0:
+ return ""
+
+ sections: List[str] = []
+ if 3 in self._matrix_type_usage:
+ sections.append(
+ "typedef struct __attribute__((packed)) vkdispatch_packed_float3 {\n"
+ " float x;\n"
+ " float y;\n"
+ " float z;\n"
+ "} vkdispatch_packed_float3;\n"
+ "static inline float3 vkdispatch_unpack_float3(vkdispatch_packed_float3 v) { return (float3)(v.x, v.y, v.z); }\n"
+ "static inline vkdispatch_packed_float3 vkdispatch_pack_float3(float3 v) {\n"
+ " vkdispatch_packed_float3 out = {v.x, v.y, v.z};\n"
+ " return out;\n"
+ "}"
+ )
+
+ for dim in sorted(self._matrix_type_usage):
+ sections.append(self._emit_matrix_helpers_for_dim(dim))
+
+ return "\n\n".join(sections)
+
+ @staticmethod
+ def _vector_components(dim: int) -> List[str]:
+ return list("xyzw"[:dim])
+
+ @staticmethod
+ def _matrix_struct_name(dim: int) -> str:
+ return f"vkdispatch_mat{dim}"
+
+ @staticmethod
+ def _vector_type_name(dim: int) -> str:
+ return f"float{dim}"
+
+ def _matrix_col_expr(self, mat_expr: str, col: int, dim: int) -> str:
+ if dim == 3:
+ return f"vkdispatch_unpack_float3({mat_expr}.c{col})"
+ return f"{mat_expr}.c{col}"
+
+ def _matrix_col_assign_stmt(self, target_expr: str, col: int, value_expr: str, dim: int) -> str:
+ if dim == 3:
+ return f"{target_expr}.c{col} = vkdispatch_pack_float3({value_expr});"
+ return f"{target_expr}.c{col} = {value_expr};"
+
+ def _emit_matrix_helpers_for_dim(self, dim: int) -> str:
+ mat_type = self._matrix_struct_name(dim)
+ vec_type = self._vector_type_name(dim)
+ comps = self._vector_components(dim)
+ scalar_helper_name = self._matrix_helper_name(dim, "scalar")
+ copy_helper_name = self._matrix_helper_name(dim, "copy")
+ cols_helper_name = self._matrix_helper_name(dim, "cols")
+ flat_helper_name = self._matrix_helper_name(dim, "flat")
+
+ lines: List[str] = []
+
+ if dim == 3:
+ lines.append(
+ "typedef struct __attribute__((packed)) vkdispatch_mat3 {\n"
+ " vkdispatch_packed_float3 c0;\n"
+ " vkdispatch_packed_float3 c1;\n"
+ " vkdispatch_packed_float3 c2;\n"
+ "} vkdispatch_mat3;"
+ )
+ else:
+ cols = "\n".join([f" {vec_type} c{i};" for i in range(dim)])
+ lines.append(f"typedef struct {mat_type} {{\n{cols}\n}} {mat_type};")
+
+ # Constructors.
+ lines.append(f"static inline {mat_type} {scalar_helper_name}(float s) {{")
+ lines.append(f" {mat_type} out;")
+ for col_idx in range(dim):
+ diag_values = [("s" if row_idx == col_idx else "0.0f") for row_idx in range(dim)]
+ vec_expr = f"({vec_type})(" + ", ".join(diag_values) + ")"
+ lines.append(f" {self._matrix_col_assign_stmt('out', col_idx, vec_expr, dim)}")
+ lines.append(" return out;")
+ lines.append("}")
+
+ lines.append(f"static inline {mat_type} {copy_helper_name}({mat_type} m) {{ return m; }}")
+
+ col_args = ", ".join([f"{vec_type} c{i}" for i in range(dim)])
+ lines.append(f"static inline {mat_type} {cols_helper_name}({col_args}) {{")
+ lines.append(f" {mat_type} out;")
+ for col_idx in range(dim):
+ lines.append(f" {self._matrix_col_assign_stmt('out', col_idx, f'c{col_idx}', dim)}")
+ lines.append(" return out;")
+ lines.append("}")
+
+ flat_names = [f"m{col}{row}" for col in range(dim) for row in range(dim)]
+ flat_args = ", ".join([f"float {name}" for name in flat_names])
+ lines.append(f"static inline {mat_type} {flat_helper_name}({flat_args}) {{")
+ lines.append(f" return {cols_helper_name}(")
+ for col_idx in range(dim):
+ values = [f"m{col_idx}{row_idx}" for row_idx in range(dim)]
+ suffix = "," if col_idx < dim - 1 else ""
+ lines.append(f" ({vec_type})({', '.join(values)}){suffix}")
+ lines.append(" );")
+ lines.append("}")
+
+ # Unary negation.
+ lines.append(f"static inline {mat_type} vkdispatch_mat{dim}_neg({mat_type} a) {{")
+ lines.append(f" {mat_type} out;")
+ for col_idx in range(dim):
+ col_expr = self._matrix_col_expr("a", col_idx, dim)
+ lines.append(
+ f" {self._matrix_col_assign_stmt('out', col_idx, f'-{col_expr}', dim)}"
+ )
+ lines.append(" return out;")
+ lines.append("}")
+
+ # Matrix +/- matrix.
+ for op_name, op_symbol in (("add", "+"), ("sub", "-")):
+ lines.append(
+ f"static inline {mat_type} vkdispatch_mat{dim}_{op_name}_mm({mat_type} a, {mat_type} b) {{"
+ )
+ lines.append(f" {mat_type} out;")
+ for col_idx in range(dim):
+ lhs_col = self._matrix_col_expr("a", col_idx, dim)
+ rhs_col = self._matrix_col_expr("b", col_idx, dim)
+ lines.append(
+ f" {self._matrix_col_assign_stmt('out', col_idx, f'{lhs_col} {op_symbol} {rhs_col}', dim)}"
+ )
+ lines.append(" return out;")
+ lines.append("}")
+
+ # Matrix/scalar and scalar/matrix arithmetic.
+ for op_name, op_symbol in (("add", "+"), ("sub", "-"), ("mul", "*"), ("div", "/")):
+ lines.append(
+ f"static inline {mat_type} vkdispatch_mat{dim}_{op_name}_ms({mat_type} a, float b) {{"
+ )
+ lines.append(f" {mat_type} out;")
+ for col_idx in range(dim):
+ lhs_col = self._matrix_col_expr("a", col_idx, dim)
+ lines.append(
+ f" {self._matrix_col_assign_stmt('out', col_idx, f'{lhs_col} {op_symbol} b', dim)}"
+ )
+ lines.append(" return out;")
+ lines.append("}")
+
+ lines.append(
+ f"static inline {mat_type} vkdispatch_mat{dim}_{op_name}_sm(float a, {mat_type} b) {{"
+ )
+ lines.append(f" {mat_type} out;")
+ for col_idx in range(dim):
+ rhs_col = self._matrix_col_expr("b", col_idx, dim)
+ lines.append(
+ f" {self._matrix_col_assign_stmt('out', col_idx, f'a {op_symbol} {rhs_col}', dim)}"
+ )
+ lines.append(" return out;")
+ lines.append("}")
+
+ # Matrix/vector product (column-major, GLSL-style): m * v.
+ mat_vec_terms = [f"({self._matrix_col_expr('m', i, dim)} * v.{comps[i]})" for i in range(dim)]
+ lines.append(f"static inline {vec_type} vkdispatch_mat{dim}_mul_mv({mat_type} m, {vec_type} v) {{")
+ lines.append(f" return {' + '.join(mat_vec_terms)};")
+ lines.append("}")
+
+ # Vector/matrix product (column-major, GLSL-style): v * m.
+ lines.append(f"static inline {vec_type} vkdispatch_mat{dim}_mul_vm({vec_type} v, {mat_type} m) {{")
+ for col_idx in range(dim):
+ lines.append(f" {vec_type} col{col_idx} = {self._matrix_col_expr('m', col_idx, dim)};")
+ row_exprs = []
+ for col_idx in range(dim):
+ terms = [f"(v.{comps[row_idx]} * col{col_idx}.{comps[row_idx]})" for row_idx in range(dim)]
+ row_exprs.append(" + ".join(terms))
+ lines.append(f" return ({vec_type})({', '.join(row_exprs)});")
+ lines.append("}")
+
+ return "\n".join(lines)
+
+ def arithmetic_unary_expr(self, op: str, var_type: dtypes.dtype, var_expr: str) -> Optional[str]:
+ if op == "-" and dtypes.is_matrix(var_type):
+ dim = var_type.child_count
+ self._record_matrix_dim(dim)
+ return f"vkdispatch_mat{dim}_neg({var_expr})"
+ return None
+
+ def arithmetic_binary_expr(
+ self,
+ op: str,
+ lhs_type: dtypes.dtype,
+ lhs_expr: str,
+ rhs_type: dtypes.dtype,
+ rhs_expr: str,
+ ) -> Optional[str]:
+ if not (dtypes.is_matrix(lhs_type) or dtypes.is_matrix(rhs_type)):
+ return None
+
+ if op not in ("+", "-", "*", "/"):
+ raise NotImplementedError(
+ f"OpenCL matrix arithmetic override does not support operator '{op}' "
+ f"for ({lhs_type.name}, {rhs_type.name})."
+ )
+
+ if dtypes.is_matrix(lhs_type):
+ dim = lhs_type.child_count
+ if dtypes.is_matrix(rhs_type):
+ if rhs_type.child_count != dim:
+ raise ValueError(
+ f"OpenCL matrix arithmetic requires matching dimensions, got '{lhs_type.name}' and '{rhs_type.name}'."
+ )
+ if op not in ("+", "-"):
+ raise NotImplementedError(
+ f"OpenCL matrix arithmetic does not support operator '{op}' for two matrices."
+ )
+ self._record_matrix_dim(dim)
+ return f"vkdispatch_mat{dim}_{'add' if op == '+' else 'sub'}_mm({lhs_expr}, {rhs_expr})"
+
+ if dtypes.is_scalar(rhs_type):
+ self._record_matrix_dim(dim)
+ suffix = "add" if op == "+" else "sub" if op == "-" else "mul" if op == "*" else "div"
+ return f"vkdispatch_mat{dim}_{suffix}_ms({lhs_expr}, {rhs_expr})"
+
+ if dtypes.is_vector(rhs_type) and op == "*":
+ if rhs_type.child_count != dim or rhs_type.scalar != dtypes.float32:
+ raise ValueError(
+ f"OpenCL matrix/vector multiplication requires float32 vec{dim}, got '{rhs_type.name}'."
+ )
+ self._record_matrix_dim(dim)
+ return f"vkdispatch_mat{dim}_mul_mv({lhs_expr}, {rhs_expr})"
+
+ raise NotImplementedError(
+ f"Unsupported OpenCL matrix arithmetic for ({lhs_type.name}, {rhs_type.name}) with operator '{op}'."
+ )
+
+ # lhs is not matrix; rhs is matrix
+ dim = rhs_type.child_count
+ if dtypes.is_scalar(lhs_type):
+ self._record_matrix_dim(dim)
+ suffix = "add" if op == "+" else "sub" if op == "-" else "mul" if op == "*" else "div"
+ return f"vkdispatch_mat{dim}_{suffix}_sm({lhs_expr}, {rhs_expr})"
+
+ if dtypes.is_vector(lhs_type) and op == "*":
+ if lhs_type.child_count != dim or lhs_type.scalar != dtypes.float32:
+ raise ValueError(
+ f"OpenCL vector/matrix multiplication requires float32 vec{dim}, got '{lhs_type.name}'."
+ )
+ self._record_matrix_dim(dim)
+ return f"vkdispatch_mat{dim}_mul_vm({lhs_expr}, {rhs_expr})"
+
+ raise NotImplementedError(
+ f"Unsupported OpenCL matrix arithmetic for ({lhs_type.name}, {rhs_type.name}) with operator '{op}'."
+ )
+
+ def make_source(self, header: str, body: str, x: int, y: int, z: int) -> str:
+ expected_size_header = (
+ f"// Expected local size: ({x}, {y}, {z})\n"
+ f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X {x}\n"
+ f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y {y}\n"
+ f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Z {z}\n"
+ )
+ workgroup_attribute = f"__attribute__((reqd_work_group_size({x}, {y}, {z})))"
+ if "__kernel void vkdispatch_main" in body:
+ body = body.replace(
+ "__kernel void vkdispatch_main",
+ f"{workgroup_attribute}\n__kernel void vkdispatch_main",
+ 1,
+ )
+ else:
+ body = f"{workgroup_attribute}\n{body}"
+
+ return f"{expected_size_header}\n{header}\n{body}"
+
+ def constant_namespace(self) -> str:
+ return "UBO"
+
+ def variable_namespace(self) -> str:
+ return "PC"
+
+ def exec_bounds_guard(self, exec_count_expr: str) -> str:
+ gid_expr = f"({self.global_invocation_id_expr()})"
+ exec_expr = f"({exec_count_expr})"
+ return (
+ f"if ({self.component_access_expr(exec_expr, 'x', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'x', dtypes.uvec3)} || "
+ f"{self.component_access_expr(exec_expr, 'y', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'y', dtypes.uvec3)} || "
+ f"{self.component_access_expr(exec_expr, 'z', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'z', dtypes.uvec3)}) {{ return; }}\n"
+ )
+
+ def shared_buffer_declaration(self, var_type: dtypes.dtype, name: str, size: int) -> str:
+ self._shared_buffer_lines.append(f"__local {self.type_name(var_type)} {name}[{size}];")
+ # OpenCL requires __local storage declarations at kernel/function scope.
+ return ""
+
+ def uniform_block_declaration(self, contents: str) -> str:
+ self._register_kernel_param("__global const UniformObjectBuffer* vkdispatch_uniform_ptr")
+ self._register_alias_line("const UniformObjectBuffer UBO = *vkdispatch_uniform_ptr;")
+ return f"\ntypedef struct UniformObjectBuffer {{\n{contents}\n}} UniformObjectBuffer;\n"
+
+ def storage_buffer_declaration(self, binding: int, var_type: dtypes.dtype, name: str) -> str:
+ struct_name = f"Buffer{binding}"
+ param_name = f"vkdispatch_binding_{binding}_ptr"
+ data_type = self.type_name(var_type)
+ self._register_kernel_param(f"__global {data_type}* {param_name}")
+ if dtypes.is_complex(var_type):
+ scalar_type = self.type_name(var_type.child_type)
+ self._register_alias_line(
+ f"__global {scalar_type}* {name}_scalar = (__global {scalar_type}*)({param_name});"
+ )
+ elif dtypes.is_vector(var_type):
+ scalar_type = self.type_name(var_type.scalar)
+ self._register_alias_line(
+ f"__global {scalar_type}* {name}_scalar = (__global {scalar_type}*)({param_name});"
+ )
+ self._register_alias_line(f"{struct_name} {name} = {{{param_name}}};")
+ return f"typedef struct {struct_name} {{ __global {data_type}* data; }} {struct_name};\n"
+
+ def sampler_declaration(self, binding: int, dimensions: int, name: str) -> str:
+ _ = (binding, dimensions, name)
+ raise NotImplementedError("image/sampler unsupported in OpenCL backend")
+
+ def push_constant_declaration(self, contents: str) -> str:
+ self._register_kernel_param("const PushConstant vkdispatch_pc_value")
+ self._register_alias_line("const PushConstant PC = vkdispatch_pc_value;")
+ return f"\ntypedef struct PushConstant {{\n{contents}\n}} PushConstant;\n"
+
+ def entry_point(self, body_contents: str) -> str:
+ params = ", ".join(self._kernel_params)
+ alias_block = ""
+ shared_block = ""
+ for line in self._shared_buffer_lines:
+ shared_block += f" {line}\n"
+ for line in self._entry_alias_lines:
+ alias_block += f" {line}\n"
+
+ return (
+ f"__kernel void vkdispatch_main({params}) {{\n"
+ f"{shared_block}"
+ f"{alias_block}"
+ f"{body_contents}"
+ f"}}\n"
+ )
+
+ def inf_f32_expr(self) -> str:
+ return "as_float((uint)0x7F800000u)"
+
+ def ninf_f32_expr(self) -> str:
+ return "as_float((uint)0xFF800000u)"
+
+ def inf_f64_expr(self) -> str:
+ return "as_double((ulong)0x7FF0000000000000UL)"
+
+ def ninf_f64_expr(self) -> str:
+ return "as_double((ulong)0xFFF0000000000000UL)"
+
+ def inf_f16_expr(self) -> str:
+ return "as_half((ushort)0x7C00u)"
+
+ def ninf_f16_expr(self) -> str:
+ return "as_half((ushort)0xFC00u)"
+
+ def float_bits_to_int_expr(self, var_expr: str) -> str:
+ return f"as_int({var_expr})"
+
+ def float_bits_to_uint_expr(self, var_expr: str) -> str:
+ return f"as_uint({var_expr})"
+
+ def int_bits_to_float_expr(self, var_expr: str) -> str:
+ return f"as_float({var_expr})"
+
+ def uint_bits_to_float_expr(self, var_expr: str) -> str:
+ return f"as_float({var_expr})"
+
+ def global_invocation_id_expr(self) -> str:
+ return "((uint3)((uint)get_global_id(0), (uint)get_global_id(1), (uint)get_global_id(2)))"
+
+ def local_invocation_id_expr(self) -> str:
+ return "((uint3)((uint)get_local_id(0), (uint)get_local_id(1), (uint)get_local_id(2)))"
+
+ def local_invocation_index_expr(self) -> str:
+ return (
+ "((uint)(get_local_id(0) + "
+ "get_local_size(0) * (get_local_id(1) + get_local_size(1) * get_local_id(2))))"
+ )
+
+ def workgroup_id_expr(self) -> str:
+ return "((uint3)((uint)get_group_id(0), (uint)get_group_id(1), (uint)get_group_id(2)))"
+
+ def workgroup_size_expr(self) -> str:
+ return "((uint3)((uint)get_local_size(0), (uint)get_local_size(1), (uint)get_local_size(2)))"
+
+ def num_workgroups_expr(self) -> str:
+ return "((uint3)((uint)get_num_groups(0), (uint)get_num_groups(1), (uint)get_num_groups(2)))"
+
+ def num_subgroups_expr(self) -> str:
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def subgroup_id_expr(self) -> str:
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def subgroup_size_expr(self) -> str:
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def subgroup_invocation_id_expr(self) -> str:
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def barrier_statement(self) -> str:
+ return "barrier(CLK_LOCAL_MEM_FENCE);"
+
+ def memory_barrier_statement(self) -> str:
+ return "mem_fence(CLK_LOCAL_MEM_FENCE);"
+
+ def memory_barrier_buffer_statement(self) -> str:
+ return "mem_fence(CLK_GLOBAL_MEM_FENCE);"
+
+ def memory_barrier_shared_statement(self) -> str:
+ return "mem_fence(CLK_LOCAL_MEM_FENCE);"
+
+ def memory_barrier_image_statement(self) -> str:
+ raise NotImplementedError("image/sampler unsupported in OpenCL backend")
+
+ def group_memory_barrier_statement(self) -> str:
+ return "mem_fence(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE);"
+
+ def subgroup_add_expr(self, arg_expr: str) -> str:
+ _ = arg_expr
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def subgroup_mul_expr(self, arg_expr: str) -> str:
+ _ = arg_expr
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def subgroup_min_expr(self, arg_expr: str) -> str:
+ _ = arg_expr
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def subgroup_max_expr(self, arg_expr: str) -> str:
+ _ = arg_expr
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def subgroup_and_expr(self, arg_expr: str) -> str:
+ _ = arg_expr
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def subgroup_or_expr(self, arg_expr: str) -> str:
+ _ = arg_expr
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def subgroup_xor_expr(self, arg_expr: str) -> str:
+ _ = arg_expr
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def subgroup_elect_expr(self) -> str:
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def subgroup_barrier_statement(self) -> str:
+ raise NotImplementedError("subgroup operations unsupported in OpenCL backend")
+
+ def printf_statement(self, fmt: str, args: List[str]) -> str:
+ if len(args) == 0:
+ return f'printf("{fmt}");'
+ return f'printf("{fmt}", {", ".join(args)});'
+
+ def texture_size_expr(self, texture_expr: str, lod: int, dimensions: int) -> str:
+ _ = (texture_expr, lod, dimensions)
+ raise NotImplementedError("image/sampler unsupported in OpenCL backend")
+
+ def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Optional[str] = None) -> str:
+ _ = (texture_expr, coord_expr, lod_expr)
+ raise NotImplementedError("image/sampler unsupported in OpenCL backend")
+
+ def mark_texture_sample_dimension(self, dimensions: int) -> None:
+ _ = dimensions
+ raise NotImplementedError("image/sampler unsupported in OpenCL backend")
+
+ def atomic_add_expr(self, mem_expr: str, value_expr: str, var_type: dtypes.dtype) -> str:
+ if var_type not in (dtypes.int32, dtypes.uint32):
+ raise NotImplementedError(f"OpenCL atomic_add only supports int32/uint32, got '{var_type.name}'")
+
+ return f"atomic_add(&({mem_expr}), {value_expr})"
diff --git a/vkdispatch/codegen/builder.py b/vkdispatch/codegen/builder.py
index c3214976..44e50e48 100644
--- a/vkdispatch/codegen/builder.py
+++ b/vkdispatch/codegen/builder.py
@@ -17,6 +17,15 @@
from .variables.variables import BaseVariable, ShaderVariable, ScaledAndOfftsetIntVariable
from .variables.bound_variables import BufferVariable, ImageVariable
+_PUSH_CONSTANT_UNSUPPORTED_BACKENDS = set()
+
+
+def _push_constant_not_supported_error(backend_name: str) -> str:
+ return (
+ f"Push Constants are not supported for the {backend_name.upper()} backend. "
+ "Use Const instead."
+ )
+
@dataclasses.dataclass
class SharedBuffer:
"""
@@ -61,15 +70,22 @@ class ShaderDescription:
uniform_structure: List[StructElement]
binding_type_list: List[BindingType]
binding_access: List[Tuple[bool, bool]] # List of tuples indicating read and write access for each binding
- exec_count_name: str
+ exec_count_name: Optional[str]
+ resource_binding_base: int
backend: Optional[CodeGenBackend] = None
def make_source(self, x: int, y: int, z: int) -> str:
if self.backend is None:
layout_str = f"layout(local_size_x = {x}, local_size_y = {y}, local_size_z = {z}) in;"
- return f"{self.header}\n{layout_str}\n{self.body}"
+ shader_source = f"{self.header}\n{layout_str}\n{self.body}"
+ else:
+ shader_source = self.backend.make_source(self.header, self.body, x, y, z)
+
+ # ff = open(f"sources/{self.name}.comp", "w")
+ # ff.write(shader_source)
+ # ff.close()
- return self.backend.make_source(self.header, self.body, x, y, z)
+ return shader_source
def __repr__(self):
description_string = ""
@@ -123,7 +139,6 @@ class ShaderBuilder(ShaderWriter):
pc_struct: StructBuilder
uniform_struct: StructBuilder
exec_count: Optional[ShaderVariable]
- pre_header: str
flags: ShaderFlags
backend: CodeGenBackend
@@ -140,11 +155,6 @@ def __init__(self,
else:
# Use the selected backend type while keeping per-builder backend state isolated.
self.backend = get_codegen_backend().__class__()
-
- self.pre_header = self.backend.pre_header(
- enable_subgroup_ops=not (self.flags & ShaderFlags.NO_SUBGROUP_OPS),
- enable_printf=not (self.flags & ShaderFlags.NO_PRINTF)
- )
self.reset()
@@ -159,9 +169,10 @@ def reset(self) -> None:
self.shared_buffers = []
self.scope_num = 1
- self.exec_count = self.declare_constant(dtypes.uvec4, var_name="exec_count")
-
+ self.exec_count = None
+
if not (self.flags & ShaderFlags.NO_EXEC_BOUNDS):
+ self.exec_count = self.declare_constant(dtypes.uvec4, var_name="exec_count")
self.append_contents(self.backend.exec_bounds_guard(self.exec_count.resolve()))
def new_var(self,
@@ -211,6 +222,9 @@ def declare_constant(self, var_type: dtypes.dtype, count: int = 1, var_name: Opt
return new_var
def declare_variable(self, var_type: dtypes.dtype, count: int = 1, var_name: Optional[str] = None):
+ if self.backend.name in _PUSH_CONSTANT_UNSUPPORTED_BACKENDS:
+ raise NotImplementedError(_push_constant_not_supported_error(self.backend.name))
+
if var_name is None:
var_name = self.new_name()
@@ -235,6 +249,10 @@ def declare_buffer(self, var_type: dtypes.dtype, var_name: Optional[str] = None)
buffer_name = f"buf{self.binding_count}" if var_name is None else var_name
shape_name = f"{buffer_name}_shape"
+ scalar_expr = None
+
+ if self.backend.name == "opencl" and (dtypes.is_vector(var_type) or dtypes.is_complex(var_type)):
+ scalar_expr = f"{buffer_name}_scalar"
self.binding_list.append(ShaderBinding(var_type, buffer_name, 0, BindingType.STORAGE_BUFFER))
self.binding_read_access[self.binding_count] = False
@@ -247,13 +265,18 @@ def read_lambda():
def write_lambda():
self.binding_write_access[current_binding_count] = True
+
+ def shape_var_factory():
+ return self.declare_constant(dtypes.ivec4, var_name=shape_name)
return BufferVariable(
var_type,
self.binding_count,
f"{buffer_name}.data",
- self.declare_constant(dtypes.ivec4, var_name=shape_name),
- shape_name,
+ shape_var_factory=shape_var_factory,
+ shape_name=shape_name,
+ scalar_expr=scalar_expr,
+ codegen_backend=self.backend,
read_lambda=read_lambda,
write_lambda=write_lambda
)
@@ -287,12 +310,17 @@ def shared_buffer(self, var_type: dtypes.dtype, size: int, var_name: Optional[st
shape_name = f"{var_name}_shape"
+ def shape_var_factory():
+ return self.declare_constant(dtypes.ivec4, var_name=shape_name)
+
new_var = BufferVariable(
var_type,
-1,
var_name,
- self.declare_constant(dtypes.ivec4, var_name=shape_name),
- shape_name,
+ shape_var_factory=shape_var_factory,
+ shape_name=shape_name,
+ scalar_expr=None,
+ codegen_backend=self.backend,
read_lambda=lambda: None,
write_lambda=lambda: None
)
@@ -316,7 +344,7 @@ def compose_struct_decleration(self, elements: List[StructElement]) -> str:
return "\n".join(declerations)
def build(self, name: str) -> ShaderDescription:
- header = "" + self.pre_header
+ header = ""
for shared_buffer in self.shared_buffers:
header += self.backend.shared_buffer_declaration(
@@ -328,22 +356,28 @@ def build(self, name: str) -> ShaderDescription:
uniform_elements = self.uniform_struct.build()
uniform_decleration_contents = self.compose_struct_decleration(uniform_elements)
- if len(uniform_decleration_contents) > 0:
+ has_uniform_buffer = len(uniform_decleration_contents) > 0
+ if has_uniform_buffer:
header += self.backend.uniform_block_declaration(uniform_decleration_contents)
- binding_type_list = [BindingType.UNIFORM_BUFFER]
- binding_access = [(True, False)] # UBO is read-only
+ binding_base = 1 if has_uniform_buffer else 0
+ binding_type_list = []
+ binding_access = []
+ if has_uniform_buffer:
+ binding_type_list.append(BindingType.UNIFORM_BUFFER)
+ binding_access.append((True, False)) # UBO is read-only
for ii, binding in enumerate(self.binding_list):
+ emitted_binding = ii + binding_base
if binding.binding_type == BindingType.STORAGE_BUFFER:
- header += self.backend.storage_buffer_declaration(ii + 1, binding.dtype, binding.name)
+ header += self.backend.storage_buffer_declaration(emitted_binding, binding.dtype, binding.name)
binding_type_list.append(binding.binding_type)
binding_access.append((
self.binding_read_access[ii + 1],
self.binding_write_access[ii + 1]
))
else:
- header += self.backend.sampler_declaration(ii + 1, binding.dimension, binding.name)
+ header += self.backend.sampler_declaration(emitted_binding, binding.dimension, binding.name)
binding_type_list.append(binding.binding_type)
binding_access.append((
self.binding_read_access[ii + 1],
@@ -355,10 +389,18 @@ def build(self, name: str) -> ShaderDescription:
pc_decleration_contents = self.compose_struct_decleration(pc_elements)
if len(pc_decleration_contents) > 0:
+ assert self.backend.name not in _PUSH_CONSTANT_UNSUPPORTED_BACKENDS, (
+ _push_constant_not_supported_error(self.backend.name)
+ )
header += self.backend.push_constant_declaration(pc_decleration_contents)
+ pre_header = self.backend.pre_header(
+ enable_subgroup_ops=not (self.flags & ShaderFlags.NO_SUBGROUP_OPS),
+ enable_printf=not (self.flags & ShaderFlags.NO_PRINTF)
+ )
+
return ShaderDescription(
- header=header,
+ header=f"{pre_header}{header}",
body=self.backend.entry_point(self.contents),
name=name,
pc_size=self.pc_struct.size,
@@ -366,6 +408,7 @@ def build(self, name: str) -> ShaderDescription:
uniform_structure=uniform_elements,
binding_type_list=[binding.value for binding in binding_type_list],
binding_access=binding_access,
- exec_count_name=self.exec_count.raw_name,
+ exec_count_name=self.exec_count.raw_name if self.exec_count is not None else None,
+ resource_binding_base=binding_base,
backend=self.backend
)
diff --git a/vkdispatch/codegen/functions/atomic_memory.py b/vkdispatch/codegen/functions/atomic_memory.py
index 000350f7..7efb8590 100644
--- a/vkdispatch/codegen/functions/atomic_memory.py
+++ b/vkdispatch/codegen/functions/atomic_memory.py
@@ -1,20 +1,76 @@
+from typing import Any, List
+
+import vkdispatch.base.dtype as dtypes
+
+from ..variables.base_variable import BaseVariable
+from ..variables.bound_variables import BufferVariable
from ..variables.variables import ShaderVariable
+from . import utils
-from typing import Any
-# https://docs.vulkan.org/glsl/latest/chapters/builtinfunctions.html#atomic-memory-functions
+def _is_buffer_backed_target(var: ShaderVariable) -> bool:
+ stack: List[BaseVariable] = [var]
+ visited_ids = set()
+
+ while len(stack) > 0:
+ current = stack.pop()
+ current_id = id(current)
+ if current_id in visited_ids:
+ continue
+ visited_ids.add(current_id)
+
+ if isinstance(current, BufferVariable):
+ return True
+
+ stack.extend(current.parents)
+
+ return False
+
+# https://docs.vulkan.org/glsl/latest/chapters/builtinfunctions.html#atomic-memory-functions
def atomic_add(mem: ShaderVariable, y: Any) -> ShaderVariable:
- raise NotImplementedError("atomic_add is not implemented yet")
+ assert isinstance(mem, ShaderVariable), f"atomic_add target must be a ShaderVariable, got {type(mem)}"
+ assert dtypes.is_scalar(mem.var_type), "atomic_add target must be a scalar lvalue"
+ assert mem.is_setable(), "atomic_add target must be a writable lvalue"
+ assert not mem.is_register(), "atomic_add does not support register/local variables as target"
+ assert _is_buffer_backed_target(mem), "atomic_add target must reference a buffer element (e.g., buf[idx])"
+
+ assert mem.var_type in (dtypes.int32, dtypes.uint32), (
+ f"atomic_add currently supports only int32/uint32 targets, got '{mem.var_type.name}'"
+ )
+
+ parents: List[BaseVariable] = [mem]
+
+ if isinstance(y, ShaderVariable):
+ assert dtypes.is_scalar(y.var_type), "atomic_add increment variable must be scalar"
+ assert dtypes.is_integer_dtype(y.var_type), (
+ f"atomic_add increment variable must be integer-typed, got '{y.var_type.name}'"
+ )
+ y.read_callback()
+ parents.append(y)
+ y_expr = utils.backend_constructor(mem.var_type, y)
+ elif utils.is_int_number(y):
+ y_expr = utils.backend_constructor(mem.var_type, y)
+ elif utils.is_number(y):
+ raise TypeError(f"atomic_add increment must be an integer scalar, got {y!r}")
+ else:
+ raise TypeError(f"atomic_add increment must be an integer scalar or ShaderVariable, got {type(y)}")
+
+ mem.read_callback()
+ mem.write_callback()
- # assert isinstance(mem, BaseVariable), "mem must be a BaseVariable"
+ result_var = utils.new_var(
+ mem.var_type,
+ None,
+ parents=parents,
+ lexical_unit=True,
+ settable=True,
+ register=True
+ )
- # new_var = self.make_var(arg1.var_type, None, [])
- # self.append_contents(f"{new_var.var_type.glsl_type} {new_var.name} = atomicAdd({arg1.resolve()}, {arg2.resolve()});\n")
+ atomic_expr = utils.codegen_backend().atomic_add_expr(mem.resolve(), y_expr, mem.var_type)
+ utils.append_contents(
+ f"{utils.backend_type_name(result_var.var_type)} {result_var.name} = {atomic_expr};\n"
+ )
- # return mem.new_var(
- # mem.var_type,
- # f"atomicAdd({mem.resolve()}, {resolve_input(y)})",
- # parents=[y, x],
- # lexical_unit=True
- # )
\ No newline at end of file
+ return result_var
diff --git a/vkdispatch/codegen/functions/base_functions/arithmetic.py b/vkdispatch/codegen/functions/base_functions/arithmetic.py
index 4ecab608..79e890e5 100644
--- a/vkdispatch/codegen/functions/base_functions/arithmetic.py
+++ b/vkdispatch/codegen/functions/base_functions/arithmetic.py
@@ -1,11 +1,11 @@
import vkdispatch.base.dtype as dtypes
from vkdispatch.codegen.variables.base_variable import BaseVariable
-from typing import Any
+from typing import Any, Tuple, Union
-from ...._compat import numpy_compat as npc
+from .. import scalar_eval as se
def my_log2_int(x: int) -> int:
- return int(npc.round(npc.log2(x)))
+ return int(se.round(se.log2(x)))
from . import base_utils
@@ -18,6 +18,26 @@ def _mark_arith_unary(var: BaseVariable, op: str) -> None:
def _mark_arith_binary(lhs_type: dtypes.dtype, rhs_type: dtypes.dtype, op: str, *, inplace: bool = False) -> None:
base_utils.get_codegen_backend().mark_composite_binary_op(lhs_type, rhs_type, op, inplace=inplace)
+def _resolve_arithmetic_binary_expr(
+ op: str,
+ lhs_type: dtypes.dtype,
+ lhs_expr: str,
+ rhs_type: dtypes.dtype,
+ rhs_expr: str,
+) -> Tuple[str, bool]:
+ override_expr = base_utils.get_codegen_backend().arithmetic_binary_expr(
+ op, lhs_type, lhs_expr, rhs_type, rhs_expr
+ )
+ if override_expr is not None:
+ return override_expr, True
+ return f"{lhs_expr} {op} {rhs_expr}", False
+
+def _resolve_arithmetic_unary_expr(op: str, var_type: dtypes.dtype, var_expr: str) -> Tuple[str, bool]:
+ override_expr = base_utils.get_codegen_backend().arithmetic_unary_expr(op, var_type, var_expr)
+ if override_expr is not None:
+ return override_expr, True
+ return f"{op}{var_expr}", False
+
def arithmetic_op_common(var: BaseVariable,
other: Any,
reverse: bool = False,
@@ -54,27 +74,55 @@ def add(var: BaseVariable, other: Any, inplace: bool = False) -> BaseVariable:
return_type = arithmetic_op_common(var, other, inplace=inplace)
if base_utils.is_scalar_number(other):
- _mark_arith_binary(var.var_type, base_utils.number_to_dtype(other), "+", inplace=inplace)
+ scalar_type = base_utils.number_to_dtype(other)
+ scalar_expr = base_utils.format_number_literal(other)
+ _mark_arith_binary(var.var_type, scalar_type, "+", inplace=inplace)
+ expr, use_assignment = _resolve_arithmetic_binary_expr(
+ "+",
+ var.var_type,
+ var.resolve(),
+ scalar_type,
+ scalar_expr,
+ )
if not inplace:
+ if use_assignment:
+ return base_utils.new_base_var(
+ return_type,
+ expr,
+ parents=[var],
+ )
return base_utils.new_scaled_var(
return_type,
var.resolve(),
offset=other,
parents=[var])
- base_utils.append_contents(f"{var.resolve()} += {other};\n")
+ if use_assignment:
+ base_utils.append_contents(f"{var.resolve()} = {expr};\n")
+ else:
+ base_utils.append_contents(f"{var.resolve()} += {scalar_expr};\n")
return var
assert isinstance(other, BaseVariable)
_mark_arith_binary(var.var_type, other.var_type, "+", inplace=inplace)
+ expr, use_assignment = _resolve_arithmetic_binary_expr(
+ "+",
+ var.var_type,
+ var.resolve(),
+ other.var_type,
+ other.resolve(),
+ )
if not inplace:
return base_utils.new_base_var(
return_type,
- f"{var.resolve()} + {other.resolve()}",
+ expr,
parents=[var, other])
- base_utils.append_contents(f"{var.resolve()} += {other.resolve()};\n")
+ if use_assignment:
+ base_utils.append_contents(f"{var.resolve()} = {expr};\n")
+ else:
+ base_utils.append_contents(f"{var.resolve()} += {other.resolve()};\n")
return var
def sub(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable:
@@ -82,60 +130,103 @@ def sub(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = Fa
if base_utils.is_scalar_number(other):
scalar_type = base_utils.number_to_dtype(other)
+ scalar_expr = base_utils.format_number_literal(other)
if reverse and not inplace:
_mark_arith_unary(var, "-")
_mark_arith_binary(var.var_type, scalar_type, "+", inplace=False)
else:
# Non-reverse scalar subtraction is emitted as `+ (-scalar)` via scaled-var optimization.
_mark_arith_binary(var.var_type, scalar_type, "+" if not inplace else "-", inplace=inplace)
+ expr, use_assignment = _resolve_arithmetic_binary_expr(
+ "-",
+ scalar_type if reverse else var.var_type,
+ scalar_expr if reverse else var.resolve(),
+ var.var_type if reverse else scalar_type,
+ var.resolve() if reverse else scalar_expr,
+ )
if not inplace:
+ if use_assignment:
+ return base_utils.new_base_var(
+ return_type,
+ expr,
+ parents=[var],
+ )
return base_utils.new_scaled_var(
return_type,
f"(-{var.resolve()})" if reverse else var.resolve(),
- offset=other,
+ offset=other if reverse else -other,
parents=[var])
- base_utils.append_contents(f"{var.resolve()} -= {other};\n")
+ if use_assignment:
+ base_utils.append_contents(f"{var.resolve()} = {expr};\n")
+ else:
+ base_utils.append_contents(f"{var.resolve()} -= {scalar_expr};\n")
return var
assert isinstance(other, BaseVariable)
- _mark_arith_binary(var.var_type if not reverse else other.var_type, other.var_type if not reverse else var.var_type, "-", inplace=inplace)
+ lhs_type = var.var_type if not reverse else other.var_type
+ rhs_type = other.var_type if not reverse else var.var_type
+ _mark_arith_binary(lhs_type, rhs_type, "-", inplace=inplace)
+ expr, use_assignment = _resolve_arithmetic_binary_expr(
+ "-",
+ lhs_type,
+ var.resolve() if not reverse else other.resolve(),
+ rhs_type,
+ other.resolve() if not reverse else var.resolve(),
+ )
if not inplace:
return base_utils.new_base_var(
return_type,
- (
- f"{var.resolve()} - {other.resolve()}"
- if not reverse else
- f"{other.resolve()} - {var.resolve()}"
- ),
+ expr,
parents=[var, other])
- base_utils.append_contents(f"{var.resolve()} -= {other.resolve()};\n")
+ if use_assignment:
+ base_utils.append_contents(f"{var.resolve()} = {expr};\n")
+ else:
+ base_utils.append_contents(f"{var.resolve()} -= {other.resolve()};\n")
return var
def mul(var: BaseVariable, other: Any, inplace: bool = False) -> BaseVariable:
- return_type = arithmetic_op_common(var, other, inplace=inplace)
-
if base_utils.is_scalar_number(other):
+ return_type = arithmetic_op_common(var, other, inplace=inplace)
+ scalar_type = base_utils.number_to_dtype(other)
+ scalar_expr = base_utils.format_number_literal(other)
+ expr, use_assignment = _resolve_arithmetic_binary_expr(
+ "*",
+ var.var_type,
+ var.resolve(),
+ scalar_type,
+ scalar_expr,
+ )
if not inplace:
if other == 1:
return var
- if dtypes.is_integer_dtype(var.var_type) and base_utils.is_int_number(other) and base_utils.is_int_power_of_2(other):
+ if (
+ not use_assignment
+ and dtypes.is_integer_dtype(var.var_type)
+ and base_utils.is_int_number(other)
+ and base_utils.is_int_power_of_2(other)
+ ):
power = my_log2_int(other)
- _mark_arith_binary(var.var_type, base_utils.number_to_dtype(other), "<<", inplace=False)
+ _mark_arith_binary(var.var_type, scalar_type, "<<", inplace=False)
return base_utils.new_base_var(var.var_type, f"{var.resolve()} << {power}", [var])
- _mark_arith_binary(var.var_type, base_utils.number_to_dtype(other), "*", inplace=False)
- return base_utils.new_scaled_var(
- return_type,
- var.resolve(),
- scale=other,
- parents=[var])
-
- _mark_arith_binary(var.var_type, base_utils.number_to_dtype(other), "*", inplace=True)
- base_utils.append_contents(f"{var.resolve()} *= {other};\n")
+ _mark_arith_binary(var.var_type, scalar_type, "*", inplace=False)
+ if use_assignment:
+ return base_utils.new_base_var(
+ return_type,
+ expr,
+ parents=[var],
+ )
+ return base_utils.new_scaled_var(return_type, var.resolve(), scale=other, parents=[var])
+
+ _mark_arith_binary(var.var_type, scalar_type, "*", inplace=True)
+ if use_assignment:
+ base_utils.append_contents(f"{var.resolve()} = {expr};\n")
+ else:
+ base_utils.append_contents(f"{var.resolve()} *= {scalar_expr};\n")
return var
assert isinstance(other, BaseVariable)
@@ -146,14 +237,32 @@ def mul(var: BaseVariable, other: Any, inplace: bool = False) -> BaseVariable:
if dtypes.is_matrix(var.var_type) and dtypes.is_matrix(other.var_type):
raise ValueError("Matrix multiplication is not supported via the `*` operator. Use `@` operator instead.")
+ return_type = dtypes.cross_multiply_type(var.var_type, other.var_type)
+ if inplace:
+ assert var.is_setable(), "Inplace arithmetic requires the variable to be settable."
+ var.read_callback()
+ var.write_callback()
+ other.read_callback()
+ assert return_type == var.var_type, "Inplace arithmetic requires the result type to match the variable type."
+
_mark_arith_binary(var.var_type, other.var_type, "*", inplace=inplace)
+ expr, use_assignment = _resolve_arithmetic_binary_expr(
+ "*",
+ var.var_type,
+ var.resolve(),
+ other.var_type,
+ other.resolve(),
+ )
if not inplace:
return base_utils.new_base_var(
- var.var_type,
- f"{var.resolve()} * {other.resolve()}",
+ return_type,
+ expr,
parents=[var, other])
- base_utils.append_contents(f"{var.resolve()} *= {other.resolve()};\n")
+ if use_assignment:
+ base_utils.append_contents(f"{var.resolve()} = {expr};\n")
+ else:
+ base_utils.append_contents(f"{var.resolve()} *= {other.resolve()};\n")
return var
def truediv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable:
@@ -165,21 +274,39 @@ def truediv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool
if base_utils.is_scalar_number(other):
scalar_f_type = dtypes.float32
+ other_expr = base_utils.format_number_literal(other, force_float32=True)
if not reverse:
_mark_arith_binary(return_type, scalar_f_type, "/", inplace=inplace)
else:
_mark_arith_binary(scalar_f_type, return_type, "/", inplace=inplace)
+ lhs_expr = base_utils.to_dtype_base(return_type, var).resolve() if not reverse else other_expr
+ rhs_expr = other_expr if not reverse else base_utils.to_dtype_base(return_type, var).resolve()
+ lhs_type = return_type if not reverse else scalar_f_type
+ rhs_type = scalar_f_type if not reverse else return_type
+ expr, use_assignment = _resolve_arithmetic_binary_expr(
+ "/",
+ lhs_type,
+ lhs_expr,
+ rhs_type,
+ rhs_expr,
+ )
if not inplace:
return base_utils.new_base_var(
return_type,
- (
- f"{base_utils.to_dtype_base(return_type, var).resolve()} / {float(other)}"
- if not reverse else
- f"{float(other)} / {base_utils.to_dtype_base(return_type, var).resolve()}"
- ),
+ expr,
parents=[var])
- base_utils.append_contents(f"{var.resolve()} /= {float(other)};\n")
+ if use_assignment:
+ inplace_expr, _ = _resolve_arithmetic_binary_expr(
+ "/",
+ var.var_type,
+ var.resolve(),
+ scalar_f_type,
+ other_expr,
+ )
+ base_utils.append_contents(f"{var.resolve()} = {inplace_expr};\n")
+ else:
+ base_utils.append_contents(f"{var.resolve()} /= {other_expr};\n")
return var
assert isinstance(other, BaseVariable)
@@ -193,17 +320,42 @@ def truediv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool
lhs_mark_type = return_type if not reverse else dtypes.make_floating_dtype(other.var_type)
rhs_mark_type = dtypes.make_floating_dtype(other.var_type) if not reverse else return_type
_mark_arith_binary(lhs_mark_type, rhs_mark_type, "/", inplace=inplace)
+
+ lhs_expr = (
+ base_utils.to_dtype_base(lhs_mark_type, var).resolve()
+ if not reverse else
+ base_utils.to_dtype_base(lhs_mark_type, other).resolve()
+ )
+ rhs_expr = (
+ base_utils.to_dtype_base(rhs_mark_type, other).resolve()
+ if not reverse else
+ base_utils.to_dtype_base(rhs_mark_type, var).resolve()
+ )
+ expr, use_assignment = _resolve_arithmetic_binary_expr(
+ "/",
+ lhs_mark_type,
+ lhs_expr,
+ rhs_mark_type,
+ rhs_expr,
+ )
+
if not inplace:
return base_utils.new_base_var(
return_type,
- (
- f"{base_utils.to_dtype_base(return_type, var).resolve()} / {base_utils.to_dtype_base(return_type, other).resolve()}"
- if not reverse else
- f"{base_utils.to_dtype_base(return_type, other).resolve()} / {base_utils.to_dtype_base(return_type, var).resolve()}"
- ),
+ expr,
parents=[var, other])
- base_utils.append_contents(f"{var.resolve()} /= {base_utils.to_dtype_base(return_type, other).resolve()};\n")
+ if use_assignment:
+ inplace_expr, _ = _resolve_arithmetic_binary_expr(
+ "/",
+ var.var_type,
+ var.resolve(),
+ rhs_mark_type,
+ rhs_expr,
+ )
+ base_utils.append_contents(f"{var.resolve()} = {inplace_expr};\n")
+ else:
+ base_utils.append_contents(f"{var.resolve()} /= {rhs_expr};\n")
return var
def floordiv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable:
@@ -291,43 +443,84 @@ def mod(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = Fa
base_utils.append_contents(f"{var.resolve()} %= {other.resolve()};\n")
return var
-def pow(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable:
- return_type = arithmetic_op_common(var, other, reverse=reverse, inplace=inplace)
-
- if base_utils.is_scalar_number(other):
- if not inplace:
- return base_utils.new_base_var(
- return_type,
- (
- f"pow({var.resolve()}, {other})"
- if not reverse else
- f"pow({other}, {var.resolve()})"
- ),
- parents=[var])
- base_utils.append_contents(f"{var.resolve()} = pow({var.resolve()}, {other});\n")
- return var
+def pow_expr(x: Any, y: Any) -> Union[BaseVariable, float]:
+ if base_utils.is_int_number(y) and y == 0:
+ return 1
+
+ if base_utils.is_number(y) and base_utils.is_number(x):
+ return se.power(x, y)
+
+ if base_utils.is_number(x) and isinstance(y, BaseVariable):
+ result_type = base_utils.dtype_to_floating(y.var_type)
+ return base_utils.new_base_var(
+ result_type,
+ base_utils.get_codegen_backend().binary_math_expr(
+ "pow",
+ dtypes.float32,
+ base_utils.resolve_input(x),
+ result_type,
+ y.resolve(),
+ ),
+ parents=[y]
+ )
+
+ if base_utils.is_number(y) and isinstance(x, BaseVariable):
+ result_type = base_utils.dtype_to_floating(x.var_type)
- assert isinstance(other, BaseVariable)
+ if base_utils.is_int_number(y) and x.is_register():
+ if y > 0 and y <= 4:
+ expr = " * ".join([x.resolve()] * int(y))
+ return base_utils.new_base_var(result_type, expr, parents=[x])
+ elif y < 0 and y >= -4:
+ expr = " * ".join([x.resolve()] * int(-y))
+ return base_utils.new_base_var(result_type, f"1 / ({expr})", parents=[x])
- if not inplace:
return base_utils.new_base_var(
- return_type,
- (
- f"pow({var.resolve()}, {other.resolve()})"
- if not reverse else
- f"pow({other.resolve()}, {var.resolve()})"
+ result_type,
+ base_utils.get_codegen_backend().binary_math_expr(
+ "pow",
+ result_type,
+ x.resolve(),
+ dtypes.float32,
+ base_utils.resolve_input(y),
),
- parents=[var, other])
+ parents=[x]
+ )
+
+ assert isinstance(y, BaseVariable), "First argument must be a ShaderVariable or number"
+ assert isinstance(x, BaseVariable), "Second argument must be a ShaderVariable or number"
+
+ result_type = base_utils.dtype_to_floating(dtypes.cross_type(x.var_type, y.var_type))
+ return base_utils.new_base_var(
+ result_type,
+ base_utils.get_codegen_backend().binary_math_expr(
+ "pow",
+ base_utils.dtype_to_floating(x.var_type),
+ x.resolve(),
+ base_utils.dtype_to_floating(y.var_type),
+ y.resolve(),
+ ),
+ parents=[y, x],
+ lexical_unit=True
+ )
+
+def pow(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable:
+ _ = arithmetic_op_common(var, other, reverse=reverse, inplace=inplace)
+ experession = pow_expr(other, var) if reverse else pow_expr(var, other)
+
+ if not inplace:
+ return experession
- base_utils.append_contents(f"{var.resolve()} = pow({var.resolve()}, {other.resolve()});\n")
+ base_utils.append_contents(f"{var.resolve()} = {experession};\n")
return var
def neg(var: BaseVariable) -> BaseVariable:
_mark_arith_unary(var, "-")
+ expr, _ = _resolve_arithmetic_unary_expr("-", var.var_type, var.resolve())
return base_utils.new_base_var(
var.var_type,
- f"-{var.resolve()}",
+ expr,
parents=[var])
def absolute(var: BaseVariable) -> BaseVariable:
diff --git a/vkdispatch/codegen/functions/base_functions/base_utils.py b/vkdispatch/codegen/functions/base_functions/base_utils.py
index a6daaf5f..7a5d7d71 100644
--- a/vkdispatch/codegen/functions/base_functions/base_utils.py
+++ b/vkdispatch/codegen/functions/base_functions/base_utils.py
@@ -4,13 +4,18 @@
from typing import Any, Optional
import numbers
+import math
-from ...._compat import numpy_compat as npc
+from ....compat import numpy_compat as npc
from vkdispatch.codegen.shader_writer import new_scaled_var, append_contents, new_name
from vkdispatch.codegen.global_builder import get_codegen_backend
from vkdispatch.codegen.shader_writer import new_var as new_var_impl
+_I32_MIN = -(2 ** 31)
+_I32_MAX = 2 ** 31 - 1
+_U32_MAX = 2 ** 32 - 1
+
def new_base_var(var_type: dtypes.dtype,
var_name: Optional[str],
parents: list,
@@ -45,9 +50,13 @@ def is_int_power_of_2(n: int) -> bool:
def number_to_dtype(number: numbers.Number):
if is_int_number(number):
if number >= 0:
- return dtypes.uint32
+ if number <= _U32_MAX:
+ return dtypes.uint32
+ return dtypes.uint64
- return dtypes.int32
+ if number >= _I32_MIN and number <= _I32_MAX:
+ return dtypes.int32
+ return dtypes.int64
elif is_float_number(number):
return dtypes.float32
elif is_complex_number(number):
@@ -62,33 +71,65 @@ def check_is_int(variable):
return npc.is_integer_scalar(variable)
def dtype_to_floating(var_type: dtypes.dtype) -> dtypes.dtype:
- if var_type == dtypes.int32 or var_type == dtypes.uint32:
- return dtypes.float32
+ return dtypes.make_floating_dtype(var_type)
+
+def _inf_scalar_type(var_type: dtypes.dtype) -> dtypes.dtype:
+ """Extract the scalar float type from any dtype."""
+ if dtypes.is_complex(var_type):
+ return var_type.child_type
+ if dtypes.is_vector(var_type) or dtypes.is_matrix(var_type):
+ return var_type.scalar
+ return var_type
+
+def format_number_literal(var: numbers.Number, *, force_float32: bool = False, dtype: Optional[dtypes.dtype] = None) -> str:
+ if is_complex_number(var):
+ return str(var)
- if var_type == dtypes.ivec2 or var_type == dtypes.uvec2:
- return dtypes.vec2
+ if is_float_number(var) or (force_float32 and is_int_number(var)):
+ value = float(var)
- if var_type == dtypes.ivec3 or var_type == dtypes.uvec3:
- return dtypes.vec3
-
- if var_type == dtypes.ivec4 or var_type == dtypes.uvec4:
- return dtypes.vec4
-
- return var_type
+ if math.isinf(value):
+ backend = get_codegen_backend()
+ scalar = _inf_scalar_type(dtype) if dtype is not None else dtypes.float32
+ if scalar is dtypes.float64:
+ return backend.inf_f64_expr() if value > 0 else backend.ninf_f64_expr()
+ if scalar is dtypes.float16:
+ return backend.inf_f16_expr() if value > 0 else backend.ninf_f16_expr()
+ return backend.inf_f32_expr() if value > 0 else backend.ninf_f32_expr()
+
+ if math.isnan(value):
+ return "(0.0f / 0.0f)"
+
+ literal = repr(value)
+ if "e" not in literal and "E" not in literal and "." not in literal:
+ literal += ".0"
+ return literal + "f"
-def resolve_input(var: Any) -> str:
+ return str(var)
+
+def resolve_input(var: Any, dtype: Optional[dtypes.dtype] = None) -> str:
#print("Resolving input:", var)
if is_number(var):
- return str(var)
-
+ return format_number_literal(var, dtype=dtype)
+
assert isinstance(var, BaseVariable), "Argument must be a ShaderVariable or number"
return var.resolve()
+def resolve_input_type(var: Any) -> Optional[dtypes.dtype]:
+ if is_number(var):
+ return number_to_dtype(var)
+
+ if isinstance(var, BaseVariable):
+ return var.var_type
+
+ return None
+
def backend_constructor(var_type: dtypes.dtype, *args) -> str:
return get_codegen_backend().constructor(
var_type,
- [resolve_input(elem) for elem in args]
+ [resolve_input(elem, dtype=var_type) for elem in args],
+ arg_types=[resolve_input_type(elem) for elem in args],
)
def to_dtype_base(var_type: dtypes.dtype, *args):
diff --git a/vkdispatch/codegen/functions/block_synchonization.py b/vkdispatch/codegen/functions/block_synchonization.py
index ca0da11c..3deccc45 100644
--- a/vkdispatch/codegen/functions/block_synchonization.py
+++ b/vkdispatch/codegen/functions/block_synchonization.py
@@ -1,4 +1,4 @@
-from ..global_builder import get_builder
+from ..global_builder import get_builder, get_codegen_backend
from . import utils
@@ -6,7 +6,7 @@ def barrier():
# On Apple devices, a memory barrier is required before a barrier
# to ensure memory operations are visible to all threads
# (for some reason)
- if get_builder().is_apple_device:
+ if get_builder().is_apple_device and get_codegen_backend().name == "glsl":
memory_barrier()
utils.append_contents(utils.codegen_backend().barrier_statement() + "\n")
diff --git a/vkdispatch/codegen/functions/builtin_constants.py b/vkdispatch/codegen/functions/builtin_constants.py
index f023fdb6..47812331 100644
--- a/vkdispatch/codegen/functions/builtin_constants.py
+++ b/vkdispatch/codegen/functions/builtin_constants.py
@@ -17,6 +17,38 @@ def ninf_f32():
lexical_unit=True
)
+def inf_f64():
+ return utils.new_var(
+ dtypes.float64,
+ utils.codegen_backend().inf_f64_expr(),
+ [],
+ lexical_unit=True
+ )
+
+def ninf_f64():
+ return utils.new_var(
+ dtypes.float64,
+ utils.codegen_backend().ninf_f64_expr(),
+ [],
+ lexical_unit=True
+ )
+
+def inf_f16():
+ return utils.new_var(
+ dtypes.float16,
+ utils.codegen_backend().inf_f16_expr(),
+ [],
+ lexical_unit=True
+ )
+
+def ninf_f16():
+ return utils.new_var(
+ dtypes.float16,
+ utils.codegen_backend().ninf_f16_expr(),
+ [],
+ lexical_unit=True
+ )
+
def global_invocation_id():
return utils.new_var(
dtypes.uvec3,
diff --git a/vkdispatch/codegen/functions/common_builtins.py b/vkdispatch/codegen/functions/common_builtins.py
index 741d590a..e801bdda 100644
--- a/vkdispatch/codegen/functions/common_builtins.py
+++ b/vkdispatch/codegen/functions/common_builtins.py
@@ -3,7 +3,7 @@
from typing import Any, Union, Tuple
from . import utils
-from ..._compat import numpy_compat as npc
+from . import scalar_eval as se
def comment(comment: str, preceding_new_line: bool = True) -> None:
comment_text = str(comment).replace("\r\n", "\n").replace("\r", "\n")
@@ -45,7 +45,7 @@ def abs(var: Any) -> Union[ShaderVariable, float]:
def sign(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.sign(var)
+ return se.sign(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
@@ -58,7 +58,7 @@ def sign(var: Any) -> Union[ShaderVariable, float]:
def floor(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.floor(var)
+ return se.floor(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
@@ -71,7 +71,7 @@ def floor(var: Any) -> Union[ShaderVariable, float]:
def ceil(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.ceil(var)
+ return se.ceil(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
@@ -84,7 +84,7 @@ def ceil(var: Any) -> Union[ShaderVariable, float]:
def trunc(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.trunc(var)
+ return se.trunc(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
@@ -97,7 +97,7 @@ def trunc(var: Any) -> Union[ShaderVariable, float]:
def round(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.round(var)
+ return se.round(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
@@ -110,7 +110,7 @@ def round(var: Any) -> Union[ShaderVariable, float]:
def round_even(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.round(var)
+ return se.round(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
utils.mark_backend_feature("roundEven")
@@ -124,7 +124,7 @@ def round_even(var: Any) -> Union[ShaderVariable, float]:
def fract(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return float(var - npc.floor(var))
+ return float(var - se.floor(var))
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
utils.mark_backend_feature("fract")
@@ -138,7 +138,7 @@ def fract(var: Any) -> Union[ShaderVariable, float]:
def mod(x: Any, y: Any) -> Union[ShaderVariable, float]:
if utils.is_number(y) and utils.is_number(x):
- return npc.mod(x, y)
+ return se.mod(x, y)
base_var = None
@@ -160,14 +160,14 @@ def mod(x: Any, y: Any) -> Union[ShaderVariable, float]:
def modf(x: Any, y: Any) -> Tuple[ShaderVariable, ShaderVariable]:
if utils.is_number(y) and utils.is_number(x):
- a, b = npc.modf(x, y)
+ a, b = se.modf(x, y)
return float(a), float(b)
if utils.is_number(x) and isinstance(y, ShaderVariable):
utils.mark_backend_feature("mod")
return utils.new_var(
utils.dtype_to_floating(y.var_type),
- f"mod({x}, {y.resolve()})",
+ f"mod({utils.resolve_input(x)}, {y.resolve()})",
parents=[y]
)
@@ -175,7 +175,7 @@ def modf(x: Any, y: Any) -> Tuple[ShaderVariable, ShaderVariable]:
utils.mark_backend_feature("mod")
return utils.new_var(
utils.dtype_to_floating(x.var_type),
- f"mod({x.resolve()}, {y})",
+ f"mod({x.resolve()}, {utils.resolve_input(y)})",
parents=[x]
)
@@ -192,7 +192,7 @@ def modf(x: Any, y: Any) -> Tuple[ShaderVariable, ShaderVariable]:
def min(x: Any, y: Any) -> Union[ShaderVariable, float]:
if utils.is_number(y) and utils.is_number(x):
- return npc.minimum(x, y)
+ return se.minimum(x, y)
base_var = None
@@ -212,7 +212,7 @@ def min(x: Any, y: Any) -> Union[ShaderVariable, float]:
def max(x: Any, y: Any) -> Union[ShaderVariable, float]:
if utils.is_number(y) and utils.is_number(x):
- return npc.maximum(x, y)
+ return se.maximum(x, y)
base_var = None
@@ -232,7 +232,7 @@ def max(x: Any, y: Any) -> Union[ShaderVariable, float]:
def clip(x: Any, min_val: Any, max_val: Any) -> Union[ShaderVariable, float]:
if utils.is_number(x) and utils.is_number(min_val) and utils.is_number(max_val):
- return npc.clip(x, min_val, max_val)
+ return se.clip(x, min_val, max_val)
base_var = None
@@ -257,7 +257,7 @@ def clamp(x: Any, min_val: Any, max_val: Any) -> Union[ShaderVariable, float]:
def mix(x: Any, y: Any, a: Any) -> Union[ShaderVariable, float]:
if utils.is_number(y) and utils.is_number(x) and utils.is_number(a):
- return npc.interp(a, [0, 1], [x, y])
+ return se.interp(a, [0, 1], [x, y])
base_var = None
@@ -303,7 +303,7 @@ def step(edge: Any, x: Any) -> Union[ShaderVariable, float]:
def smoothstep(edge0: Any, edge1: Any, x: Any) -> Union[ShaderVariable, float]:
if utils.is_number(edge0) and utils.is_number(edge1) and utils.is_number(x):
- t = npc.clip((x - edge0) / (edge1 - edge0), 0.0, 1.0)
+ t = se.clip((x - edge0) / (edge1 - edge0), 0.0, 1.0)
return float(t * t * (3.0 - 2.0 * t))
base_var = None
@@ -328,7 +328,7 @@ def smoothstep(edge0: Any, edge1: Any, x: Any) -> Union[ShaderVariable, float]:
def isnan(var: Any) -> Union[ShaderVariable, bool]:
if utils.is_number(var):
- return npc.isnan(var)
+ return se.isnan(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
@@ -341,7 +341,7 @@ def isnan(var: Any) -> Union[ShaderVariable, bool]:
def isinf(var: Any) -> Union[ShaderVariable, bool]:
if utils.is_number(var):
- return npc.isinf(var)
+ return se.isinf(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
@@ -354,7 +354,7 @@ def isinf(var: Any) -> Union[ShaderVariable, bool]:
def float_bits_to_int(var: Any) -> Union[ShaderVariable, int]:
if utils.is_number(var):
- return npc.float_bits_to_int(var)
+ return se.float_bits_to_int(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
@@ -367,7 +367,7 @@ def float_bits_to_int(var: Any) -> Union[ShaderVariable, int]:
def float_bits_to_uint(var: Any) -> Union[ShaderVariable, int]:
if utils.is_number(var):
- return npc.float_bits_to_uint(var)
+ return se.float_bits_to_uint(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
@@ -380,7 +380,7 @@ def float_bits_to_uint(var: Any) -> Union[ShaderVariable, int]:
def int_bits_to_float(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.int_bits_to_float(var)
+ return se.int_bits_to_float(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
@@ -393,7 +393,7 @@ def int_bits_to_float(var: Any) -> Union[ShaderVariable, float]:
def uint_bits_to_float(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.uint_bits_to_float(var)
+ return se.uint_bits_to_float(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
diff --git a/vkdispatch/codegen/functions/complex_numbers.py b/vkdispatch/codegen/functions/complex_numbers.py
index af6a33ce..e99f3d7b 100644
--- a/vkdispatch/codegen/functions/complex_numbers.py
+++ b/vkdispatch/codegen/functions/complex_numbers.py
@@ -4,28 +4,32 @@
from .common_builtins import fma
-from .type_casting import to_complex
+from .type_casting import to_complex, to_dtype
from . import utils
from .trigonometry import cos, sin
def complex_from_euler_angle(angle: ShaderVariable):
- return to_complex(cos(angle), sin(angle))
+ if not isinstance(angle, ShaderVariable):
+ raise TypeError("complex_from_euler_angle expects a ShaderVariable angle")
+
+ target_complex_type = dtypes.complex_from_float(dtypes.make_floating_dtype(angle.var_type))
+ return to_dtype(target_complex_type, cos(angle), sin(angle))
def validate_complex_number(arg1: Any) -> Union[ShaderVariable, complex]:
if isinstance(arg1, ShaderVariable):
- assert arg1.var_type == dtypes.complex64, "Input variables to complex multiplication must be complex"
+ assert dtypes.is_complex(arg1.var_type), "Input variables to complex multiplication must be complex"
return arg1
assert utils.is_number(arg1), "Argument must be ShaderVariable or number"
return complex(arg1)
-def _new_big_complex(arg1: Any, arg2: Any):
- var_str = utils.backend_constructor(dtypes.complex64, arg1, arg2)
+def _new_big_complex(var_type: dtypes.dtype, arg1: Any, arg2: Any):
+ var_str = utils.backend_constructor(var_type, arg1, arg2)
return utils.new_var(
- dtypes.complex64,
+ var_type,
var_str,
[utils.resolve_input(arg1), utils.resolve_input(arg2)],
lexical_unit=True
@@ -35,4 +39,19 @@ def mult_complex(arg1: ShaderVariable, arg2: ShaderVariable):
a1 = validate_complex_number(arg1)
a2 = validate_complex_number(arg2)
- return _new_big_complex(fma(a1.real, a2.real, -a1.imag * a2.imag), fma(a1.real, a2.imag, a1.imag * a2.real))
+ fallback_type = dtypes.complex64
+ for normalized_arg in (a1, a2):
+ if isinstance(normalized_arg, ShaderVariable):
+ fallback_type = normalized_arg.var_type
+ break
+
+ result_type = None
+ for normalized_arg in (a1, a2):
+ arg_type = normalized_arg.var_type if isinstance(normalized_arg, ShaderVariable) else fallback_type
+ result_type = arg_type if result_type is None else dtypes.cross_type(result_type, arg_type)
+
+ return _new_big_complex(
+ result_type, # type: ignore[arg-type]
+ fma(a1.real, a2.real, -a1.imag * a2.imag),
+ fma(a1.real, a2.imag, a1.imag * a2.real),
+ )
diff --git a/vkdispatch/codegen/functions/control_flow.py b/vkdispatch/codegen/functions/control_flow.py
index 107627c3..88fcad45 100644
--- a/vkdispatch/codegen/functions/control_flow.py
+++ b/vkdispatch/codegen/functions/control_flow.py
@@ -52,8 +52,18 @@ def else_if_all(*args: List[ShaderVariable]):
utils.scope_increment()
def return_statement(arg=None):
- arg = arg if arg is not None else ""
- utils.append_contents(f"return {arg};\n")
+ if arg is None:
+ utils.append_contents("return;\n")
+ return
+
+ if isinstance(arg, str):
+ arg_expr = arg
+ elif isinstance(arg, ShaderVariable) or utils.is_number(arg):
+ arg_expr = utils.resolve_input(arg)
+ else:
+ arg_expr = str(arg)
+
+ utils.append_contents(f"return {arg_expr};\n")
def while_statement(arg: ShaderVariable):
utils.append_contents(f"while({proc_bool(arg)}) {'{'}\n")
@@ -75,7 +85,7 @@ def end(indent: bool = True):
utils.append_contents("}\n")
def logical_and(arg1: ShaderVariable, arg2: ShaderVariable):
- return utils.new_var(dtypes.int32, f"({arg1} && {arg2})", [arg1, arg2])
+ return utils.new_var(dtypes.int32, f"({proc_bool(arg1)} && {proc_bool(arg2)})", [arg1, arg2])
def logical_or(arg1: ShaderVariable, arg2: ShaderVariable):
- return utils.new_var(dtypes.int32, f"({arg1} || {arg2})", [arg1, arg2])
\ No newline at end of file
+ return utils.new_var(dtypes.int32, f"({proc_bool(arg1)} || {proc_bool(arg2)})", [arg1, arg2])
diff --git a/vkdispatch/codegen/functions/exponential.py b/vkdispatch/codegen/functions/exponential.py
index 1b67e6b4..68b2ebc6 100644
--- a/vkdispatch/codegen/functions/exponential.py
+++ b/vkdispatch/codegen/functions/exponential.py
@@ -1,105 +1,149 @@
+import vkdispatch.base.dtype as dtypes
from ..variables.variables import ShaderVariable
from typing import Any, Union
from . import utils
-from ..._compat import numpy_compat as npc
+from . import scalar_eval as se
+
+def _is_glsl_backend() -> bool:
+ return utils.codegen_backend().name == "glsl"
+
+def _is_float64_dtype(var_type: dtypes.dtype) -> bool:
+ if dtypes.is_scalar(var_type):
+ return var_type == dtypes.float64
+
+ if dtypes.is_vector(var_type):
+ return var_type.scalar == dtypes.float64
+
+ return False
+
+def _float64_to_float32_dtype(var_type: dtypes.dtype) -> dtypes.dtype:
+ if var_type == dtypes.float64:
+ return dtypes.float32
+
+ if dtypes.is_vector(var_type) and var_type.scalar == dtypes.float64:
+ return dtypes.to_vector(dtypes.float32, var_type.child_count)
+
+ raise TypeError(f"Unsupported fp64 fallback dtype: {var_type}")
+
+def _needs_glsl_float64_trig_fallback(var_type: dtypes.dtype) -> bool:
+ return _is_glsl_backend() and _is_float64_dtype(var_type)
+
+def process_float_var(var: ShaderVariable) -> bool:
+ pass
+
+def _unary_math_var(func_name: str, var: ShaderVariable) -> ShaderVariable:
+ result_type = utils.dtype_to_floating(var.var_type)
+ expr_arg_type = result_type
+ expr_arg = var.resolve()
+ expr_result_type = result_type
+
+ if _needs_glsl_float64_trig_fallback(result_type) and func_name in {"exp", "exp2", "log", "log2"}:
+ expr_arg_type = _float64_to_float32_dtype(result_type)
+ expr_result_type = expr_arg_type
+ expr_arg = utils.backend_constructor_from_resolved(expr_arg_type, [expr_arg])
+
+ expr = utils.codegen_backend().unary_math_expr(func_name, expr_result_type, expr_arg)
+
+ if expr_result_type != result_type:
+ expr = utils.backend_constructor_from_resolved(result_type, [expr])
+
+ return utils.new_var(
+ result_type,
+ expr,
+ parents=[var],
+ lexical_unit=True
+ )
def pow(x: Any, y: Any) -> Union[ShaderVariable, float]:
if utils.is_number(y) and utils.is_number(x):
- return npc.power(x, y)
+ return se.power(x, y)
if utils.is_number(x) and isinstance(y, ShaderVariable):
+ result_type = utils.dtype_to_floating(y.var_type)
return utils.new_var(
- utils.dtype_to_floating(y.var_type),
- f"pow({x}, {y.resolve()})",
+ result_type,
+ utils.codegen_backend().binary_math_expr(
+ "pow",
+ dtypes.float32,
+ utils.resolve_input(x),
+ result_type,
+ y.resolve(),
+ ),
parents=[y]
)
if utils.is_number(y) and isinstance(x, ShaderVariable):
+ result_type = utils.dtype_to_floating(x.var_type)
return utils.new_var(
- utils.dtype_to_floating(x.var_type),
- f"pow({x.resolve()}, {y})",
+ result_type,
+ utils.codegen_backend().binary_math_expr(
+ "pow",
+ result_type,
+ x.resolve(),
+ dtypes.float32,
+ utils.resolve_input(y),
+ ),
parents=[x]
)
assert isinstance(y, ShaderVariable), "First argument must be a ShaderVariable or number"
assert isinstance(x, ShaderVariable), "Second argument must be a ShaderVariable or number"
+ result_type = utils.dtype_to_floating(dtypes.cross_type(x.var_type, y.var_type))
return utils.new_var(
- utils.dtype_to_floating(y.var_type),
- f"pow({x.resolve()}, {y.resolve()})",
+ result_type,
+ utils.codegen_backend().binary_math_expr(
+ "pow",
+ utils.dtype_to_floating(x.var_type),
+ x.resolve(),
+ utils.dtype_to_floating(y.var_type),
+ y.resolve(),
+ ),
parents=[y, x],
lexical_unit=True
)
def exp(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.exp(var)
+ return se.exp(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- utils.dtype_to_floating(var.var_type),
- f"exp({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("exp", var)
def exp2(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.exp2(var)
+ return se.exp2(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- utils.dtype_to_floating(var.var_type),
- f"exp2({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("exp2", var)
def log(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.log(var)
+ return se.log(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- utils.dtype_to_floating(var.var_type),
- f"log({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("log", var)
def log2(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.log2(var)
+ return se.log2(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
+ return _unary_math_var("log2", var)
- return utils.new_var(
- utils.dtype_to_floating(var.var_type),
- f"log2({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
-
+# has double
def sqrt(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.sqrt(var)
+ return se.sqrt(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
+ return _unary_math_var("sqrt", var)
- return utils.new_var(
- utils.dtype_to_floating(var.var_type),
- f"sqrt({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
-
+# has double
def inversesqrt(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return float(1.0 / npc.sqrt(var))
+ return float(1.0 / se.sqrt(var))
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
utils.mark_backend_feature("inversesqrt")
diff --git a/vkdispatch/codegen/functions/geometric.py b/vkdispatch/codegen/functions/geometric.py
index 7e6fa864..6992a8ad 100644
--- a/vkdispatch/codegen/functions/geometric.py
+++ b/vkdispatch/codegen/functions/geometric.py
@@ -3,11 +3,11 @@
from typing import Any, Union
from . import utils
-from ..._compat import numpy_compat as npc
+from . import scalar_eval as se
def length(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.abs_value(var)
+ return se.abs_value(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
@@ -20,7 +20,7 @@ def length(var: Any) -> Union[ShaderVariable, float]:
def distance(x: Any, y: Any) -> Union[ShaderVariable, float]:
if utils.is_number(y) and utils.is_number(x):
- return npc.abs_value(y - x)
+ return se.abs_value(y - x)
base_var = None
@@ -40,7 +40,7 @@ def distance(x: Any, y: Any) -> Union[ShaderVariable, float]:
def dot(x: Any, y: Any) -> Union[ShaderVariable, float]:
if utils.is_number(y) and utils.is_number(x):
- return npc.dot(x, y)
+ return se.dot(x, y)
base_var = None
diff --git a/vkdispatch/codegen/functions/registers.py b/vkdispatch/codegen/functions/registers.py
index 1aa2a622..64387ef1 100644
--- a/vkdispatch/codegen/functions/registers.py
+++ b/vkdispatch/codegen/functions/registers.py
@@ -4,7 +4,7 @@
from . import utils
-from .type_casting import to_dtype, to_complex
+from .type_casting import to_dtype, to_complex, to_complex32, to_complex64, to_complex128
def new_register(var_type: dtypes.dtype, *args, var_name: Optional[str] = None):
new_var = utils.new_var(
@@ -29,22 +29,65 @@ def new_register(var_type: dtypes.dtype, *args, var_name: Optional[str] = None):
return new_var
+def new_float16_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.float16, *args, var_name=var_name)
+
def new_float_register(*args, var_name: Optional[str] = None):
return new_register(dtypes.float32, *args, var_name=var_name)
+def new_float64_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.float64, *args, var_name=var_name)
+
+def new_int16_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.int16, *args, var_name=var_name)
+
def new_int_register(*args, var_name: Optional[str] = None):
return new_register(dtypes.int32, *args, var_name=var_name)
+def new_int64_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.int64, *args, var_name=var_name)
+
+def new_uint16_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.uint16, *args, var_name=var_name)
+
def new_uint_register(*args, var_name: Optional[str] = None):
return new_register(dtypes.uint32, *args, var_name=var_name)
-def new_complex_register(*args, var_name: Optional[str] = None):
+def new_uint64_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.uint64, *args, var_name=var_name)
+
+def _new_complex_register(var_type: dtypes.dtype, complex_ctor, *args, var_name: Optional[str] = None):
if len(args) > 0:
- true_args = (to_complex(*args),)
+ true_args = (complex_ctor(*args),)
else:
true_args = (0,)
- return new_register(dtypes.complex64, *true_args, var_name=var_name)
+ return new_register(var_type, *true_args, var_name=var_name)
+
+def new_complex_register(*args, var_name: Optional[str] = None):
+ if len(args) == 0:
+ return new_register(dtypes.complex64, 0, var_name=var_name)
+
+ complex_value = to_complex(*args)
+ return new_register(complex_value.var_type, complex_value, var_name=var_name)
+
+def new_complex32_register(*args, var_name: Optional[str] = None):
+ return _new_complex_register(dtypes.complex32, to_complex32, *args, var_name=var_name)
+
+def new_complex64_register(*args, var_name: Optional[str] = None):
+ return _new_complex_register(dtypes.complex64, to_complex64, *args, var_name=var_name)
+
+def new_complex128_register(*args, var_name: Optional[str] = None):
+ return _new_complex_register(dtypes.complex128, to_complex128, *args, var_name=var_name)
+
+def new_hvec2_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.hvec2, *args, var_name=var_name)
+
+def new_hvec3_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.hvec3, *args, var_name=var_name)
+
+def new_hvec4_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.hvec4, *args, var_name=var_name)
def new_vec2_register(*args, var_name: Optional[str] = None):
return new_register(dtypes.vec2, *args, var_name=var_name)
@@ -55,6 +98,33 @@ def new_vec3_register(*args, var_name: Optional[str] = None):
def new_vec4_register(*args, var_name: Optional[str] = None):
return new_register(dtypes.vec4, *args, var_name=var_name)
+def new_dvec2_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.dvec2, *args, var_name=var_name)
+
+def new_dvec3_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.dvec3, *args, var_name=var_name)
+
+def new_dvec4_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.dvec4, *args, var_name=var_name)
+
+def new_ihvec2_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.ihvec2, *args, var_name=var_name)
+
+def new_ihvec3_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.ihvec3, *args, var_name=var_name)
+
+def new_ihvec4_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.ihvec4, *args, var_name=var_name)
+
+def new_uhvec2_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.uhvec2, *args, var_name=var_name)
+
+def new_uhvec3_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.uhvec3, *args, var_name=var_name)
+
+def new_uhvec4_register(*args, var_name: Optional[str] = None):
+ return new_register(dtypes.uhvec4, *args, var_name=var_name)
+
def new_uvec2_register(*args, var_name: Optional[str] = None):
return new_register(dtypes.uvec2, *args, var_name=var_name)
diff --git a/vkdispatch/codegen/functions/scalar_eval.py b/vkdispatch/codegen/functions/scalar_eval.py
new file mode 100644
index 00000000..5d406ba2
--- /dev/null
+++ b/vkdispatch/codegen/functions/scalar_eval.py
@@ -0,0 +1,194 @@
+from __future__ import annotations
+
+import builtins
+import math
+import struct
+
+from typing import Any, Sequence, Tuple
+
+
+def sign(value: float) -> float:
+ if value > 0:
+ return 1.0
+ if value < 0:
+ return -1.0
+ return 0.0
+
+
+def floor(value: float) -> float:
+ return float(math.floor(value))
+
+
+def ceil(value: float) -> float:
+ return float(math.ceil(value))
+
+
+def trunc(value: float) -> float:
+ return float(math.trunc(value))
+
+
+def round(value: float) -> float:
+ return float(builtins.round(value))
+
+
+def abs_value(value: Any) -> float:
+ return float(abs(value))
+
+
+def mod(x: float, y: float) -> float:
+ return float(x % y)
+
+
+def modf(x: float, _unused: Any = None) -> Tuple[float, float]:
+ frac, whole = math.modf(x)
+ return float(frac), float(whole)
+
+
+def minimum(x: float, y: float) -> float:
+ return float(x if x <= y else y)
+
+
+def maximum(x: float, y: float) -> float:
+ return float(x if x >= y else y)
+
+
+def clip(x: float, min_value: float, max_value: float) -> float:
+ return float(min(max(x, min_value), max_value))
+
+
+def interp(x: float, xp: Sequence[float], fp: Sequence[float]) -> float:
+ if len(xp) != len(fp):
+ raise ValueError("xp and fp must have the same length")
+ if len(xp) == 0:
+ raise ValueError("xp and fp must be non-empty")
+ if len(xp) == 1:
+ return float(fp[0])
+
+ if x <= xp[0]:
+ return float(fp[0])
+ if x >= xp[-1]:
+ return float(fp[-1])
+
+ for index in range(1, len(xp)):
+ if x <= xp[index]:
+ x0 = xp[index - 1]
+ x1 = xp[index]
+ y0 = fp[index - 1]
+ y1 = fp[index]
+
+ if x1 == x0:
+ return float(y0)
+
+ t = (x - x0) / (x1 - x0)
+ return float(y0 + t * (y1 - y0))
+
+ return float(fp[-1])
+
+
+def isnan(value: float) -> bool:
+ return math.isnan(value)
+
+
+def isinf(value: float) -> bool:
+ return math.isinf(value)
+
+
+def float_bits_to_int(value: float) -> int:
+ return int(struct.unpack("=i", struct.pack("=f", float(value)))[0])
+
+
+def float_bits_to_uint(value: float) -> int:
+ return int(struct.unpack("=I", struct.pack("=f", float(value)))[0])
+
+
+def int_bits_to_float(value: int) -> float:
+ return float(struct.unpack("=f", struct.pack("=i", int(value)))[0])
+
+
+def uint_bits_to_float(value: int) -> float:
+ return float(struct.unpack("=f", struct.pack("=I", int(value)))[0])
+
+
+def power(x: float, y: float) -> float:
+ return float(math.pow(x, y))
+
+
+def exp(value: float) -> float:
+ return float(math.exp(value))
+
+
+def exp2(value: float) -> float:
+ if hasattr(math, "exp2"):
+ return float(math.exp2(value))
+ return float(math.pow(2.0, value))
+
+
+def log(value: float) -> float:
+ return float(math.log(value))
+
+
+def log2(value: float) -> float:
+ return float(math.log2(value))
+
+
+def sqrt(value: float) -> float:
+ return float(math.sqrt(value))
+
+
+def sin(value: float) -> float:
+ return float(math.sin(value))
+
+
+def cos(value: float) -> float:
+ return float(math.cos(value))
+
+
+def tan(value: float) -> float:
+ return float(math.tan(value))
+
+
+def arcsin(value: float) -> float:
+ return float(math.asin(value))
+
+
+def arccos(value: float) -> float:
+ return float(math.acos(value))
+
+
+def arctan(value: float) -> float:
+ return float(math.atan(value))
+
+
+def arctan2(y: float, x: float) -> float:
+ return float(math.atan2(y, x))
+
+
+def sinh(value: float) -> float:
+ return float(math.sinh(value))
+
+
+def cosh(value: float) -> float:
+ return float(math.cosh(value))
+
+
+def tanh(value: float) -> float:
+ return float(math.tanh(value))
+
+
+def arcsinh(value: float) -> float:
+ return float(math.asinh(value))
+
+
+def arccosh(value: float) -> float:
+ return float(math.acosh(value))
+
+
+def arctanh(value: float) -> float:
+ return float(math.atanh(value))
+
+
+def dot(x: Any, y: Any) -> float:
+ if isinstance(x, (int, float, complex)) and isinstance(y, (int, float, complex)):
+ return float(x * y)
+
+ return float(sum(a * b for a, b in zip(x, y)))
diff --git a/vkdispatch/codegen/functions/subgroups.py b/vkdispatch/codegen/functions/subgroups.py
index 477d3f53..23f90952 100644
--- a/vkdispatch/codegen/functions/subgroups.py
+++ b/vkdispatch/codegen/functions/subgroups.py
@@ -4,25 +4,60 @@
from . import utils
def subgroup_add(arg1: ShaderVariable):
- return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_add_expr(arg1.resolve()), [arg1], lexical_unit=True)
+ return utils.new_var(
+ arg1.var_type,
+ utils.codegen_backend().subgroup_add_expr(arg1.resolve(), arg1.var_type),
+ [arg1],
+ lexical_unit=True,
+ )
def subgroup_mul(arg1: ShaderVariable):
- return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_mul_expr(arg1.resolve()), [arg1], lexical_unit=True)
+ return utils.new_var(
+ arg1.var_type,
+ utils.codegen_backend().subgroup_mul_expr(arg1.resolve(), arg1.var_type),
+ [arg1],
+ lexical_unit=True,
+ )
def subgroup_min(arg1: ShaderVariable):
- return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_min_expr(arg1.resolve()), [arg1], lexical_unit=True)
+ return utils.new_var(
+ arg1.var_type,
+ utils.codegen_backend().subgroup_min_expr(arg1.resolve(), arg1.var_type),
+ [arg1],
+ lexical_unit=True,
+ )
def subgroup_max(arg1: ShaderVariable):
- return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_max_expr(arg1.resolve()), [arg1], lexical_unit=True)
+ return utils.new_var(
+ arg1.var_type,
+ utils.codegen_backend().subgroup_max_expr(arg1.resolve(), arg1.var_type),
+ [arg1],
+ lexical_unit=True,
+ )
def subgroup_and(arg1: ShaderVariable):
- return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_and_expr(arg1.resolve()), [arg1], lexical_unit=True)
+ return utils.new_var(
+ arg1.var_type,
+ utils.codegen_backend().subgroup_and_expr(arg1.resolve(), arg1.var_type),
+ [arg1],
+ lexical_unit=True,
+ )
def subgroup_or(arg1: ShaderVariable):
- return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_or_expr(arg1.resolve()), [arg1], lexical_unit=True)
+ return utils.new_var(
+ arg1.var_type,
+ utils.codegen_backend().subgroup_or_expr(arg1.resolve(), arg1.var_type),
+ [arg1],
+ lexical_unit=True,
+ )
def subgroup_xor(arg1: ShaderVariable):
- return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_xor_expr(arg1.resolve()), [arg1], lexical_unit=True)
+ return utils.new_var(
+ arg1.var_type,
+ utils.codegen_backend().subgroup_xor_expr(arg1.resolve(), arg1.var_type),
+ [arg1],
+ lexical_unit=True,
+ )
def subgroup_elect():
return utils.new_var(dtypes.int32, utils.codegen_backend().subgroup_elect_expr(), [], lexical_unit=True)
diff --git a/vkdispatch/codegen/functions/trigonometry.py b/vkdispatch/codegen/functions/trigonometry.py
index 504f25cc..19251db1 100644
--- a/vkdispatch/codegen/functions/trigonometry.py
+++ b/vkdispatch/codegen/functions/trigonometry.py
@@ -1,233 +1,250 @@
import vkdispatch.base.dtype as dtypes
from ..variables.variables import ShaderVariable
-from typing import Any, Union
+from typing import Any, List, Union
from . import utils
-from ..._compat import numpy_compat as npc
+from . import scalar_eval as se
def dtype_to_floating(var_type: dtypes.dtype) -> dtypes.dtype:
- if var_type == dtypes.int32 or var_type == dtypes.uint32:
+ return dtypes.make_floating_dtype(var_type)
+
+def _is_glsl_backend() -> bool:
+ return utils.codegen_backend().name == "glsl"
+
+def _is_float64_dtype(var_type: dtypes.dtype) -> bool:
+ if dtypes.is_scalar(var_type):
+ return var_type == dtypes.float64
+
+ if dtypes.is_vector(var_type):
+ return var_type.scalar == dtypes.float64
+
+ return False
+
+def _float64_to_float32_dtype(var_type: dtypes.dtype) -> dtypes.dtype:
+ if var_type == dtypes.float64:
return dtypes.float32
- if var_type == dtypes.ivec2 or var_type == dtypes.uvec2:
- return dtypes.vec2
+ if dtypes.is_vector(var_type) and var_type.scalar == dtypes.float64:
+ return dtypes.to_vector(dtypes.float32, var_type.child_count)
- if var_type == dtypes.ivec3 or var_type == dtypes.uvec3:
- return dtypes.vec3
-
- if var_type == dtypes.ivec4 or var_type == dtypes.uvec4:
- return dtypes.vec4
-
- return var_type
+ raise TypeError(f"Unsupported fp64 fallback dtype: {var_type}")
-def radians(var: Any) -> Union[ShaderVariable, float]:
- if utils.is_number(var):
- return var * (3.141592653589793 / 180.0)
+def _needs_glsl_float64_trig_fallback(var_type: dtypes.dtype) -> bool:
+ return _is_glsl_backend() and _is_float64_dtype(var_type)
- assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
- utils.mark_backend_feature("radians")
+def _cast_expr(var_type: dtypes.dtype, expr: str) -> str:
+ return utils.backend_constructor_from_resolved(var_type, [expr])
+
+def _unary_math_var(func_name: str, var: ShaderVariable) -> ShaderVariable:
+ result_type = dtype_to_floating(var.var_type)
+ expr_arg_type = result_type
+ expr_arg = var.resolve()
+ expr_result_type = result_type
+
+ if _needs_glsl_float64_trig_fallback(result_type):
+ expr_arg_type = _float64_to_float32_dtype(result_type)
+ expr_result_type = expr_arg_type
+ expr_arg = _cast_expr(expr_arg_type, expr_arg)
+
+ expr = utils.codegen_backend().unary_math_expr(func_name, expr_result_type, expr_arg)
+
+ if expr_result_type != result_type:
+ expr = _cast_expr(result_type, expr)
return utils.new_var(
- dtype_to_floating(var.var_type),
- f"radians({var.resolve()})",
+ result_type,
+ expr,
parents=[var],
lexical_unit=True
)
+def _binary_math_var(
+ func_name: str,
+ result_type: dtypes.dtype,
+ lhs_type: dtypes.dtype,
+ lhs_expr: str,
+ rhs_type: dtypes.dtype,
+ rhs_expr: str,
+ parents: List[ShaderVariable],
+ *,
+ lexical_unit: bool = False,
+) -> ShaderVariable:
+ expr_result_type = result_type
+ expr_lhs_type = lhs_type
+ expr_rhs_type = rhs_type
+ expr_lhs = lhs_expr
+ expr_rhs = rhs_expr
+
+ if _needs_glsl_float64_trig_fallback(result_type):
+ expr_result_type = _float64_to_float32_dtype(result_type)
+
+ if _is_float64_dtype(lhs_type):
+ expr_lhs_type = _float64_to_float32_dtype(lhs_type)
+ expr_lhs = _cast_expr(expr_lhs_type, lhs_expr)
+
+ if _is_float64_dtype(rhs_type):
+ expr_rhs_type = _float64_to_float32_dtype(rhs_type)
+ expr_rhs = _cast_expr(expr_rhs_type, rhs_expr)
+
+ expr = utils.codegen_backend().binary_math_expr(
+ func_name,
+ expr_lhs_type,
+ expr_lhs,
+ expr_rhs_type,
+ expr_rhs,
+ )
+
+ if expr_result_type != result_type:
+ expr = _cast_expr(result_type, expr)
+
+ return utils.new_var(
+ result_type,
+ expr,
+ parents=parents,
+ lexical_unit=lexical_unit,
+ )
+
+def radians(var: Any) -> Union[ShaderVariable, float]:
+ if utils.is_number(var):
+ return var * (3.141592653589793 / 180.0)
+
+ assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
+ utils.mark_backend_feature("radians")
+ return _unary_math_var("radians", var)
+
def degrees(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
return var * (180.0 / 3.141592653589793)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
utils.mark_backend_feature("degrees")
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"degrees({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("degrees", var)
def sin(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.sin(var)
+ return se.sin(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"sin({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("sin", var)
def cos(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.cos(var)
+ return se.cos(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"cos({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("cos", var)
def tan(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.tan(var)
+ return se.tan(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"tan({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("tan", var)
def asin(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.arcsin(var)
+ return se.arcsin(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"asin({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("asin", var)
def acos(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.arccos(var)
+ return se.arccos(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"acos({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("acos", var)
def atan(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.arctan(var)
+ return se.arctan(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"atan({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("atan", var)
def atan2(y: Any, x: Any) -> Union[ShaderVariable, float]:
if utils.is_number(y) and utils.is_number(x):
- return npc.arctan2(y, x)
+ return se.arctan2(y, x)
if utils.is_number(x) and isinstance(y, ShaderVariable):
- return utils.new_var(
- dtype_to_floating(y.var_type),
- f"atan({y.resolve()}, {x})",
- parents=[y]
+ result_type = dtype_to_floating(y.var_type)
+ scalar_result_type = result_type.scalar if dtypes.is_vector(result_type) else result_type
+ return _binary_math_var(
+ "atan2",
+ result_type,
+ result_type,
+ y.resolve(),
+ scalar_result_type,
+ utils.resolve_input(x),
+ [y],
)
if utils.is_number(y) and isinstance(x, ShaderVariable):
- return utils.new_var(
- dtype_to_floating(x.var_type),
- f"atan({y}, {x.resolve()})",
- parents=[x]
+ result_type = dtype_to_floating(x.var_type)
+ scalar_result_type = result_type.scalar if dtypes.is_vector(result_type) else result_type
+ return _binary_math_var(
+ "atan2",
+ result_type,
+ scalar_result_type,
+ utils.resolve_input(y),
+ result_type,
+ x.resolve(),
+ [x],
)
assert isinstance(y, ShaderVariable), "First argument must be a ShaderVariable or number"
assert isinstance(x, ShaderVariable), "Second argument must be a ShaderVariable or number"
- return utils.new_var(
- dtype_to_floating(y.var_type),
- f"atan({y.resolve()}, {x.resolve()})",
- parents=[y, x],
- lexical_unit=True
+ result_type = dtype_to_floating(dtypes.cross_type(y.var_type, x.var_type))
+ return _binary_math_var(
+ "atan2",
+ result_type,
+ result_type,
+ y.resolve(),
+ dtype_to_floating(x.var_type),
+ x.resolve(),
+ [y, x],
+ lexical_unit=True,
)
def sinh(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.sinh(var)
+ return se.sinh(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"sinh({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("sinh", var)
def cosh(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.cosh(var)
+ return se.cosh(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"cosh({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("cosh", var)
def tanh(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.tanh(var)
+ return se.tanh(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"tanh({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("tanh", var)
def asinh(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.arcsinh(var)
+ return se.arcsinh(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"asinh({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("asinh", var)
def acosh(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.arccosh(var)
+ return se.arccosh(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"acosh({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("acosh", var)
def atanh(var: Any) -> Union[ShaderVariable, float]:
if utils.is_number(var):
- return npc.arctanh(var)
+ return se.arctanh(var)
assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number"
-
- return utils.new_var(
- dtype_to_floating(var.var_type),
- f"atanh({var.resolve()})",
- parents=[var],
- lexical_unit=True
- )
+ return _unary_math_var("atanh", var)
diff --git a/vkdispatch/codegen/functions/type_casting.py b/vkdispatch/codegen/functions/type_casting.py
index d70d894f..276a479a 100644
--- a/vkdispatch/codegen/functions/type_casting.py
+++ b/vkdispatch/codegen/functions/type_casting.py
@@ -2,6 +2,7 @@
from typing import Optional
from . import utils
+from ..variables.variables import ShaderVariable
def to_dtype(var_type: dtypes.dtype, *args):
return utils.new_var(
@@ -26,22 +27,88 @@ def str_to_dtype(var_type: dtypes.dtype,
register=register
)
+def to_float16(*args):
+ return to_dtype(dtypes.float16, *args)
+
def to_float(*args):
return to_dtype(dtypes.float32, *args)
+def to_float64(*args):
+ return to_dtype(dtypes.float64, *args)
+
+def to_int16(*args):
+ return to_dtype(dtypes.int16, *args)
+
def to_int(*args):
return to_dtype(dtypes.int32, *args)
+def to_int64(*args):
+ return to_dtype(dtypes.int64, *args)
+
+def to_uint16(*args):
+ return to_dtype(dtypes.uint16, *args)
+
def to_uint(*args):
return to_dtype(dtypes.uint32, *args)
-def to_complex(*args):
+def to_uint64(*args):
+ return to_dtype(dtypes.uint64, *args)
+
+def _complex_from_real_arg(arg) -> dtypes.dtype:
+ if isinstance(arg, ShaderVariable):
+ if dtypes.is_complex(arg.var_type):
+ return arg.var_type
+ if dtypes.is_scalar(arg.var_type):
+ return dtypes.complex_from_float(dtypes.make_floating_dtype(arg.var_type))
+ raise TypeError(f"Unsupported variable type for complex conversion: {arg.var_type}")
+
+ if utils.is_number(arg):
+ base_type = utils.number_to_dtype(arg)
+ if dtypes.is_complex(base_type):
+ return base_type
+ return dtypes.complex_from_float(dtypes.make_floating_dtype(base_type))
+
+ raise TypeError(f"Unsupported argument type for complex conversion: {type(arg)}")
+
+def _infer_complex_dtype(*args) -> dtypes.dtype:
+ complex_type = _complex_from_real_arg(args[0])
+
+ for arg in args[1:]:
+ complex_type = dtypes.cross_type(complex_type, _complex_from_real_arg(arg))
+
+ return complex_type
+
+def _to_complex_dtype(var_type: dtypes.dtype, *args):
assert len(args) == 1 or len(args) == 2, "Must give one of two arguments for complex init"
+ if len(args) == 1 and isinstance(args[0], ShaderVariable) and dtypes.is_complex(args[0].var_type):
+ return to_dtype(var_type, args[0])
+
if len(args) == 1:
- return to_dtype(dtypes.complex64, args[0], 0)
+ return to_dtype(var_type, args[0], 0)
+
+ return to_dtype(var_type, *args)
+
+def to_complex32(*args):
+ return _to_complex_dtype(dtypes.complex32, *args)
+
+def to_complex(*args):
+ return _to_complex_dtype(_infer_complex_dtype(*args), *args)
+
+def to_complex64(*args):
+ return _to_complex_dtype(dtypes.complex64, *args)
+
+def to_complex128(*args):
+ return _to_complex_dtype(dtypes.complex128, *args)
+
+def to_hvec2(*args):
+ return to_dtype(dtypes.hvec2, *args)
+
+def to_hvec3(*args):
+ return to_dtype(dtypes.hvec3, *args)
- return to_dtype(dtypes.complex64, *args)
+def to_hvec4(*args):
+ return to_dtype(dtypes.hvec4, *args)
def to_vec2(*args):
return to_dtype(dtypes.vec2, *args)
@@ -52,14 +119,23 @@ def to_vec3(*args):
def to_vec4(*args):
return to_dtype(dtypes.vec4, *args)
-def to_uvec2(*args):
- return to_dtype(dtypes.uvec2, *args)
+def to_dvec2(*args):
+ return to_dtype(dtypes.dvec2, *args)
-def to_uvec3(*args):
- return to_dtype(dtypes.uvec3, *args)
+def to_dvec3(*args):
+ return to_dtype(dtypes.dvec3, *args)
-def to_uvec4(*args):
- return to_dtype(dtypes.uvec4, *args)
+def to_dvec4(*args):
+ return to_dtype(dtypes.dvec4, *args)
+
+def to_ihvec2(*args):
+ return to_dtype(dtypes.ihvec2, *args)
+
+def to_ihvec3(*args):
+ return to_dtype(dtypes.ihvec3, *args)
+
+def to_ihvec4(*args):
+ return to_dtype(dtypes.ihvec4, *args)
def to_ivec2(*args):
return to_dtype(dtypes.ivec2, *args)
@@ -70,6 +146,24 @@ def to_ivec3(*args):
def to_ivec4(*args):
return to_dtype(dtypes.ivec4, *args)
+def to_uhvec2(*args):
+ return to_dtype(dtypes.uhvec2, *args)
+
+def to_uhvec3(*args):
+ return to_dtype(dtypes.uhvec3, *args)
+
+def to_uhvec4(*args):
+ return to_dtype(dtypes.uhvec4, *args)
+
+def to_uvec2(*args):
+ return to_dtype(dtypes.uvec2, *args)
+
+def to_uvec3(*args):
+ return to_dtype(dtypes.uvec3, *args)
+
+def to_uvec4(*args):
+ return to_dtype(dtypes.uvec4, *args)
+
def to_mat2(*args):
return to_dtype(dtypes.mat2, *args)
diff --git a/vkdispatch/codegen/functions/utils.py b/vkdispatch/codegen/functions/utils.py
index 85879d48..ddb866fb 100644
--- a/vkdispatch/codegen/functions/utils.py
+++ b/vkdispatch/codegen/functions/utils.py
@@ -1,6 +1,6 @@
import vkdispatch.base.dtype as dtypes
from ..variables.variables import ShaderVariable
-from typing import List
+from typing import List, Optional
from .base_functions.base_utils import *
from ..global_builder import get_codegen_backend
@@ -24,11 +24,33 @@ def mark_backend_feature(feature_name: str) -> None:
def backend_type_name(var_type: dtypes.dtype) -> str:
return codegen_backend().type_name(var_type)
+def _resolve_arg_types(args: tuple) -> List[Optional[dtypes.dtype]]:
+ resolved_types: List[Optional[dtypes.dtype]] = []
+
+ for elem in args:
+ if isinstance(elem, ShaderVariable):
+ resolved_types.append(elem.var_type)
+ continue
+
+ if is_number(elem):
+ resolved_types.append(number_to_dtype(elem))
+ continue
+
+ resolved_types.append(None)
+
+ return resolved_types
+
def backend_constructor(var_type: dtypes.dtype, *args) -> str:
+ resolved_types = _resolve_arg_types(args)
return codegen_backend().constructor(
var_type,
- [resolve_input(elem) for elem in args]
+ [resolve_input(elem) for elem in args],
+ arg_types=resolved_types,
)
-def backend_constructor_from_resolved(var_type: dtypes.dtype, args: List[str]) -> str:
- return codegen_backend().constructor(var_type, args)
+def backend_constructor_from_resolved(
+ var_type: dtypes.dtype,
+ args: List[str],
+ arg_types: Optional[List[Optional[dtypes.dtype]]] = None,
+) -> str:
+ return codegen_backend().constructor(var_type, args, arg_types=arg_types)
diff --git a/vkdispatch/codegen/global_builder.py b/vkdispatch/codegen/global_builder.py
index 204cd425..8a14b1b9 100644
--- a/vkdispatch/codegen/global_builder.py
+++ b/vkdispatch/codegen/global_builder.py
@@ -1,7 +1,8 @@
import threading
import vkdispatch.base.dtype as dtypes
from .shader_writer import set_shader_writer
-from .backends import CodeGenBackend, GLSLBackend, CUDABackend
+from .backends import CodeGenBackend, GLSLBackend, CUDABackend, OpenCLBackend
+from vkdispatch.base.init import is_cuda, is_opencl
from typing import Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
@@ -11,16 +12,12 @@
_shader_print_line_numbers = threading.local()
_codegen_backend = threading.local()
-
def _make_runtime_default_codegen_backend() -> CodeGenBackend:
- try:
- from vkdispatch.base.backend import BACKEND_PYCUDA, get_active_backend_name
+ if is_cuda():
+ return CUDABackend()
- if get_active_backend_name() == BACKEND_PYCUDA:
- return CUDABackend()
- except Exception:
- # If runtime backend metadata is unavailable, fall back to GLSL.
- pass
+ if is_opencl():
+ return OpenCLBackend()
return GLSLBackend()
@@ -52,6 +49,10 @@ def set_codegen_backend(backend: Optional[Union[CodeGenBackend, str]]):
_codegen_backend.active_backend = CUDABackend()
return
+ if backend_name == "opencl":
+ _codegen_backend.active_backend = OpenCLBackend()
+ return
+
raise ValueError(f"Unknown codegen backend '{backend}'")
_codegen_backend.active_backend = backend
diff --git a/vkdispatch/codegen/variables/bound_variables.py b/vkdispatch/codegen/variables/bound_variables.py
index 5c6a25e4..228ff299 100644
--- a/vkdispatch/codegen/variables/bound_variables.py
+++ b/vkdispatch/codegen/variables/bound_variables.py
@@ -2,6 +2,7 @@
import vkdispatch.base.dtype as dtypes
from ..functions import type_casting
+from ..functions.base_functions import base_utils
from ..global_builder import get_codegen_backend
from typing import Callable, Optional
@@ -21,14 +22,19 @@ def __init__(self,
class BufferVariable(BoundVariable):
read_lambda: Callable[[], None]
write_lambda: Callable[[], None]
+ scalar_expr: Optional[str]
+ codegen_backend: Optional[object]
def __init__(self,
var_type: dtypes.dtype,
binding: int,
name: str,
shape_var: "ShaderVariable" = None,
+ shape_var_factory: Optional[Callable[[], "ShaderVariable"]] = None,
shape_name: Optional[str] = None,
raw_name: Optional[str] = None,
+ scalar_expr: Optional[str] = None,
+ codegen_backend: Optional[object] = None,
read_lambda: Callable[[], None] = None,
write_lambda: Callable[[], None] = None,
) -> None:
@@ -41,17 +47,62 @@ def __init__(self,
self.read_lambda = read_lambda
self.write_lambda = write_lambda
- self.shape = shape_var
+ self._shape_var = shape_var
+ self._shape_var_factory = shape_var_factory
self.shape_name = shape_name
+ self.scalar_expr = scalar_expr
+ self.codegen_backend = codegen_backend
self.can_index = True
self.use_child_type = False
+ @property
+ def shape(self) -> "ShaderVariable":
+ if self._shape_var is None:
+ assert self._shape_var_factory is not None, "Buffer shape variable factory is not available!"
+ self._shape_var = self._shape_var_factory()
+
+ return self._shape_var
+
def read_callback(self):
self.read_lambda()
def write_callback(self):
self.write_lambda()
+ def __getitem__(self, index) -> "ShaderVariable":
+ assert self.can_index, f"Variable '{self.resolve()}' of type '{self.var_type.name}' cannot be indexed into!"
+
+ return_type = self.var_type.child_type if self.use_child_type else self.var_type
+
+ if isinstance(index, tuple):
+ assert len(index) == 1, "Only single index is supported, cannot use multi-dimentional indexing!"
+ index = index[0]
+
+ if base_utils.is_int_number(index):
+ return ShaderVariable(
+ return_type,
+ f"{self.resolve()}[{index}]",
+ parents=[self],
+ settable=self.settable,
+ lexical_unit=True,
+ buffer_root=self,
+ buffer_index_expr=str(index),
+ )
+
+ assert isinstance(index, ShaderVariable), f"Index must be a ShaderVariable or int type, not {type(index)}!"
+ assert dtypes.is_scalar(index.var_type), "Indexing variable must be a scalar!"
+ assert dtypes.is_integer_dtype(index.var_type), "Indexing variable must be an integer type!"
+
+ return ShaderVariable(
+ return_type,
+ f"{self.resolve()}[{index.resolve()}]",
+ parents=[self, index],
+ settable=self.settable,
+ lexical_unit=True,
+ buffer_root=self,
+ buffer_index_expr=index.resolve(),
+ )
+
class ImageVariable(BoundVariable):
dimensions: int = 0
read_lambda: Callable[[], None]
@@ -89,14 +140,27 @@ def sample(self, coord: "ShaderVariable", lod: "ShaderVariable" = None) -> "Shad
if self.dimensions == 1:
sample_coord_string = f"((({coord.resolve()}) + 0.5) / {backend.texture_size_expr(self.resolve(), 0, self.dimensions)})"
elif self.dimensions == 2:
- coord_expr = backend.constructor(dtypes.vec2, [f"{coord.resolve()}.x", f"{coord.resolve()}.y"])
+ coord_expr = backend.constructor(
+ dtypes.vec2,
+ [
+ backend.component_access_expr(coord.resolve(), "x", coord.var_type),
+ backend.component_access_expr(coord.resolve(), "y", coord.var_type),
+ ]
+ )
tex_size_expr = backend.constructor(
dtypes.vec2,
[backend.texture_size_expr(self.resolve(), 0, self.dimensions)]
)
sample_coord_string = f"(({coord_expr} + 0.5) / {tex_size_expr})"
elif self.dimensions == 3:
- coord_expr = backend.constructor(dtypes.vec3, [f"{coord.resolve()}.x", f"{coord.resolve()}.y", f"{coord.resolve()}.z"])
+ coord_expr = backend.constructor(
+ dtypes.vec3,
+ [
+ backend.component_access_expr(coord.resolve(), "x", coord.var_type),
+ backend.component_access_expr(coord.resolve(), "y", coord.var_type),
+ backend.component_access_expr(coord.resolve(), "z", coord.var_type),
+ ]
+ )
tex_size_expr = backend.constructor(
dtypes.vec3,
[backend.texture_size_expr(self.resolve(), 0, self.dimensions)]
diff --git a/vkdispatch/codegen/variables/variables.py b/vkdispatch/codegen/variables/variables.py
index 3bebd883..e8e776ee 100644
--- a/vkdispatch/codegen/variables/variables.py
+++ b/vkdispatch/codegen/variables/variables.py
@@ -13,24 +13,14 @@
ENABLE_SCALED_AND_OFFSET_INT = True
def var_types_to_floating(var_type: dtypes.dtype) -> dtypes.dtype:
- if var_type == dtypes.int32 or var_type == dtypes.uint32:
- return dtypes.float32
-
- if var_type == dtypes.ivec2 or var_type == dtypes.uvec2:
- return dtypes.vec2
-
- if var_type == dtypes.ivec3 or var_type == dtypes.uvec3:
- return dtypes.vec3
-
- if var_type == dtypes.ivec4 or var_type == dtypes.uvec4:
- return dtypes.vec4
-
- return var_type
+ return dtypes.make_floating_dtype(var_type)
class ShaderVariable(BaseVariable):
_initilized: bool
is_complex: bool
is_conjugate: Optional[bool]
+ buffer_root: Optional["ShaderVariable"]
+ buffer_index_expr: Optional[str]
def __init__(self,
var_type: dtypes.dtype,
@@ -40,7 +30,9 @@ def __init__(self,
settable: bool = False,
register: bool = False,
parents: List["ShaderVariable"] = None,
- is_conjugate: bool = False
+ is_conjugate: bool = False,
+ buffer_root: Optional["ShaderVariable"] = None,
+ buffer_index_expr: Optional[str] = None,
) -> None:
super().__setattr__("_initilized", False)
@@ -56,6 +48,8 @@ def __init__(self,
self.is_complex = False
self.is_conjugate = None
+ self.buffer_root = buffer_root
+ self.buffer_index_expr = buffer_index_expr
if dtypes.is_complex(self.var_type):
self.can_index = True
@@ -80,6 +74,28 @@ def __init__(self,
self._initilized = True
+ def _buffer_component_expr(self, component_index_expr: str) -> Optional[str]:
+ if self.buffer_root is None or self.buffer_index_expr is None:
+ return None
+
+ if not (dtypes.is_vector(self.var_type) or dtypes.is_complex(self.var_type)):
+ return None
+
+ scalar_expr = getattr(self.buffer_root, "scalar_expr", None)
+ if scalar_expr is None:
+ return None
+
+ backend = getattr(self.buffer_root, "codegen_backend", None)
+ if backend is None:
+ backend = get_codegen_backend()
+
+ return backend.buffer_component_expr(
+ scalar_expr,
+ self.var_type,
+ self.buffer_index_expr,
+ component_index_expr,
+ )
+
def __getitem__(self, index) -> "ShaderVariable":
assert self.can_index, f"Variable '{self.resolve()}' of type '{self.var_type.name}' cannot be indexed into!"
@@ -90,11 +106,31 @@ def __getitem__(self, index) -> "ShaderVariable":
index = index[0]
if base_utils.is_int_number(index):
+ component_expr = self._buffer_component_expr(str(index))
+ if component_expr is not None:
+ return ShaderVariable(
+ return_type,
+ component_expr,
+ parents=[self],
+ settable=self.settable,
+ lexical_unit=True
+ )
+
return ShaderVariable(return_type, f"{self.resolve()}[{index}]", parents=[self], settable=self.settable, lexical_unit=True)
assert isinstance(index, ShaderVariable), f"Index must be a ShaderVariable or int type, not {type(index)}!"
assert dtypes.is_scalar(index.var_type), "Indexing variable must be a scalar!"
assert dtypes.is_integer_dtype(index.var_type), "Indexing variable must be an integer type!"
+
+ component_expr = self._buffer_component_expr(index.resolve())
+ if component_expr is not None:
+ return ShaderVariable(
+ return_type,
+ component_expr,
+ parents=[self, index],
+ settable=self.settable,
+ lexical_unit=True
+ )
return ShaderVariable(return_type, f"{self.resolve()}[{index.resolve()}]", parents=[self, index], settable=self.settable, lexical_unit=True)
@@ -109,15 +145,18 @@ def swizzle(self, components: str) -> "ShaderVariable":
sample_type = self.var_type if dtypes.is_scalar(self.var_type) else self.var_type.child_type
return_type = sample_type if len(components) == 1 else dtypes.to_vector(sample_type, len(components))
+ backend = get_codegen_backend()
+ base_expr = self.resolve()
if dtypes.is_scalar(self.var_type):
assert all(c == 'x' for c in components), f"Cannot swizzle scalar variable '{self.resolve()}' with components other than 'x'!"
- swizzle_expr = f"{self.resolve()}.x"
+ scalar_x_expr = backend.component_access_expr(base_expr, "x", self.var_type)
+ swizzle_expr = scalar_x_expr
if len(components) > 1:
- swizzle_expr = get_codegen_backend().constructor(
+ swizzle_expr = backend.constructor(
return_type,
- [f"{self.resolve()}.x" for _ in components]
+ [scalar_x_expr for _ in components]
)
return ShaderVariable(
@@ -125,8 +164,8 @@ def swizzle(self, components: str) -> "ShaderVariable":
name=swizzle_expr,
parents=[self],
lexical_unit=True,
- settable=self.settable,
- register=self.register
+ settable=self.settable and len(components) == 1,
+ register=self.register and len(components) == 1
)
if self.var_type.shape[0] < 4:
@@ -138,11 +177,24 @@ def swizzle(self, components: str) -> "ShaderVariable":
if self.var_type.shape[0] < 2:
assert 'y' not in components, f"Cannot swizzle variable '{self.resolve()}' of type '{self.var_type.name}' with component 'y'!"
- swizzle_expr = f"{self.resolve()}.{components}"
+ if len(components) == 1:
+ component_index = "xyzw".index(components)
+ component_expr = self._buffer_component_expr(str(component_index))
+ if component_expr is not None:
+ return ShaderVariable(
+ var_type=return_type,
+ name=component_expr,
+ parents=[self],
+ lexical_unit=True,
+ settable=self.settable,
+ register=self.register
+ )
+
+ swizzle_expr = backend.component_access_expr(base_expr, components, self.var_type)
if len(components) > 1:
- swizzle_expr = get_codegen_backend().constructor(
+ swizzle_expr = backend.constructor(
return_type,
- [f"{self.resolve()}.{elem}" for elem in components]
+ [backend.component_access_expr(base_expr, elem, self.var_type) for elem in components]
)
return ShaderVariable(
@@ -150,8 +202,8 @@ def swizzle(self, components: str) -> "ShaderVariable":
name=swizzle_expr,
parents=[self],
lexical_unit=True,
- settable=self.settable,
- register=self.register
+ settable=self.settable and len(components) == 1,
+ register=self.register and len(components) == 1
)
def conjugate(self) -> "ShaderVariable":
@@ -175,16 +227,19 @@ def set_value(self, value: "ShaderVariable") -> None:
self.read_callback()
if base_utils.is_number(value):
- if self.var_type == dtypes.complex64:
+ if dtypes.is_complex(self.var_type):
complex_value = complex(value)
complex_constructor = get_codegen_backend().constructor(
- dtypes.complex64,
- [str(complex_value.real), str(complex_value.imag)]
+ self.var_type,
+ [
+ base_utils.format_number_literal(complex_value.real),
+ base_utils.format_number_literal(complex_value.imag),
+ ]
)
base_utils.append_contents(f"{self.resolve()} = {complex_constructor};\n")
return
- base_utils.append_contents(f"{self.resolve()} = {value};\n")
+ base_utils.append_contents(f"{self.resolve()} = {base_utils.format_number_literal(value)};\n")
return
assert self.var_type == value.var_type, f"Cannot set variable of type '{self.var_type.name}' to value of type '{value.var_type.name}'!"
@@ -218,6 +273,9 @@ def __setattr__(self, name: str, value: "ShaderVariable") -> "ShaderVariable":
self.imag.set_value(value)
return
+
+ if dtypes.is_complex(self.var_type) and (name == "x" or name == "y"):
+ raise ValueError(f"Cannot set attribute '{name}' of complex variable '{self.resolve()}', use 'real' and 'imag' instead!")
if dtypes.is_vector(self.var_type) and (name == "x" or name == "y" or name == "z" or name == "w"):
if name == "x":
@@ -254,7 +312,7 @@ def to_register(self, var_name: str = None) -> "ShaderVariable":
def to_dtype(self, var_type: dtypes.dtype) -> "ShaderVariable":
return base_utils.new_base_var(
var_type,
- get_codegen_backend().constructor(var_type, [self.resolve()]),
+ get_codegen_backend().constructor(var_type, [self.resolve()], arg_types=[self.var_type]),
[self],
lexical_unit=True
)
@@ -316,16 +374,18 @@ def __init__(self,
offset: int = 0,
parents: List["ShaderVariable"] = None
) -> None:
+ # ShaderVariable.__init__ eagerly creates vector swizzles (`x`, `y`, ...),
+ # which call resolve() during construction. Pre-seed these fields so
+ # ScaledAndOfftsetIntVariable.resolve() is safe before super().__init__ completes.
+ object.__setattr__(self, "base_name", str(name))
+ object.__setattr__(self, "scale", scale)
+ object.__setattr__(self, "offset", offset)
super().__init__(var_type, name, parents=parents)
-
- self.base_name = str(name)
- self.scale = scale
- self.offset = offset
def new_from_self(self, scale: int = 1, offset: int = 0):
child_vartype = self.var_type
- if isinstance(scale, float) or isinstance(offset, float):
+ if base_utils.is_float_number(scale) or base_utils.is_float_number(offset):
child_vartype = var_types_to_floating(self.var_type)
return ScaledAndOfftsetIntVariable(
@@ -337,8 +397,14 @@ def new_from_self(self, scale: int = 1, offset: int = 0):
)
def resolve(self) -> str:
- scale_str = f" * {self.scale}" if self.scale != 1 else ""
- offset_str = f" + {self.offset}" if self.offset != 0 else ""
+ scale_str = (
+ f" * {base_utils.format_number_literal(self.scale)}"
+ if self.scale != 1 else ""
+ )
+ offset_str = (
+ f" + {base_utils.format_number_literal(self.offset)}"
+ if self.offset != 0 else ""
+ )
if scale_str == "" and offset_str == "":
return self.base_name
diff --git a/vkdispatch/_compat/__init__.py b/vkdispatch/compat/__init__.py
similarity index 100%
rename from vkdispatch/_compat/__init__.py
rename to vkdispatch/compat/__init__.py
diff --git a/vkdispatch/_compat/numpy_compat.py b/vkdispatch/compat/numpy_compat.py
similarity index 56%
rename from vkdispatch/_compat/numpy_compat.py
rename to vkdispatch/compat/numpy_compat.py
index 62e9dbf9..7d42ab43 100644
--- a/vkdispatch/_compat/numpy_compat.py
+++ b/vkdispatch/compat/numpy_compat.py
@@ -48,245 +48,11 @@ def ceil(value: float) -> float:
return float(_np.ceil(value))
return float(math.ceil(value))
-
-def floor(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.floor(value))
- return float(math.floor(value))
-
-
-def trunc(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.trunc(value))
- return float(math.trunc(value))
-
-
def round(value: float) -> float:
if HAS_NUMPY:
return float(_np.round(value))
return float(builtins.round(value))
-
-def sign(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.sign(value))
-
- if value > 0:
- return 1.0
- if value < 0:
- return -1.0
- return 0.0
-
-
-def abs_value(value: Any) -> float:
- if HAS_NUMPY:
- return float(_np.abs(value))
- return float(abs(value))
-
-
-def minimum(x: float, y: float) -> float:
- if HAS_NUMPY:
- return float(_np.minimum(x, y))
- return float(x if x <= y else y)
-
-
-def maximum(x: float, y: float) -> float:
- if HAS_NUMPY:
- return float(_np.maximum(x, y))
- return float(x if x >= y else y)
-
-
-def clip(x: float, min_value: float, max_value: float) -> float:
- if HAS_NUMPY:
- return float(_np.clip(x, min_value, max_value))
- return float(min(max(x, min_value), max_value))
-
-
-def mod(x: float, y: float) -> float:
- if HAS_NUMPY:
- return float(_np.mod(x, y))
- return float(x % y)
-
-
-def modf(x: float, _unused: Any = None) -> Tuple[float, float]:
- if HAS_NUMPY:
- frac, whole = _np.modf(x)
- return float(frac), float(whole)
-
- frac, whole = math.modf(x)
- return float(frac), float(whole)
-
-
-def interp(x: float, xp: Sequence[float], fp: Sequence[float]) -> float:
- if HAS_NUMPY:
- return float(_np.interp(x, xp, fp))
-
- if len(xp) != len(fp):
- raise ValueError("xp and fp must have the same length")
- if len(xp) == 0:
- raise ValueError("xp and fp must be non-empty")
- if len(xp) == 1:
- return float(fp[0])
-
- if x <= xp[0]:
- return float(fp[0])
- if x >= xp[-1]:
- return float(fp[-1])
-
- for index in range(1, len(xp)):
- if x <= xp[index]:
- x0 = xp[index - 1]
- x1 = xp[index]
- y0 = fp[index - 1]
- y1 = fp[index]
-
- if x1 == x0:
- return float(y0)
-
- t = (x - x0) / (x1 - x0)
- return float(y0 + t * (y1 - y0))
-
- return float(fp[-1])
-
-
-def isnan(value: float) -> bool:
- if HAS_NUMPY:
- return bool(_np.isnan(value))
- return math.isnan(value)
-
-
-def isinf(value: float) -> bool:
- if HAS_NUMPY:
- return bool(_np.isinf(value))
- return math.isinf(value)
-
-
-def power(x: float, y: float) -> float:
- if HAS_NUMPY:
- return float(_np.power(x, y))
- return float(math.pow(x, y))
-
-
-def exp(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.exp(value))
- return float(math.exp(value))
-
-
-def exp2(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.exp2(value))
- if hasattr(math, "exp2"):
- return float(math.exp2(value))
- return float(math.pow(2.0, value))
-
-
-def log(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.log(value))
- return float(math.log(value))
-
-
-def log2(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.log2(value))
- return float(math.log2(value))
-
-
-def sqrt(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.sqrt(value))
- return float(math.sqrt(value))
-
-
-def sin(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.sin(value))
- return float(math.sin(value))
-
-
-def cos(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.cos(value))
- return float(math.cos(value))
-
-
-def tan(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.tan(value))
- return float(math.tan(value))
-
-
-def arcsin(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.arcsin(value))
- return float(math.asin(value))
-
-
-def arccos(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.arccos(value))
- return float(math.acos(value))
-
-
-def arctan(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.arctan(value))
- return float(math.atan(value))
-
-
-def arctan2(y: float, x: float) -> float:
- if HAS_NUMPY:
- return float(_np.arctan2(y, x))
- return float(math.atan2(y, x))
-
-
-def sinh(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.sinh(value))
- return float(math.sinh(value))
-
-
-def cosh(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.cosh(value))
- return float(math.cosh(value))
-
-
-def tanh(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.tanh(value))
- return float(math.tanh(value))
-
-
-def arcsinh(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.arcsinh(value))
- return float(math.asinh(value))
-
-
-def arccosh(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.arccosh(value))
- return float(math.acosh(value))
-
-
-def arctanh(value: float) -> float:
- if HAS_NUMPY:
- return float(_np.arctanh(value))
- return float(math.atanh(value))
-
-
-def dot(x: Any, y: Any) -> float:
- if HAS_NUMPY:
- return float(_np.dot(x, y))
-
- if isinstance(x, (int, float, complex)) and isinstance(y, (int, float, complex)):
- return float(x * y)
-
- return float(sum(a * b for a, b in zip(x, y)))
-
-
def angle(value: complex) -> float:
if HAS_NUMPY:
return float(_np.angle(value))
@@ -319,16 +85,32 @@ class HostDType:
kind: str
+INT16 = HostDType("int16", 2, "h", "int")
+UINT16 = HostDType("uint16", 2, "H", "uint")
INT32 = HostDType("int32", 4, "i", "int")
UINT32 = HostDType("uint32", 4, "I", "uint")
+INT64 = HostDType("int64", 8, "q", "int")
+UINT64 = HostDType("uint64", 8, "Q", "uint")
+FLOAT16 = HostDType("float16", 2, "e", "float")
FLOAT32 = HostDType("float32", 4, "f", "float")
+FLOAT64 = HostDType("float64", 8, "d", "float")
+COMPLEX32 = HostDType("complex32", 4, "ee", "complex")
COMPLEX64 = HostDType("complex64", 8, "ff", "complex")
+COMPLEX128 = HostDType("complex128", 16, "dd", "complex")
_HOST_DTYPES = {
+ "int16": INT16,
+ "uint16": UINT16,
"int32": INT32,
"uint32": UINT32,
+ "int64": INT64,
+ "uint64": UINT64,
+ "float16": FLOAT16,
"float32": FLOAT32,
+ "float64": FLOAT64,
+ "complex32": COMPLEX32,
"complex64": COMPLEX64,
+ "complex128": COMPLEX128,
}
@@ -355,6 +137,16 @@ def host_dtype_name(dtype: Any) -> str:
raise ValueError(f"Unsupported dtype ({dtype})!")
+def _numpy_dtype_or_none(dtype_name: str):
+ if not HAS_NUMPY:
+ return None
+
+ try:
+ return _np.dtype(dtype_name)
+ except TypeError:
+ return None
+
+
def dtype_itemsize(dtype: Any) -> int:
if isinstance(dtype, HostDType):
return dtype.itemsize
@@ -455,7 +247,13 @@ def from_buffer(buffer: bytes, dtype: Any, shape: Tuple[int, ...]):
dtype_name = host_dtype_name(dtype)
if HAS_NUMPY:
- return _np.frombuffer(buffer, dtype=_np.dtype(dtype_name)).reshape(shape)
+ np_dtype = _numpy_dtype_or_none(dtype_name)
+ if np_dtype is not None:
+ return _np.frombuffer(buffer, dtype=np_dtype).reshape(shape)
+
+ if dtype_name == "complex32":
+ half_pairs = _np.frombuffer(buffer, dtype=_np.float16).reshape(*shape, 2)
+ return half_pairs[..., 0].astype(_np.float32) + (1j * half_pairs[..., 1].astype(_np.float32))
return CompatArray(buffer, host_dtype(dtype_name), tuple(shape))
@@ -516,16 +314,19 @@ def pack_values(values: Sequence[Any], dtype: Any) -> bytes:
dtype_name = host_dtype_name(dtype)
if HAS_NUMPY:
- array = _np.asarray(values_list, dtype=_np.dtype(dtype_name))
- return array.tobytes()
+ np_dtype = _numpy_dtype_or_none(dtype_name)
+ if np_dtype is not None:
+ array = _np.asarray(values_list, dtype=np_dtype)
+ return array.tobytes()
host = host_dtype(dtype_name)
if host.kind == "complex":
output = bytearray()
+ pack_fmt = "=" + host.struct_format
for value in values_list:
coerced = _coerce_scalar(value, host)
- output.extend(struct.pack("=ff", float(coerced.real), float(coerced.imag)))
+ output.extend(struct.pack(pack_fmt, float(coerced.real), float(coerced.imag)))
return bytes(output)
pack_fmt = "=" + host.struct_format
@@ -539,13 +340,16 @@ def unpack_values(data: bytes, dtype: Any) -> List[Any]:
dtype_name = host_dtype_name(dtype)
if HAS_NUMPY:
- return _np.frombuffer(data, dtype=_np.dtype(dtype_name)).tolist()
+ np_dtype = _numpy_dtype_or_none(dtype_name)
+ if np_dtype is not None:
+ return _np.frombuffer(data, dtype=np_dtype).tolist()
host = host_dtype(dtype_name)
if host.kind == "complex":
values: List[Any] = []
- for real, imag in struct.iter_unpack("=ff", data):
+ unpack_fmt = "=" + host.struct_format
+ for real, imag in struct.iter_unpack(unpack_fmt, data):
values.append(complex(real, imag))
return values
@@ -558,26 +362,3 @@ def unpack_values(data: bytes, dtype: Any) -> List[Any]:
return values
-
-def float_bits_to_int(value: float) -> int:
- if HAS_NUMPY:
- return int(_np.frombuffer(_np.float32(value).tobytes(), dtype=_np.int32)[0])
- return int(struct.unpack("=i", struct.pack("=f", float(value)))[0])
-
-
-def float_bits_to_uint(value: float) -> int:
- if HAS_NUMPY:
- return int(_np.frombuffer(_np.float32(value).tobytes(), dtype=_np.uint32)[0])
- return int(struct.unpack("=I", struct.pack("=f", float(value)))[0])
-
-
-def int_bits_to_float(value: int) -> float:
- if HAS_NUMPY:
- return float(_np.frombuffer(_np.int32(value).tobytes(), dtype=_np.float32)[0])
- return float(struct.unpack("=f", struct.pack("=i", int(value)))[0])
-
-
-def uint_bits_to_float(value: int) -> float:
- if HAS_NUMPY:
- return float(_np.frombuffer(_np.uint32(value).tobytes(), dtype=_np.float32)[0])
- return float(struct.unpack("=f", struct.pack("=I", int(value)))[0])
diff --git a/vkdispatch/execution_pipeline/buffer_builder.py b/vkdispatch/execution_pipeline/buffer_builder.py
index 43086904..01418bae 100644
--- a/vkdispatch/execution_pipeline/buffer_builder.py
+++ b/vkdispatch/execution_pipeline/buffer_builder.py
@@ -11,7 +11,7 @@
import vkdispatch as vd
import vkdispatch.codegen as vc
-from .._compat import numpy_compat as npc
+from ..compat import numpy_compat as npc
from vkdispatch.base.dtype import to_numpy_dtype
@@ -150,6 +150,12 @@ def _write_payload(self, instance_index: int, element_slice: slice, payload: byt
if len(payload) != expected_size:
raise ValueError(f"Packed value size mismatch! Expected {expected_size}, got {len(payload)}")
+ if npc.HAS_NUMPY:
+ np = npc.numpy_module()
+ row = self.backing_buffer[instance_index]
+ row[element_slice] = np.frombuffer(payload, dtype=np.uint8)
+ return
+
start = instance_index * self.instance_bytes + element_slice.start
end = start + expected_size
@@ -178,7 +184,7 @@ def _setitem_python(self, key: Tuple[str, str], value: Any) -> None:
return
# Broadcast scalar values across all instances for scalar fields.
- if not isinstance(value, (list, tuple)) and not isinstance(value, npc.CompatArray) and buffer_element.shape == (1,):
+ if not isinstance(value, (list, tuple)) and not npc.is_array_like(value) and buffer_element.shape == (1,):
payload = self._pack_single_instance_value([value], key, buffer_element)
for instance_index in range(self.instance_count):
self._write_payload(instance_index, buffer_element.memory_slice, payload)
@@ -186,7 +192,7 @@ def _setitem_python(self, key: Tuple[str, str], value: Any) -> None:
expected_element_count = npc.prod(buffer_element.shape)
- if isinstance(value, npc.CompatArray):
+ if npc.is_array_like(value):
flat_values = npc.flatten(value)
expected_total = expected_element_count * self.instance_count
@@ -224,7 +230,9 @@ def __setitem__(
if self.backing_buffer is None:
raise RuntimeError("BufferBuilder.prepare(...) must be called before assigning values")
- if npc.HAS_NUMPY:
+ buffer_element = self.element_map[key]
+
+ if npc.HAS_NUMPY and not npc.is_host_dtype(buffer_element.dtype):
self._setitem_numpy(key, value)
return
@@ -236,7 +244,7 @@ def __repr__(self) -> str:
for key, elem in self.element_map.items():
buffer_element = self.element_map[key]
- if npc.HAS_NUMPY:
+ if npc.HAS_NUMPY and not npc.is_host_dtype(buffer_element.dtype):
value = (self.backing_buffer[:, buffer_element.memory_slice]).view(buffer_element.dtype)
else:
decoded_instances = []
diff --git a/vkdispatch/execution_pipeline/command_graph.py b/vkdispatch/execution_pipeline/command_graph.py
index 13ac8d25..efdfc40f 100644
--- a/vkdispatch/execution_pipeline/command_graph.py
+++ b/vkdispatch/execution_pipeline/command_graph.py
@@ -18,6 +18,9 @@
import dataclasses
+def _runtime_supports_push_constants() -> bool:
+ return True
+
@dataclasses.dataclass
class BufferBindInfo:
"""A dataclass to hold information about a buffer binding."""
@@ -63,9 +66,10 @@ class CommandGraph(CommandList):
uniform_bindings: Any
uniform_constants_size: int
- uniform_constants_buffer: vd.Buffer
+ uniform_constants_buffer: Optional[vd.Buffer]
uniform_descriptors: List[Tuple[DescriptorSet, int, int]]
+ recorded_descriptor_sets: List[DescriptorSet]
name_to_pc_key_dict: Dict[str, List[Tuple[str, str]]]
queued_pc_values: Dict[Tuple[str, str], Any]
@@ -84,12 +88,75 @@ def __init__(self, reset_on_submit: bool = False, submit_on_record: bool = False
self.queued_pc_values = {}
self.uniform_descriptors = []
+ self.recorded_descriptor_sets = []
self._reset_on_submit = reset_on_submit
self.submit_on_record = submit_on_record
+ # Lazily allocate host-uploaded UBO backing only when needed by non-CUDA backends.
self.uniform_constants_size = 0
- self.uniform_constants_buffer = vd.Buffer(shape=(4096,), var_type=vd.uint32) # Create a base static constants buffer at size 4k bytes
+ self.uniform_constants_buffer = None
+
+ def _ensure_uniform_constants_capacity(self, uniform_word_size: int) -> None:
+ if self.uniform_constants_buffer is not None and uniform_word_size <= self.uniform_constants_size:
+ return
+
+ # Grow exponentially to reduce reallocation churn for larger UBO layouts.
+ if self.uniform_constants_size == 0:
+ self.uniform_constants_size = max(4096, uniform_word_size)
+ else:
+ self.uniform_constants_size = max(uniform_word_size, self.uniform_constants_size * 2)
+ self.uniform_constants_buffer = vd.Buffer(shape=(self.uniform_constants_size,), var_type=vd.uint32)
+
+ def _prepare_submission_state(self, instance_count: int) -> None:
+ if len(self.pc_builder.element_map) > 0 and (
+ self.pc_builder.instance_count != instance_count or not self.buffers_valid
+ ):
+
+ assert _runtime_supports_push_constants(), (
+ "Push constants not supported for backends without push-constant support "
+ "(OpenCL). Use UBO-backed variables instead."
+ )
+
+ self.pc_builder.prepare(instance_count)
+
+ for key, value in self.pc_values.items():
+ self.pc_builder[key] = value
+
+ if len(self.uniform_builder.element_map) > 0 and not self.buffers_valid:
+ self.uniform_builder.prepare(1)
+
+ for key, value in self.uniform_values.items():
+ self.uniform_builder[key] = value
+
+ uniform_word_size = (self.uniform_builder.instance_bytes + 3) // 4
+ uniform_payload = self.uniform_builder.tobytes()
+
+ if vd.is_cuda():
+ for descriptor_set, offset, size in self.uniform_descriptors:
+ descriptor_set.set_inline_uniform_payload(uniform_payload[offset:offset + size])
+ else:
+ self._ensure_uniform_constants_capacity(uniform_word_size)
+ assert self.uniform_constants_buffer is not None
+
+ for descriptor_set, offset, size in self.uniform_descriptors:
+ descriptor_set.bind_buffer(self.uniform_constants_buffer, 0, offset, size, True, write_access=False)
+
+ self.uniform_constants_buffer.write(uniform_payload)
+
+ if not self.buffers_valid:
+ self.buffers_valid = True
+
+ def prepare_for_cuda_graph_capture(self, instance_count: int = None) -> None:
+ """Initialize internal data uploads before torch CUDA graph capture.
+
+ This method performs one-time uniform/push-constant staging without submitting
+ the command list, so only kernel launches are captured by ``torch.cuda.graph``.
+ """
+ if instance_count is None:
+ instance_count = 1
+
+ self._prepare_submission_state(instance_count)
def reset(self) -> None:
"""Reset the command graph by clearing the push constant buffer and descriptor
@@ -100,15 +167,29 @@ def reset(self) -> None:
self.pc_builder.reset()
self.uniform_builder.reset()
- self.pc_values = {}
- self.uniform_values = {}
- self.name_to_pc_key_dict = {}
- self.queued_pc_values = {}
+ for descriptor_set in self.recorded_descriptor_sets:
+ descriptor_set.destroy()
+
+ self.pc_values.clear()
+ self.uniform_values.clear()
+ self.name_to_pc_key_dict.clear()
+ self.queued_pc_values.clear()
+ self.uniform_descriptors.clear()
+ self.recorded_descriptor_sets.clear()
- self.uniform_descriptors = []
self.buffers_valid = False
+
+ def _destroy(self) -> None:
+ self.reset()
+ super()._destroy()
def bind_var(self, name: str):
+ if not _runtime_supports_push_constants():
+ raise RuntimeError(
+ "CommandGraph.bind_var() is disabled for backends without push-constant "
+ "support (OpenCL). Pass Variable values directly at shader invocation."
+ )
+
def register_var(key: Tuple[str, str]):
if not name in self.name_to_pc_key_dict.keys():
self.name_to_pc_key_dict[name] = []
@@ -118,6 +199,12 @@ def register_var(key: Tuple[str, str]):
return register_var
def set_var(self, name: str, value: Any):
+ if not _runtime_supports_push_constants():
+ raise RuntimeError(
+ "CommandGraph.set_var() is disabled for backends without push-constant "
+ "support (OpenCL). Pass Variable values directly at shader invocation."
+ )
+
if name not in self.name_to_pc_key_dict.keys():
raise ValueError("Variable not bound!")
@@ -154,18 +241,36 @@ def record_shader(self,
"""
descriptor_set = DescriptorSet(plan)
+ self.recorded_descriptor_sets.append(descriptor_set)
if shader_uuid is None:
shader_uuid = shader_description.name + "_" + str(uuid.uuid4())
+ if (not _runtime_supports_push_constants()) and len(pc_values) > 0:
+ raise RuntimeError(
+ "Push-constant Variable payloads are disabled for backends without "
+ "push-constant support (OpenCL). "
+ "Variable values must be UBO-backed and provided at shader invocation."
+ )
+
if len(shader_description.pc_structure) != 0:
+ if not _runtime_supports_push_constants():
+ raise RuntimeError(
+ "Kernels should not emit push-constant layouts for backends without "
+ "push-constant support (OpenCL). Use UBO-backed variables."
+ )
self.pc_builder.register_struct(shader_uuid, shader_description.pc_structure)
-
- uniform_offset, uniform_range = self.uniform_builder.register_struct(shader_uuid, shader_description.uniform_structure)
- self.uniform_descriptors.append((descriptor_set, uniform_offset, uniform_range))
+ uniform_field_names = {elem.name for elem in shader_description.uniform_structure}
+ resolved_uniform_values: Dict[Tuple[str, str], Any] = {}
- self.uniform_values[(shader_uuid, shader_description.exec_count_name)] = [exec_limits[0], exec_limits[1], exec_limits[2], 0]
+ if shader_description.exec_count_name is not None:
+ resolved_uniform_values[(shader_uuid, shader_description.exec_count_name)] = [
+ exec_limits[0],
+ exec_limits[1],
+ exec_limits[2],
+ 0,
+ ]
for buffer_bind_info in bound_buffers:
descriptor_set.bind_buffer(
@@ -175,7 +280,8 @@ def record_shader(self,
write_access=buffer_bind_info.write_access,
)
- self.uniform_values[(shader_uuid, buffer_bind_info.shape_name)] = buffer_bind_info.buffer.shader_shape
+ if buffer_bind_info.shape_name in uniform_field_names:
+ resolved_uniform_values[(shader_uuid, buffer_bind_info.shape_name)] = buffer_bind_info.buffer.shader_shape
for sampler_bind_info in bound_samplers:
descriptor_set.bind_sampler(
@@ -186,7 +292,14 @@ def record_shader(self,
)
for key, value in uniform_values.items():
- self.uniform_values[(shader_uuid, key)] = value
+ resolved_uniform_values[(shader_uuid, key)] = value
+
+ if len(shader_description.uniform_structure) > 0:
+ uniform_offset, uniform_range = self.uniform_builder.register_struct(shader_uuid, shader_description.uniform_structure)
+ self.uniform_descriptors.append((descriptor_set, uniform_offset, uniform_range))
+
+ for key, value in resolved_uniform_values.items():
+ self.uniform_values[key] = value
for key, value in pc_values.items():
self.pc_values[(shader_uuid, key)] = value
@@ -194,11 +307,15 @@ def record_shader(self,
super().record_compute_plan(plan, descriptor_set, blocks)
self.buffers_valid = False
-
+
if self.submit_on_record:
self.submit()
-
- def submit(self, instance_count: int = None, queue_index: int = -2) -> None:
+
+ def submit(
+ self,
+ instance_count: int = None,
+ queue_index: int = -2
+ ) -> None:
"""Submit the command list to the specified device with additional data to
append to the front of the command list.
@@ -210,30 +327,8 @@ def submit(self, instance_count: int = None, queue_index: int = -2) -> None:
if instance_count is None:
instance_count = 1
-
- if len(self.pc_builder.element_map) > 0 and (
- self.pc_builder.instance_count != instance_count or not self.buffers_valid
- ):
-
- self.pc_builder.prepare(instance_count)
-
- for key, value in self.pc_values.items():
- self.pc_builder[key] = value
-
- if len(self.uniform_builder.element_map) > 0 and not self.buffers_valid:
-
- self.uniform_builder.prepare(1)
-
- for key, value in self.uniform_values.items():
- self.uniform_builder[key] = value
-
- for descriptor_set, offset, size in self.uniform_descriptors:
- descriptor_set.bind_buffer(self.uniform_constants_buffer, 0, offset, size, True, write_access=False)
- self.uniform_constants_buffer.write(self.uniform_builder.tobytes())
-
- if not self.buffers_valid:
- self.buffers_valid = True
+ self._prepare_submission_state(instance_count)
for key, val in self.queued_pc_values.items():
self.pc_builder[key] = val
@@ -243,7 +338,12 @@ def submit(self, instance_count: int = None, queue_index: int = -2) -> None:
if len(self.pc_builder.element_map) > 0:
my_data = self.pc_builder.tobytes()
- super().submit(data=my_data, queue_index=queue_index, instance_count=instance_count)
+ super().submit(
+ data=my_data,
+ queue_index=queue_index,
+ instance_count=instance_count,
+ cuda_stream=None,
+ )
if self._reset_on_submit:
self.reset()
@@ -253,9 +353,6 @@ def submit_any(self, instance_count: int = None) -> None:
_global_graph = threading.local()
-#__default_graph = None
-#__custom_graph = None
-
def _get_global_graph() -> Optional[CommandGraph]:
return getattr(_global_graph, 'custom_graph', None)
@@ -279,4 +376,4 @@ def set_global_graph(graph: CommandGraph = None) -> CommandGraph:
return
assert _get_global_graph() is None, "A global CommandGraph is already set for the current thread!"
- _global_graph.custom_graph = graph
\ No newline at end of file
+ _global_graph.custom_graph = graph
diff --git a/vkdispatch/execution_pipeline/cuda_graph_capture.py b/vkdispatch/execution_pipeline/cuda_graph_capture.py
new file mode 100644
index 00000000..a96f6a9e
--- /dev/null
+++ b/vkdispatch/execution_pipeline/cuda_graph_capture.py
@@ -0,0 +1,51 @@
+import vkdispatch as vd
+
+from contextlib import contextmanager
+
+import threading
+
+import typing
+
+class CUDAGraphCapture:
+ cuda_stream = typing.Any
+ uniform_buffers = typing.List[typing.Any]
+
+ def add_uniform_buffer(self, buffer):
+ self.uniform_buffers.append(buffer)
+
+_cap = threading.local()
+
+def _set_capture(capture):
+ _cap.capture = capture
+
+def get_cuda_capture() -> CUDAGraphCapture:
+ return getattr(_cap, "capture", None)
+
+@contextmanager
+def cuda_graph_capture(cuda_stream=None):
+ assert vd.is_cuda(), "CUDA graph capture is only supported when using the CUDA backend."
+
+ cap = CUDAGraphCapture()
+ cap.cuda_stream = cuda_stream
+ cap.uniform_buffers = []
+
+ _set_capture(cap)
+
+ try:
+ yield cap
+ finally:
+ _set_capture(None)
+
+@contextmanager
+def suspend_cuda_capture():
+ """Temporarily disable vkdispatch CUDA capture state for non-captured ops."""
+ cap = get_cuda_capture()
+ if cap is None:
+ yield
+ return
+
+ _set_capture(None)
+ try:
+ yield
+ finally:
+ _set_capture(cap)
diff --git a/vkdispatch/fft/__init__.py b/vkdispatch/fft/__init__.py
index b16e51ef..5dab17ff 100644
--- a/vkdispatch/fft/__init__.py
+++ b/vkdispatch/fft/__init__.py
@@ -22,6 +22,15 @@
from .functions import fft, fft2, fft3, ifft, ifft2, ifft3
from .functions import rfft, rfft2, rfft3, irfft, irfft2, irfft3
+from .src_functions import fft_src, fft2_src, fft3_src, ifft_src, ifft2_src, ifft3_src
+from .src_functions import rfft_src, rfft2_src, rfft3_src, irfft_src, irfft2_src, irfft3_src
+
+from .src_functions import fft_print_src, fft2_print_src, fft3_print_src, ifft_print_src, ifft2_print_src, ifft3_print_src
+from .src_functions import rfft_print_src, rfft2_print_src, rfft3_print_src, irfft_print_src, irfft2_print_src, irfft3_print_src
+
from .functions import convolve, convolve2D, convolve2DR, transpose
+from .src_functions import convolve_src, convolve2D_src, convolve2DR_src, transpose_src
+from .src_functions import convolve_print_src, convolve2D_print_src, convolve2DR_print_src
+
from .prime_utils import pad_dim
\ No newline at end of file
diff --git a/vkdispatch/fft/config.py b/vkdispatch/fft/config.py
index ca8e1d6d..02628e84 100644
--- a/vkdispatch/fft/config.py
+++ b/vkdispatch/fft/config.py
@@ -1,98 +1,196 @@
import vkdispatch as vd
import vkdispatch.codegen as vc
import dataclasses
-from typing import List, Tuple, Optional
+from typing import List, Tuple, Optional, Dict
-from .._compat import numpy_compat as npc
+from ..compat import numpy_compat as npc
+import vkdispatch.base.dtype as dtypes
from .prime_utils import prime_factors, group_primes, default_register_limit, default_max_prime
-@dataclasses.dataclass
-class FFTRegisterStageConfig:
- """
- Configuration for an FFT register stage.
-
- Attributes:
-
- primes (Tuple[int]): The prime numbers used for factorization.
- fft_length (int): The length of each FFT stage.
- instance_count (int): The number of instances required to achieve the desired level of parallelism.
- registers_used (int): The total number of registers used by the FFT stage.
- remainder (int): The remainder of `N` divided by `registers_used`.
- remainder_offset (int): A flag indicating whether the remainder is non-zero.
- extra_ffts (int): The additional number of FFT stages required to process the remainder.
- thread_count (int): The total number of threads used in the computation.
- sdata_size (int): The size of the shared memory buffer used to store intermediate results.
- sdata_width (int): The width of each element in the shared memory buffer.
- sdata_width_padded (int): The padded width of each element in the shared memory buffer.
-
- """
-
- primes: Tuple[int]
- fft_length: int
- instance_count: int
- registers_used: int
- remainder: int
- remainder_offset: int
- extra_ffts: int
- thread_count: int
- sdata_size: int
- sdata_width: int
- sdata_width_padded: int
-
- def __init__(self, primes: List[int], max_register_count: int, N: int):
- """
- Initializes the FFTRegisterStageConfig object.
+from .stages import FFTRegisterStageConfig
- Parameters:
+def plan_fft_stages(N: int, max_register_count: int, compute_item_size: int) -> Tuple[FFTRegisterStageConfig]:
+ all_factors = prime_factors(N)
- primes (List[int]): The prime numbers to use for factorization.
- max_register_count (int): The maximum number of registers allowed per thread.
- N (int): The length of the input data.
+ for factor in all_factors:
+ assert factor <= default_max_prime(), f"A prime factor of {N} is {factor}, which exceeds the maximum prime supported {default_max_prime()}"
- """
- self.primes = tuple(primes)
- self.fft_length = int(round(npc.prod(primes)))
- instance_primes = prime_factors(N // self.fft_length)
-
- self.instance_count = 1
+ prime_groups = group_primes(all_factors, max_register_count)
- while len(instance_primes) > 0:
- if self.instance_count * self.fft_length * instance_primes[0] > max_register_count:
- break
- self.instance_count *= instance_primes[0]
- instance_primes = instance_primes[1:]
+ stages = []
+ input_stride = 1
- self.registers_used = self.fft_length * self.instance_count
+ for group in prime_groups:
+ stage = FFTRegisterStageConfig(
+ group,
+ max_register_count,
+ N,
+ compute_item_size,
+ input_stride
+ )
+ stages.append(stage)
+ input_stride = stage.output_stride
- self.remainder = N % self.registers_used
- assert self.remainder % self.fft_length == 0, "Remainder must be divisible by the FFT length"
- self.remainder_offset = 1 if self.remainder != 0 else 0
- self.extra_ffts = self.remainder // self.fft_length
+ return tuple(stages)
- self.thread_count = N // self.registers_used + self.remainder_offset
-
- self.sdata_width = self.registers_used
-
- threads_primes = prime_factors(self.thread_count)
+@dataclasses.dataclass
+class FFTPlanCandidate:
+ max_register_count: int
+ stages: Tuple[FFTRegisterStageConfig]
+ register_count: int
+ batch_threads: int
+ transfer_count: Optional[int] = None
+
+ def __init__(self, N: int, max_register_count: int,compute_item_size: int):
+ stages = plan_fft_stages(N, max_register_count, compute_item_size)
+ register_count = max(stage.registers_used for stage in stages)
+ batch_threads = max(stage.thread_count for stage in stages)
+
+ if register_count > max_register_count:
+ self.max_register_count = None
+ self.stages = None
+ self.register_count = None
+ self.batch_threads = None
+ self.transfer_count = None
+ return
+
+ transfer_count = 0
+ output_stride = 1
+
+ for stage_index in range(len(stages) - 1):
+ output_stage = stages[stage_index]
+ input_stage = stages[stage_index + 1]
+
+ output_keys = output_stage.get_output_format(register_count).keys()
+ input_keys = input_stage.get_input_format(register_count).keys()
+
+ if output_keys != input_keys:
+ transfer_count += 1
+
+ output_stride *= output_stage.fft_length
+
+ self.max_register_count = max_register_count
+ self.stages = stages
+ self.register_count = register_count
+ self.batch_threads = batch_threads
+ self.transfer_count = transfer_count
+
+def register_limit_candidates(N: int, initial_limit: int) -> List[int]:
+ divisors = {1}
+
+ for factor in prime_factors(N):
+ divisors.update(divisor * factor for divisor in tuple(divisors))
+
+ candidates = [initial_limit]
+ candidates.extend(
+ divisor
+ for divisor in sorted(divisors)
+ if initial_limit < divisor <= N
+ )
+ return candidates
+
+def required_batch_threads_limit(batch_inner_count: int) -> int:
+ context = vd.get_context()
+ thread_dimension_limit = (
+ context.max_workgroup_size[1]
+ if batch_inner_count > 1
+ else context.max_workgroup_size[0]
+ )
+ return max(1, min(int(thread_dimension_limit), int(context.max_workgroup_invocations)))
+
+def select_fft_plan_candidate(
+ N: int,
+ batch_inner_count: int,
+ compute_item_size: int,
+ max_register_count: Optional[int],
+) -> FFTPlanCandidate:
+ batch_threads_limit = required_batch_threads_limit(batch_inner_count)
+ dimension_name = "y" if batch_inner_count > 1 else "x"
+
+ if max_register_count is not None:
+ requested_limit = min(max_register_count, N)
+ candidate = FFTPlanCandidate(
+ N=N,
+ max_register_count=requested_limit,
+ compute_item_size=compute_item_size,
+ )
+
+ assert candidate.stages is not None, f"Failed to create an FFT plan candidate for N={N} with max_register_count={requested_limit}"
+
+ if candidate.batch_threads <= batch_threads_limit:
+ return candidate
+
+ best_candidate = candidate
+ explicit_text = "requested"
+ searched_limit = requested_limit
+ else:
+ max_registers = default_register_limit()
- while self.sdata_width < 16 and len(threads_primes) > 0:
- self.sdata_width *= threads_primes[0]
- threads_primes = threads_primes[1:]
+ if N==16 or N==8 or N==4 or N==2 and vd.get_devices()[0].is_nvidia():
+ max_registers = max(2, N//2)
+
+ baseline_limit = min(8, N)
+ requested_limit = baseline_limit
+ candidate_limits = register_limit_candidates(max_registers, baseline_limit)
+ searched_limit = candidate_limits[-1]
+
+ baseline_candidate = FFTPlanCandidate(
+ N=N,
+ max_register_count=baseline_limit,
+ compute_item_size=compute_item_size,
+ )
+ best_candidate = baseline_candidate if baseline_candidate.stages is not None else None
+
+ if best_candidate is not None and baseline_candidate.batch_threads <= batch_threads_limit:
+ for candidate_limit in candidate_limits[1:]:
+ candidate = FFTPlanCandidate(
+ N=N,
+ max_register_count=candidate_limit,
+ compute_item_size=compute_item_size,
+ )
+
+ if candidate.stages is None:
+ continue
+
+ if best_candidate is None or candidate.batch_threads < best_candidate.batch_threads:
+ best_candidate = candidate
+
+ if candidate.batch_threads > batch_threads_limit:
+ continue
+
+ if candidate.transfer_count < baseline_candidate.transfer_count:
+ return candidate
+
+ return baseline_candidate
+
+ for candidate_limit in candidate_limits[1:]:
+ candidate = FFTPlanCandidate(
+ N=N,
+ max_register_count=candidate_limit,
+ compute_item_size=compute_item_size,
+ )
+ if candidate.stages is None:
+ continue
- self.sdata_width_padded = self.sdata_width
+ if best_candidate is None or candidate.batch_threads < best_candidate.batch_threads:
+ best_candidate = candidate
- if self.sdata_width_padded % 2 == 0:
- self.sdata_width_padded += 1
+ if candidate.batch_threads <= batch_threads_limit:
+ return candidate
- self.sdata_size = self.sdata_width_padded * int(npc.prod(threads_primes))
+ explicit_text = "default"
- if self.sdata_size > vd.get_context().max_shared_memory // vd.complex64.item_size:
- self.sdata_width_padded = self.sdata_width
- self.sdata_size = self.sdata_width_padded * int(npc.prod(threads_primes))
+ raise ValueError(
+ f"Unable to build an FFT plan for size {N}: minimum achievable batch thread count "
+ f"{best_candidate.batch_threads} exceeds the device's local {dimension_name}-dimension "
+ f"limit {batch_threads_limit} (starting from {explicit_text} max_register_count="
+ f"{requested_limit}, searched up to {searched_limit})."
+ )
@dataclasses.dataclass
class FFTConfig:
N: int
+ compute_type: dtypes.dtype
register_count: int
max_prime_radix: int
stages: Tuple[FFTRegisterStageConfig]
@@ -107,10 +205,21 @@ class FFTConfig:
sdata_row_size: int
sdata_row_size_padded: int
- def __init__(self, buffer_shape: Tuple, axis: int = None, max_register_count: int = None):
+ def __init__(
+ self,
+ buffer_shape: Tuple,
+ axis: int = None,
+ max_register_count: int = None,
+ compute_type: dtypes.dtype = vd.complex64,
+ ):
if axis is None:
axis = len(buffer_shape) - 1
+ if not dtypes.is_complex(compute_type):
+ raise ValueError(f"compute_type must be a complex dtype, got {compute_type}")
+
+ self.compute_type = compute_type
+
total_buffer_length = int(round(npc.prod(buffer_shape)))
N = buffer_shape[axis]
@@ -123,14 +232,6 @@ def __init__(self, buffer_shape: Tuple, axis: int = None, max_register_count: in
self.N = N
- if max_register_count is None:
- max_register_count = default_register_limit()
-
- if N==16 or N==8 or N==4 or N==2 and vd.get_devices()[0].is_nvidia():
- max_register_count = max(2, N//2)
-
- max_register_count = min(max_register_count, N)
-
all_factors = prime_factors(N)
for factor in all_factors:
@@ -138,13 +239,14 @@ def __init__(self, buffer_shape: Tuple, axis: int = None, max_register_count: in
self.max_prime_radix = max(all_factors)
- prime_groups = group_primes(all_factors, max_register_count)
-
- self.stages = tuple([FFTRegisterStageConfig(group, max_register_count, N) for group in prime_groups])
- register_utilizations = [stage.registers_used for stage in self.stages]
- self.register_count = max(register_utilizations)
-
- assert self.register_count <= max_register_count, f"Register count {self.register_count} exceeds max register count {max_register_count}"
+ plan_candidate = select_fft_plan_candidate(
+ N=N,
+ batch_inner_count=self.batch_inner_count,
+ compute_item_size=self.compute_type.item_size,
+ max_register_count=max_register_count,
+ )
+ self.stages = plan_candidate.stages
+ self.register_count = plan_candidate.register_count
self.sdata_allocation = 1
self.sdata_row_size = 1
@@ -158,9 +260,9 @@ def __init__(self, buffer_shape: Tuple, axis: int = None, max_register_count: in
self.sdata_row_size = stage.sdata_width
self.sdata_row_size_padded = stage.sdata_width_padded
- self.thread_counts = [stage.thread_count for stage in self.stages]
+ self.thread_counts = tuple(stage.thread_count for stage in self.stages)
- self.batch_threads = max(self.thread_counts)
+ self.batch_threads = plan_candidate.batch_threads
def __str__(self):
return f"FFT Config:\nN: {self.N}\nregister_count: {self.register_count}\nstages:\n{self.stages}\nlocal_size: {self.thread_counts}"
diff --git a/vkdispatch/fft/context.py b/vkdispatch/fft/context.py
index 2afa1ece..8a6bc7cc 100644
--- a/vkdispatch/fft/context.py
+++ b/vkdispatch/fft/context.py
@@ -1,5 +1,6 @@
import vkdispatch as vd
import vkdispatch.codegen as vc
+import vkdispatch.base.dtype as dtypes
import contextlib
from typing import Optional, Tuple, Union, List, Dict
@@ -31,12 +32,13 @@ def __init__(self,
buffer_shape: Tuple,
axis: int = None,
max_register_count: int = None,
+ compute_type: dtypes.dtype = vd.complex64,
name: str = None):
self.shader_context = shader_context
self.declared_shader_args = False
self.declarer = None
- self.config = FFTConfig(buffer_shape, axis, max_register_count)
+ self.config = FFTConfig(buffer_shape, axis, max_register_count, compute_type=compute_type)
self.grid = FFTGridManager(self.config, True, True)
self.resources = FFTResources(self.config, self.grid)
@@ -63,6 +65,8 @@ def declare_shader_args(self, types: List) -> List[vc.ShaderVariable]:
def make_io_manager(self,
output_map: Optional[vd.MappingFunction],
+ output_type: dtypes.dtype = vd.complex64,
+ input_type: Optional[dtypes.dtype] = None,
input_map: Optional[vd.MappingFunction] = None,
kernel_map: Optional[vd.MappingFunction] = None) -> IOManager:
assert not self.declared_shader_args, f"Shader arguments already declared with {self.declarer}"
@@ -72,6 +76,8 @@ def make_io_manager(self,
default_registers=self.registers,
shader_context=self.shader_context,
output_map=output_map,
+ output_type=output_type,
+ input_type=input_type,
input_map=input_map,
kernel_map=kernel_map
)
@@ -127,7 +133,8 @@ def register_shuffle(self,
def compile_shader(self):
self.fft_callable = self.shader_context.get_function(
local_size=self.grid.local_size,
- exec_count=self.grid.exec_size
+ exec_count=self.grid.exec_size,
+ name=self.name
)
def get_callable(self) -> vd.ShaderFunction:
@@ -148,7 +155,7 @@ def execute(self, inverse: bool):
self.register_shuffle(output_stage=i-1, input_stage=i)
self.resources.stage_begin(i)
- for ii, invocation in enumerate(self.resources.invocations[i]):
+ for ii, invocation in enumerate(self.config.stages[i].invocations):
self.resources.invocation_gaurd(i, ii)
self.registers.slice_set(invocation.register_selection, radix_composite(
@@ -156,7 +163,7 @@ def execute(self, inverse: bool):
inverse=inverse,
register_list=self.registers.register_slice(invocation.register_selection),
primes=stage.primes,
- twiddle_index=invocation.inner_block_offset,
+ twiddle_index=invocation.get_inner_block_offset(self.resources.tid),
twiddle_N=invocation.block_width
))
@@ -166,7 +173,9 @@ def execute(self, inverse: bool):
@contextlib.contextmanager
def fft_context(buffer_shape: Tuple,
axis: Optional[int] = None,
- max_register_count: Optional[int] = None):
+ max_register_count: Optional[int] = None,
+ compute_type: dtypes.dtype = vd.complex64,
+ name: Optional[str] = None):
try:
with vd.shader_context(vc.ShaderFlags.NO_EXEC_BOUNDS) as context:
@@ -174,7 +183,9 @@ def fft_context(buffer_shape: Tuple,
shader_context=context,
buffer_shape=buffer_shape,
axis=axis,
- max_register_count=max_register_count
+ max_register_count=max_register_count,
+ compute_type=compute_type,
+ name=name
)
yield fft_context
diff --git a/vkdispatch/fft/cooley_tukey.py b/vkdispatch/fft/cooley_tukey.py
index 006e0763..f2821907 100644
--- a/vkdispatch/fft/cooley_tukey.py
+++ b/vkdispatch/fft/cooley_tukey.py
@@ -3,7 +3,7 @@
from typing import List, Union
-from .._compat import numpy_compat as npc
+from ..compat import numpy_compat as npc
def get_angle_factor(inverse: bool) -> float:
return 2 * npc.pi * (1 if inverse else -1)
@@ -46,6 +46,9 @@ def _apply_twiddle_to_register(
if isinstance(twiddle, complex):
if _apply_constant_twiddle(resources, register, twiddle):
return
+
+ twiddle = vc.to_dtype(register.var_type, twiddle.real, twiddle.imag)
+
resources.radix_registers[0][:] = vc.mult_complex(register, twiddle)
register[:] = resources.radix_registers[0]
@@ -81,7 +84,8 @@ def radix_P(resources: FFTResources, inverse: bool, register_list: List[vc.Shade
continue
omega = npc.exp_complex(1j * angle_factor * i * j / len(register_list))
- resources.omega_register[:] = vc.mult_complex(register_list[j], omega)
+ typed_omega = vc.to_dtype(register_list[j].var_type, omega.real, omega.imag)
+ resources.omega_register[:] = vc.mult_complex(register_list[j], typed_omega)
resources.radix_registers[i] += resources.omega_register
for i in range(0, len(register_list)):
@@ -118,7 +122,9 @@ def apply_twiddle_factors(
_apply_twiddle_to_register(resources, register_list[i], omega)
continue
- resources.omega_register.real = (angle_factor * i / twiddle_N) * twiddle_index
+ angle_scale = vc.to_dtype(resources.omega_register.real.var_type, angle_factor * i / twiddle_N)
+ twiddle_scale = vc.to_dtype(resources.omega_register.real.var_type, twiddle_index)
+ resources.omega_register.real = angle_scale * twiddle_scale
resources.omega_register[:] = vc.complex_from_euler_angle(resources.omega_register.real)
resources.radix_registers[0][:] = vc.mult_complex(register_list[i], resources.omega_register)
register_list[i][:] = resources.radix_registers[0]
diff --git a/vkdispatch/fft/functions.py b/vkdispatch/fft/functions.py
index 9c400b4b..0818a8eb 100644
--- a/vkdispatch/fft/functions.py
+++ b/vkdispatch/fft/functions.py
@@ -1,8 +1,105 @@
import vkdispatch as vd
from .shader_factories import make_fft_shader, make_convolution_shader, make_transpose_shader, get_transposed_size
+from .precision import (
+ ensure_supported_complex_precision,
+ resolve_compute_precision,
+ validate_complex_precision,
+)
-from typing import Tuple, Union, Optional
+from typing import List, Tuple, Union, Optional
+
+
+def _validate_map_argument_annotations(map_fn: vd.MappingFunction, map_name: str) -> None:
+ for buffer_type in map_fn.buffer_types:
+ if not hasattr(buffer_type, "__args__") or len(buffer_type.__args__) != 1:
+ raise ValueError(
+ f"{map_name} contains an annotation without exactly one type argument: {buffer_type}"
+ )
+
+
+def _resolve_output_precision(
+ buffers: Tuple[vd.Buffer, ...],
+ output_map: Optional[vd.MappingFunction],
+ output_type: Optional[vd.dtype],
+) -> Optional[vd.dtype]:
+ if output_map is not None:
+ if output_type is not None:
+ raise ValueError("output_type cannot be provided when output_map is used")
+ return None
+
+ resolved_output = buffers[0].var_type if output_type is None else output_type
+ validate_complex_precision(resolved_output, arg_name="output_type")
+ ensure_supported_complex_precision(resolved_output, role="Output")
+ return resolved_output
+
+
+def _resolve_input_precision(
+ buffers: Tuple,
+ input_map: Optional[vd.MappingFunction],
+ output_map: Optional[vd.MappingFunction],
+ input_type: Optional[vd.dtype],
+ output_precision: Optional[vd.dtype],
+) -> Optional[vd.dtype]:
+ if input_map is not None:
+ if input_type is not None:
+ raise ValueError("input_type cannot be provided when input_map is used")
+ return None
+
+ if output_map is not None:
+ output_arg_count = len(output_map.buffer_types)
+ if len(buffers) <= output_arg_count:
+ raise ValueError(
+ "When output_map is used without input_map, an input buffer argument must be provided "
+ "after output_map arguments"
+ )
+
+ resolved_input = input_type
+ if resolved_input is None:
+ inferred_input = buffers[output_arg_count]
+ if not hasattr(inferred_input, "var_type"):
+ raise ValueError(
+ "When output_map is used without input_map, the argument after output_map arguments "
+ "must be a buffer"
+ )
+ resolved_input = inferred_input.var_type
+
+ validate_complex_precision(resolved_input, arg_name="input_type")
+ ensure_supported_complex_precision(resolved_input, role="Input")
+ return resolved_input
+
+ if output_precision is None:
+ raise ValueError("output_precision must be provided when output_map is not used")
+
+ resolved_input = output_precision if input_type is None else input_type
+ validate_complex_precision(resolved_input, arg_name="input_type")
+ ensure_supported_complex_precision(resolved_input, role="Input")
+
+ if resolved_input != output_precision:
+ raise ValueError(
+ "input_type must match output_type when input_map is None (default FFT path is in-place)"
+ )
+
+ return resolved_input
+
+
+def _resolve_kernel_precision(
+ buffers: Tuple[vd.Buffer, ...],
+ kernel_map: Optional[vd.MappingFunction],
+ kernel_type: Optional[vd.dtype],
+) -> Optional[vd.dtype]:
+ if kernel_map is not None:
+ if kernel_type is not None:
+ raise ValueError("kernel_type cannot be provided when kernel_map is used")
+ return None
+
+ if len(buffers) < 2:
+ raise ValueError("Kernel precision inference requires a kernel buffer argument")
+
+ resolved_kernel = buffers[1].var_type if kernel_type is None else kernel_type
+ validate_complex_precision(resolved_kernel, arg_name="kernel_type")
+ ensure_supported_complex_precision(resolved_kernel, role="Kernel")
+ return resolved_kernel
def fft(
*buffers: vd.Buffer,
@@ -16,13 +113,36 @@ def fft(
r2c: bool = False,
input_map: vd.MappingFunction = None,
output_map: vd.MappingFunction = None,
+ output_type: vd.dtype = None,
+ input_type: vd.dtype = None,
+ compute_type: vd.dtype = None,
input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None):
assert len(buffers) >= 1, "At least one buffer must be provided"
+
+ if input_map is None and output_map is None and len(buffers) != 1:
+ raise ValueError("fft() expects exactly one buffer unless input_map/output_map are used")
if buffer_shape is None:
buffer_shape = buffers[0].shape
+ resolved_output_type = _resolve_output_precision(buffers, output_map, output_type)
+ resolved_input_type = _resolve_input_precision(buffers, input_map, output_map, input_type, resolved_output_type)
+
+ io_precisions: List[vd.dtype] = []
+ if output_map is None:
+ io_precisions.append(resolved_output_type)
+ else:
+ _validate_map_argument_annotations(output_map, "output_map")
+
+ if input_map is None:
+ if resolved_input_type is not None:
+ io_precisions.append(resolved_input_type)
+ else:
+ _validate_map_argument_annotations(input_map, "input_map")
+
+ resolved_compute_type = resolve_compute_precision(io_precisions, compute_type)
+
fft_shader = make_fft_shader(
tuple(buffer_shape),
axis,
@@ -31,6 +151,9 @@ def fft(
r2c=r2c,
input_map=input_map,
output_map=output_map,
+ input_type=resolved_input_type,
+ output_type=resolved_output_type,
+ compute_type=resolved_compute_type,
input_signal_range=input_signal_range)
if print_shader:
@@ -38,18 +161,80 @@ def fft(
fft_shader(*buffers, graph=graph)
-def fft2(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None):
+def fft2(
+ buffer: vd.Buffer,
+ graph: vd.CommandGraph = None,
+ print_shader: bool = False,
+ input_map: vd.MappingFunction = None,
+ output_map: vd.MappingFunction = None,
+ output_type: vd.dtype = None,
+ input_type: vd.dtype = None,
+ compute_type: vd.dtype = None,
+):
assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions'
- fft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 2, input_map=input_map)
- fft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 1, output_map=output_map)
+ fft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=len(buffer.shape) - 2,
+ input_map=input_map,
+ output_type=output_type,
+ input_type=input_type,
+ compute_type=compute_type,
+ )
+ fft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=len(buffer.shape) - 1,
+ output_map=output_map,
+ output_type=output_type if output_map is None else None,
+ input_type=input_type if output_map is None else None,
+ compute_type=compute_type,
+ )
-def fft3(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None):
+def fft3(
+ buffer: vd.Buffer,
+ graph: vd.CommandGraph = None,
+ print_shader: bool = False,
+ input_map: vd.MappingFunction = None,
+ output_map: vd.MappingFunction = None,
+ output_type: vd.dtype = None,
+ input_type: vd.dtype = None,
+ compute_type: vd.dtype = None,
+):
assert len(buffer.shape) == 3, 'Buffer must have 3 dimensions'
- fft(buffer, graph=graph, print_shader=print_shader, axis=0, input_map=input_map)
- fft(buffer, graph=graph, print_shader=print_shader, axis=1)
- fft(buffer, graph=graph, print_shader=print_shader, axis=2, output_map=output_map)
+ fft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=0,
+ input_map=input_map,
+ output_type=output_type,
+ input_type=input_type,
+ compute_type=compute_type,
+ )
+ fft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=1,
+ output_type=output_type,
+ input_type=input_type,
+ compute_type=compute_type,
+ )
+ fft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=2,
+ output_map=output_map,
+ output_type=output_type if output_map is None else None,
+ input_type=input_type if output_map is None else None,
+ compute_type=compute_type,
+ )
def ifft(
@@ -60,54 +245,225 @@ def ifft(
name: str = None,
normalize: bool = True,
input_map: vd.MappingFunction = None,
- output_map: vd.MappingFunction = None):
- fft(buffer, graph=graph, print_shader=print_shader, axis=axis, name=name, inverse=True, normalize_inverse=normalize, input_map=input_map, output_map=output_map)
+ output_map: vd.MappingFunction = None,
+ output_type: vd.dtype = None,
+ input_type: vd.dtype = None,
+ compute_type: vd.dtype = None):
+ fft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=axis,
+ name=name,
+ inverse=True,
+ normalize_inverse=normalize,
+ input_map=input_map,
+ output_map=output_map,
+ output_type=output_type,
+ input_type=input_type,
+ compute_type=compute_type,
+ )
-def ifft2(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None):
+def ifft2(
+ buffer: vd.Buffer,
+ graph: vd.CommandGraph = None,
+ print_shader: bool = False,
+ normalize: bool = True,
+ input_map: vd.MappingFunction = None,
+ output_map: vd.MappingFunction = None,
+ output_type: vd.dtype = None,
+ input_type: vd.dtype = None,
+ compute_type: vd.dtype = None,
+):
assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions'
- ifft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 2, normalize=normalize, input_map=input_map)
- ifft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 1, normalize=normalize, output_map=output_map)
+ ifft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=len(buffer.shape) - 2,
+ normalize=normalize,
+ input_map=input_map,
+ output_type=output_type,
+ input_type=input_type,
+ compute_type=compute_type,
+ )
+ ifft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=len(buffer.shape) - 1,
+ normalize=normalize,
+ output_map=output_map,
+ output_type=output_type if output_map is None else None,
+ input_type=input_type if output_map is None else None,
+ compute_type=compute_type,
+ )
-def ifft3(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None):
+def ifft3(
+ buffer: vd.Buffer,
+ graph: vd.CommandGraph = None,
+ print_shader: bool = False,
+ normalize: bool = True,
+ input_map: vd.MappingFunction = None,
+ output_map: vd.MappingFunction = None,
+ output_type: vd.dtype = None,
+ input_type: vd.dtype = None,
+ compute_type: vd.dtype = None,
+):
assert len(buffer.shape) == 3, 'Buffer must have 3 dimensions'
- ifft(buffer, graph=graph, print_shader=print_shader, axis=0, normalize=normalize, input_map=input_map)
- ifft(buffer, graph=graph, print_shader=print_shader, axis=1, normalize=normalize)
- ifft(buffer, graph=graph, print_shader=print_shader, axis=2, normalize=normalize, output_map=output_map)
+ ifft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=0,
+ normalize=normalize,
+ input_map=input_map,
+ output_type=output_type,
+ input_type=input_type,
+ compute_type=compute_type,
+ )
+ ifft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=1,
+ normalize=normalize,
+ output_type=output_type,
+ input_type=input_type,
+ compute_type=compute_type,
+ )
+ ifft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=2,
+ normalize=normalize,
+ output_map=output_map,
+ output_type=output_type if output_map is None else None,
+ input_type=input_type if output_map is None else None,
+ compute_type=compute_type,
+ )
-def rfft(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, name: str = None):
- fft(buffer, buffer_shape=buffer.real_shape, graph=graph, print_shader=print_shader, name=name, r2c=True)
+def rfft(
+ buffer: vd.RFFTBuffer,
+ graph: vd.CommandGraph = None,
+ print_shader: bool = False,
+ name: str = None,
+ compute_type: vd.dtype = None,
+):
+ fft(
+ buffer,
+ buffer_shape=buffer.real_shape,
+ graph=graph,
+ print_shader=print_shader,
+ name=name,
+ r2c=True,
+ output_type=buffer.var_type,
+ input_type=buffer.var_type,
+ compute_type=compute_type,
+ )
-def rfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False):
+def rfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, compute_type: vd.dtype = None):
assert len(buffer.real_shape) == 2 or len(buffer.real_shape) == 3, 'Buffer must have 2 or 3 dimensions'
- rfft(buffer, graph=graph, print_shader=print_shader)
- fft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.real_shape) - 2)
+ rfft(buffer, graph=graph, print_shader=print_shader, compute_type=compute_type)
+ fft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=len(buffer.real_shape) - 2,
+ output_type=buffer.var_type,
+ input_type=buffer.var_type,
+ compute_type=compute_type,
+ )
-def rfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False):
+def rfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, compute_type: vd.dtype = None):
assert len(buffer.real_shape) == 3, 'Buffer must have 3 dimensions'
- rfft(buffer, graph=graph, print_shader=print_shader)
- fft(buffer, graph=graph, print_shader=print_shader, axis=1)
- fft(buffer, graph=graph, print_shader=print_shader, axis=0)
+ rfft(buffer, graph=graph, print_shader=print_shader, compute_type=compute_type)
+ fft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=1,
+ output_type=buffer.var_type,
+ input_type=buffer.var_type,
+ compute_type=compute_type,
+ )
+ fft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=0,
+ output_type=buffer.var_type,
+ input_type=buffer.var_type,
+ compute_type=compute_type,
+ )
-def irfft(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, name: str = None, normalize: bool = True):
- fft(buffer, buffer_shape=buffer.real_shape, graph=graph, print_shader=print_shader, name=name, inverse=True, normalize_inverse=normalize, r2c=True)
+def irfft(
+ buffer: vd.RFFTBuffer,
+ graph: vd.CommandGraph = None,
+ print_shader: bool = False,
+ name: str = None,
+ normalize: bool = True,
+ compute_type: vd.dtype = None,
+):
+ fft(
+ buffer,
+ buffer_shape=buffer.real_shape,
+ graph=graph,
+ print_shader=print_shader,
+ name=name,
+ inverse=True,
+ normalize_inverse=normalize,
+ r2c=True,
+ output_type=buffer.var_type,
+ input_type=buffer.var_type,
+ compute_type=compute_type,
+ )
-def irfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True):
+def irfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, compute_type: vd.dtype = None):
assert len(buffer.real_shape) == 2 or len(buffer.real_shape) == 3, 'Buffer must have 2 or 3 dimensions'
- ifft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.real_shape) - 2, normalize=normalize)
- irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize)
+ ifft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=len(buffer.real_shape) - 2,
+ normalize=normalize,
+ output_type=buffer.var_type,
+ input_type=buffer.var_type,
+ compute_type=compute_type,
+ )
+ irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize, compute_type=compute_type)
-def irfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True):
+def irfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, compute_type: vd.dtype = None):
assert len(buffer.real_shape) == 3, 'Buffer must have 3 dimensions'
- ifft(buffer, graph=graph, print_shader=print_shader, axis=0, normalize=normalize)
- ifft(buffer, graph=graph, print_shader=print_shader, axis=1, normalize=normalize)
- irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize)
+ ifft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=0,
+ normalize=normalize,
+ output_type=buffer.var_type,
+ input_type=buffer.var_type,
+ compute_type=compute_type,
+ )
+ ifft(
+ buffer,
+ graph=graph,
+ print_shader=print_shader,
+ axis=1,
+ normalize=normalize,
+ output_type=buffer.var_type,
+ input_type=buffer.var_type,
+ compute_type=compute_type,
+ )
+ irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize, compute_type=compute_type)
def convolve(
*buffers: vd.Buffer,
@@ -123,10 +479,43 @@ def convolve(
kernel_inner_only: bool = False,
input_map: vd.MappingFunction = None,
output_map: vd.MappingFunction = None,
+ output_type: vd.dtype = None,
+ input_type: vd.dtype = None,
+ kernel_type: vd.dtype = None,
+ compute_type: vd.dtype = None,
input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None):
+ assert len(buffers) >= 1, "At least one buffer must be provided"
+
+ if kernel_map is None and len(buffers) < 2:
+ raise ValueError("convolve() requires at least an output buffer and kernel buffer")
+
if buffer_shape is None:
buffer_shape = buffers[0].shape
+ resolved_output_type = _resolve_output_precision(buffers, output_map, output_type)
+ resolved_input_type = _resolve_input_precision(buffers, input_map, output_map, input_type, resolved_output_type)
+ resolved_kernel_type = _resolve_kernel_precision(buffers, kernel_map, kernel_type)
+
+ io_precisions: List[vd.dtype] = []
+
+ if output_map is None:
+ io_precisions.append(resolved_output_type)
+ else:
+ _validate_map_argument_annotations(output_map, "output_map")
+
+ if input_map is None:
+ if resolved_input_type is not None:
+ io_precisions.append(resolved_input_type)
+ else:
+ _validate_map_argument_annotations(input_map, "input_map")
+
+ if kernel_map is None:
+ io_precisions.append(resolved_kernel_type)
+ else:
+ _validate_map_argument_annotations(kernel_map, "kernel_map")
+
+ resolved_compute_type = resolve_compute_precision(io_precisions, compute_type)
+
fft_shader = make_convolution_shader(
tuple(buffer_shape),
kernel_map,
@@ -137,6 +526,10 @@ def convolve(
normalize=normalize,
input_map=input_map,
output_map=output_map,
+ input_type=resolved_input_type,
+ output_type=resolved_output_type,
+ kernel_type=resolved_kernel_type,
+ compute_type=resolved_compute_type,
input_signal_range=input_signal_range)
if print_shader:
@@ -155,7 +548,11 @@ def convolve2D(
transposed_kernel: bool = False,
kernel_inner_only: bool = False,
input_map: vd.MappingFunction = None,
- output_map: vd.MappingFunction = None):
+ output_map: vd.MappingFunction = None,
+ output_type: vd.dtype = None,
+ input_type: vd.dtype = None,
+ kernel_type: vd.dtype = None,
+ compute_type: vd.dtype = None):
assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions'
@@ -168,7 +565,15 @@ def convolve2D(
if output_map is not None:
output_buffers.append(buffer)
- fft(*input_buffers, graph=graph, print_shader=print_shader, input_map=input_map)
+ fft(
+ *input_buffers,
+ graph=graph,
+ print_shader=print_shader,
+ input_map=input_map,
+ output_type=output_type,
+ input_type=input_type,
+ compute_type=compute_type,
+ )
convolve(
buffer,
kernel,
@@ -179,9 +584,22 @@ def convolve2D(
kernel_inner_only=kernel_inner_only,
print_shader=print_shader,
axis=len(buffer.shape) - 2,
- normalize=normalize
+ normalize=normalize,
+ output_type=output_type,
+ input_type=input_type,
+ kernel_type=kernel_type,
+ compute_type=compute_type,
+ )
+ ifft(
+ *output_buffers,
+ graph=graph,
+ print_shader=print_shader,
+ normalize=normalize,
+ output_map=output_map,
+ output_type=output_type if output_map is None else None,
+ input_type=input_type if output_map is None else None,
+ compute_type=compute_type,
)
- ifft(*output_buffers, graph=graph, print_shader=print_shader, normalize=normalize, output_map=output_map)
def convolve2DR(
buffer: vd.RFFTBuffer,
@@ -192,11 +610,12 @@ def convolve2DR(
kernel_inner_only: bool = False,
graph: vd.CommandGraph = None,
print_shader: bool = False,
- normalize: bool = True):
+ normalize: bool = True,
+ compute_type: vd.dtype = None):
assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions'
- rfft(buffer, graph=graph, print_shader=print_shader)
+ rfft(buffer, graph=graph, print_shader=print_shader, compute_type=compute_type)
convolve(
buffer,
kernel,
@@ -207,9 +626,13 @@ def convolve2DR(
kernel_inner_only=kernel_inner_only,
print_shader=print_shader,
axis=len(buffer.shape) - 2,
- normalize=normalize
+ normalize=normalize,
+ output_type=buffer.var_type,
+ input_type=buffer.var_type,
+ kernel_type=kernel.var_type,
+ compute_type=compute_type,
)
- irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize)
+ irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize, compute_type=compute_type)
def transpose(
in_buffer: vd.Buffer,
@@ -218,25 +641,54 @@ def transpose(
out_buffer: vd.Buffer = None,
graph: vd.CommandGraph = None,
kernel_inner_only: bool = False,
- print_shader: bool = False) -> vd.Buffer:
-
+ print_shader: bool = False,
+ input_type: vd.dtype = None,
+ output_type: vd.dtype = None,
+ compute_type: vd.dtype = None) -> vd.Buffer:
+
+ resolved_input_type = in_buffer.var_type if input_type is None else input_type
+ validate_complex_precision(resolved_input_type, arg_name="input_type")
+ ensure_supported_complex_precision(resolved_input_type, role="Input")
+
+ resolved_output_type = (
+ out_buffer.var_type if (out_buffer is not None and output_type is None)
+ else in_buffer.var_type if output_type is None
+ else output_type
+ )
+ validate_complex_precision(resolved_output_type, arg_name="output_type")
+ ensure_supported_complex_precision(resolved_output_type, role="Output")
+
+ resolved_compute_type = resolve_compute_precision(
+ [resolved_input_type, resolved_output_type],
+ compute_type,
+ )
+
transposed_size = get_transposed_size(
tuple(in_buffer.shape),
- axis=axis
+ axis=axis,
+ compute_type=resolved_compute_type,
)
if out_buffer is None:
- out_buffer = vd.Buffer((transposed_size,), var_type=in_buffer.var_type)
+ out_buffer = vd.Buffer((transposed_size,), var_type=resolved_output_type)
+ else:
+ if out_buffer.var_type != resolved_output_type:
+ raise ValueError(
+ f"out_buffer type ({out_buffer.var_type.name}) does not match output_type ({resolved_output_type.name})"
+ )
assert out_buffer.size >= transposed_size, f"Output buffer size {out_buffer.size} does not match expected transposed size {transposed_size}"
if conv_shape is None:
conv_shape = in_buffer.shape
-
+
transpose_shader = make_transpose_shader(
tuple(conv_shape),
axis=axis,
- kernel_inner_only=kernel_inner_only
+ kernel_inner_only=kernel_inner_only,
+ input_type=resolved_input_type,
+ output_type=resolved_output_type,
+ compute_type=resolved_compute_type,
)
if print_shader:
@@ -244,4 +696,4 @@ def transpose(
transpose_shader(out_buffer, in_buffer, graph=graph)
- return out_buffer
\ No newline at end of file
+ return out_buffer
diff --git a/vkdispatch/fft/global_memory_iterators.py b/vkdispatch/fft/global_memory_iterators.py
index e897846a..c621f6b6 100644
--- a/vkdispatch/fft/global_memory_iterators.py
+++ b/vkdispatch/fft/global_memory_iterators.py
@@ -7,6 +7,13 @@
from .registers import FFTRegisters
from .memory_iterators import memory_reads_iterator, memory_writes_iterator, MemoryOp
+
+def _cast_if_needed(value: vc.ShaderVariable, dst_type):
+ if value.var_type == dst_type:
+ return value
+
+ return vc.to_dtype(dst_type, value)
+
def global_batch_offset(
registers: FFTRegisters,
r2c: bool = False,
@@ -57,7 +64,7 @@ def from_memory_op(cls,
inverse=inverse)
def write_to_buffer(self,
- buffer: vc.Buff[vc.c64],
+ buffer: vc.Buffer,
register: Optional[vc.ShaderVariable] = None,
io_index: Optional[vc.ShaderVariable] = None):
if register is None:
@@ -67,21 +74,18 @@ def write_to_buffer(self,
io_index = self.io_index
if not self.r2c:
- buffer[io_index] = register
+ buffer[io_index] = _cast_if_needed(register, buffer.var_type)
return
if not self.inverse:
vc.if_statement(self.fft_index < (self.fft_size // 2) + 1)
- buffer[io_index] = register
+ buffer[io_index] = _cast_if_needed(register, buffer.var_type)
vc.end()
return
- packed_value = buffer[io_index // 2]
- vc.if_statement((io_index % 2) == 0)
- packed_value.real = register.real
- vc.else_statement()
- packed_value.imag = register.real
- vc.end()
+ out_scalar_type = buffer.var_type.child_type
+ out_real = _cast_if_needed(register.real, out_scalar_type)
+ buffer[io_index // 2][io_index % 2] = out_real
def global_writes_iterator(
registers: FFTRegisters,
@@ -171,11 +175,11 @@ def signal_range_end(self, register: vc.ShaderVariable):
return
vc.else_statement()
- register[:] = vc.to_complex(0)
+ register[:] = vc.to_dtype(register.var_type, 0)
vc.end()
def read_from_buffer(self,
- buffer: vc.Buff[vc.c64],
+ buffer: vc.Buffer,
register: Optional[vc.ShaderVariable] = None,
io_index: Optional[vc.ShaderVariable] = None):
self.check_in_signal_range()
@@ -187,26 +191,23 @@ def read_from_buffer(self,
register = self.register
if not self.r2c:
- register[:] = buffer[io_index]
+ register[:] = _cast_if_needed(buffer[io_index], register.var_type)
self.signal_range_end(register)
return
if not self.inverse:
- packed_value = buffer[io_index // 2]
- vc.if_statement((io_index % 2) == 0)
- register[:] = vc.to_complex(packed_value.real)
- vc.else_statement()
- register[:] = vc.to_complex(packed_value.imag)
- vc.end()
+ packed_real = buffer[io_index // 2][io_index % 2]
+ packed_complex = vc.to_complex(packed_real)
+ register[:] = _cast_if_needed(packed_complex, register.var_type)
self.signal_range_end(register)
return
vc.if_statement(self.fft_index >= (self.fft_size // 2) + 1)
self.io_index_2[:] = self.r2c_inverse_offset - io_index
- register[:] = buffer[self.io_index_2]
+ register[:] = _cast_if_needed(buffer[self.io_index_2], register.var_type)
register.imag = -register.imag
vc.else_statement()
- register[:] = buffer[io_index]
+ register[:] = _cast_if_needed(buffer[io_index], register.var_type)
vc.end()
self.signal_range_end(register)
@@ -302,7 +303,7 @@ def from_memory_op(cls,
)
def write_to_buffer(self,
- buffer: vc.Buff[vc.c64],
+ buffer: vc.Buffer,
register: Optional[vc.ShaderVariable] = None,
io_index: Optional[vc.ShaderVariable] = None):
if io_index is None:
@@ -311,7 +312,7 @@ def write_to_buffer(self,
if register is None:
register = self.register
- buffer[io_index] = register
+ buffer[io_index] = _cast_if_needed(register, buffer.var_type)
def global_trasposed_write_iterator(registers: FFTRegisters, inner_only: bool = False):
vc.comment("""Writing registers to global memory in transposed order.
diff --git a/vkdispatch/fft/grid_manager.py b/vkdispatch/fft/grid_manager.py
index 22d642af..5d6aa4e9 100644
--- a/vkdispatch/fft/grid_manager.py
+++ b/vkdispatch/fft/grid_manager.py
@@ -6,7 +6,7 @@
from .config import FFTConfig
from .prime_utils import prime_factors
-from .._compat import numpy_compat as npc
+from ..compat import numpy_compat as npc
def allocation_valid(workgroup_size: int, shared_memory_size: int):
valid_workgroup = workgroup_size <= vd.get_context().max_workgroup_invocations
@@ -16,11 +16,12 @@ def allocation_valid(workgroup_size: int, shared_memory_size: int):
def allocate_inline_batches(
batch_num: int,
batch_threads: int,
- N: int,
+ shared_elements: int,
+ element_size: int,
max_workgroup_size: int,
max_total_threads: int):
- shared_memory_allocation = N * vd.complex64.item_size
+ shared_memory_allocation = shared_elements * element_size
batch_num_primes = prime_factors(batch_num)
prime_index = 0
workgroup_size = batch_threads
@@ -157,6 +158,7 @@ def __init__(self, config: FFTConfig, force_sdata: bool = False, declare_variabl
config.batch_inner_count,
config.batch_threads,
config.sdata_allocation if make_sdata_buffer else 0,
+ config.compute_type.item_size,
min(vd.get_context().max_workgroup_size[0], 4),
vd.get_context().max_workgroup_invocations)
@@ -171,6 +173,7 @@ def __init__(self, config: FFTConfig, force_sdata: bool = False, declare_variabl
config.batch_outer_count,
config.batch_threads * self.inline_batches_inner,
config.sdata_allocation * self.inline_batches_inner if make_sdata_buffer else 0,
+ config.compute_type.item_size,
vd.get_context().max_workgroup_size[
1 if self.inline_batches_inner == 1 else 2
],
diff --git a/vkdispatch/fft/io_manager.py b/vkdispatch/fft/io_manager.py
index 1f54fc99..b91d6bd9 100644
--- a/vkdispatch/fft/io_manager.py
+++ b/vkdispatch/fft/io_manager.py
@@ -1,5 +1,6 @@
import vkdispatch as vd
import vkdispatch.codegen as vc
+import vkdispatch.base.dtype as dtypes
from typing import Optional, Tuple
@@ -55,11 +56,22 @@ def __init__(self,
default_registers: FFTRegisters,
shader_context: vd.ShaderContext,
output_map: Optional[vd.MappingFunction],
+ output_type: dtypes.dtype = vd.complex64,
+ input_type: Optional[dtypes.dtype] = None,
input_map: Optional[vd.MappingFunction] = None,
kernel_map: Optional[vd.MappingFunction] = None):
self.default_registers = default_registers
- self.output_proxy = IOProxy(vd.complex64 if output_map is None else output_map, "Output")
- self.input_proxy = IOProxy(input_map, "Input")
+ self.output_proxy = IOProxy(output_type if output_map is None else output_map, "Output")
+
+ if input_map is not None:
+ self.input_proxy = IOProxy(input_map, "Input")
+ elif output_map is not None:
+ if input_type is None:
+ raise ValueError("input_type must be provided when output_map is used without input_map")
+ self.input_proxy = IOProxy(input_type, "Input")
+ else:
+ self.input_proxy = IOProxy(None, "Input")
+
self.kernel_proxy = IOProxy(kernel_map, "Kernel")
output_types = self.output_proxy.buffer_types
@@ -163,4 +175,4 @@ def read_kernel(self, registers: Optional[FFTRegisters] = None, format_transpose
registers,
format_transposed=format_transposed,
inner_only=inner_only
- )
\ No newline at end of file
+ )
diff --git a/vkdispatch/fft/memory_iterators.py b/vkdispatch/fft/memory_iterators.py
index 4c85e046..a7793ab7 100644
--- a/vkdispatch/fft/memory_iterators.py
+++ b/vkdispatch/fft/memory_iterators.py
@@ -22,14 +22,14 @@ def memory_reads_iterator(resources: FFTResources, stage_index: int = 0):
resources.stage_begin(stage_index)
index_list = list(range(resources.config.register_count))
- invocations = resources.invocations[stage_index]
+ invocations = resources.config.stages[stage_index].invocations
for ii, invocation in enumerate(invocations):
resources.invocation_gaurd(stage_index, ii)
register_indicies = index_list[invocation.register_selection]
- offset = invocation.instance_id
+ offset = invocation.get_offset(resources.tid)
stride = resources.config.N // resources.config.stages[stage_index].fft_length
for i in range(len(register_indicies)):
@@ -58,14 +58,14 @@ def memory_writes_iterator(resources: FFTResources, stage_index: int = -1):
index_list = list(range(resources.config.register_count))
element_count = resources.config.stages[stage_index].fft_length
- invocations = resources.invocations[stage_index]
+ invocations = resources.config.stages[stage_index].invocations
for i in range(element_count):
for ii, invocation in enumerate(invocations):
resources.invocation_gaurd(stage_index, ii)
- offset = invocation.sub_sequence_offset
- stride = resources.output_strides[stage_index]
+ offset = invocation.get_sub_sequence_offset(resources.tid)
+ stride = resources.config.stages[stage_index].input_stride
fft_index = offset + i * stride
diff --git a/vkdispatch/fft/precision.py b/vkdispatch/fft/precision.py
new file mode 100644
index 00000000..d9d6d640
--- /dev/null
+++ b/vkdispatch/fft/precision.py
@@ -0,0 +1,99 @@
+import vkdispatch as vd
+
+from typing import Iterable, List, Optional
+
+
+_COMPLEX_PRECISION_ORDER = (vd.complex32, vd.complex64, vd.complex128)
+_COMPLEX_PRECISION_RANK = {dtype: rank for rank, dtype in enumerate(_COMPLEX_PRECISION_ORDER)}
+
+
+def is_complex_precision(dtype) -> bool:
+ return dtype in _COMPLEX_PRECISION_RANK
+
+
+def validate_complex_precision(dtype, *, arg_name: str) -> None:
+ if not is_complex_precision(dtype):
+ raise ValueError(f"{arg_name} must be one of complex32, complex64, or complex128 (got {dtype})")
+
+
+def promote_complex_precisions(dtypes: Iterable) -> vd.dtype:
+ candidates = list(dtypes)
+ if len(candidates) == 0:
+ raise ValueError("At least one complex dtype is required for promotion")
+
+ for candidate in candidates:
+ validate_complex_precision(candidate, arg_name="dtype")
+
+ return max(candidates, key=lambda dtype: _COMPLEX_PRECISION_RANK[dtype])
+
+
+def default_compute_precision(io_precisions: Iterable) -> vd.dtype:
+ promoted = promote_complex_precisions(io_precisions)
+
+ # Default to at least complex64 for numerical stability.
+ if _COMPLEX_PRECISION_RANK[promoted] < _COMPLEX_PRECISION_RANK[vd.complex64]:
+ return vd.complex64
+
+ return promoted
+
+
+def supports_complex_precision(dtype) -> bool:
+ validate_complex_precision(dtype, arg_name="dtype")
+ scalar_type = dtype.child_type
+
+ for device in vd.get_context().device_infos:
+ if scalar_type == vd.float16:
+ if device.float_16_support != 1:
+ return False
+
+ # Half precision in storage buffers typically needs one of these capabilities.
+ if (
+ device.storage_buffer_16_bit_access != 1
+ and device.uniform_and_storage_buffer_16_bit_access != 1
+ ):
+ return False
+
+ if scalar_type == vd.float64 and device.float_64_support != 1:
+ return False
+
+ return True
+
+
+def ensure_supported_complex_precision(dtype, *, role: str) -> None:
+ if not supports_complex_precision(dtype):
+ raise ValueError(f"{role} precision '{dtype.name}' is not supported on the active device set")
+
+
+def resolve_compute_precision(io_precisions: List, compute_precision: Optional[vd.dtype]) -> vd.dtype:
+ if compute_precision is not None:
+ validate_complex_precision(compute_precision, arg_name="compute_type")
+ ensure_supported_complex_precision(compute_precision, role="Compute")
+ return compute_precision
+
+ for io_precision in io_precisions:
+ validate_complex_precision(io_precision, arg_name="io_precision")
+
+ if len(io_precisions) == 0:
+ for candidate in (vd.complex64, vd.complex32):
+ if supports_complex_precision(candidate):
+ return candidate
+
+ raise ValueError(
+ "Unable to resolve a default compute precision supported by all active devices"
+ )
+
+ target = default_compute_precision(io_precisions)
+ if supports_complex_precision(target):
+ return target
+
+ # Auto fallback: drop from complex128 to complex64 when fp64 is unsupported.
+ for candidate in (vd.complex64, vd.complex32):
+ if (
+ _COMPLEX_PRECISION_RANK[candidate] <= _COMPLEX_PRECISION_RANK[target]
+ and supports_complex_precision(candidate)
+ ):
+ return candidate
+
+ raise ValueError(
+ "Unable to resolve an auto compute precision supported by all active devices"
+ )
diff --git a/vkdispatch/fft/prime_utils.py b/vkdispatch/fft/prime_utils.py
index 2db85020..ee1624fa 100644
--- a/vkdispatch/fft/prime_utils.py
+++ b/vkdispatch/fft/prime_utils.py
@@ -1,13 +1,10 @@
from typing import List
import vkdispatch as vd
-from .._compat import numpy_compat as npc
+from ..compat import numpy_compat as npc
def default_register_limit():
- if vd.get_devices()[0].is_nvidia():
- return 16
-
- return 15
+ return 16
def default_max_prime():
return 13
diff --git a/vkdispatch/fft/registers.py b/vkdispatch/fft/registers.py
index 6fe671b3..31c79e32 100644
--- a/vkdispatch/fft/registers.py
+++ b/vkdispatch/fft/registers.py
@@ -32,7 +32,7 @@ def __init__(self, resources: FFTResources, count: int, name: str):
self.config = resources.config
self.registers = [
- vc.new_complex_register(var_name=f"{name}_reg_{i}") for i in range(count)
+ vc.new_register(self.config.compute_type, var_name=f"{name}_reg_{i}") for i in range(count)
]
self.count = count
@@ -53,40 +53,13 @@ def __setitem__(self, index: int, value: vc.ShaderVariable):
self.registers[index][:] = value
def normalize(self):
+ normalization = vc.to_dtype(self.config.compute_type.child_type, self.config.N)
for i in range(self.count):
- self.registers[i][:] = self.registers[i] / self.config.N
-
- def get_input_format(self, stage_index: int = 0) -> Dict[int, int]:
- in_format = {}
-
- stride = self.config.N // self.config.stages[stage_index].fft_length
-
- register_count = len(self.registers)
- register_index_list = list(range(register_count))
-
- for invocation in self.resources.invocations[stage_index]:
- sub_registers = register_index_list[invocation.register_selection]
-
- for i in range(len(sub_registers)):
- in_format[invocation.get_read_index(stride * i)] = sub_registers[i]
-
- return in_format
-
- def get_output_format(self, stage_index: int = -1) -> Dict[int, int]:
- out_format = {}
-
- register_count = len(self.registers)
- register_index_list = list(range(register_count))
-
- for jj in range(self.config.stages[stage_index].fft_length):
- for invocation in self.resources.invocations[stage_index]:
- out_format[invocation.get_write_index(jj)] = register_index_list[invocation.register_selection][jj]
-
- return out_format
+ self.registers[i][:] = self.registers[i] / normalization
def try_shuffle(self, output_stage: int = -1, input_stage: int = 0) -> bool:
- out_format = self.get_output_format(output_stage)
- in_format = self.get_input_format(input_stage)
+ out_format = self.config.stages[output_stage].get_output_format(len(self.registers))
+ in_format = self.config.stages[input_stage].get_input_format(len(self.registers))
if out_format.keys() != in_format.keys():
return False
diff --git a/vkdispatch/fft/resources.py b/vkdispatch/fft/resources.py
index 17b2085d..f63bd04e 100644
--- a/vkdispatch/fft/resources.py
+++ b/vkdispatch/fft/resources.py
@@ -8,59 +8,6 @@
from .config import FFTConfig
from .grid_manager import FFTGridManager
-@dataclasses.dataclass
-class FFTRegisterStageInvocation:
- output_stride: int
- block_width: int
- inner_block_offset: vc.ShaderVariable
- sub_sequence_offset: vc.ShaderVariable
- register_selection: slice
-
- instance_id: vc.ShaderVariable
-
- instance_id0: int
- inner_block_offset0: int
- sub_sequence_offset0: int
-
- def __init__(self,
- stage_fft_length: int,
- stage_instance_count: int,
- output_stride: int,
- instance_index: int,
- tid: vc.ShaderVariable,
- N: int):
- self.output_stride = output_stride
-
- self.block_width = output_stride * stage_fft_length
-
- instance_index_stride = N // (stage_fft_length * stage_instance_count)
-
- self.instance_id = tid + instance_index_stride * instance_index
-
- self.inner_block_offset = self.instance_id % output_stride
-
- if output_stride == 1:
- self.inner_block_offset = 0
-
- self.sub_sequence_offset = self.instance_id * stage_fft_length - self.inner_block_offset * (stage_fft_length - 1)
-
- # pretend tid is 0, used for calculating register shuffles
- self.instance_id0 = instance_index_stride * instance_index
- self.inner_block_offset0 = self.instance_id0 % output_stride
- self.sub_sequence_offset0 = self.instance_id0 * stage_fft_length - self.inner_block_offset0 * (stage_fft_length - 1)
-
- if self.block_width == N:
- self.inner_block_offset = self.instance_id
- self.sub_sequence_offset = self.inner_block_offset
-
- self.register_selection = slice(instance_index * stage_fft_length, (instance_index + 1) * stage_fft_length)
-
- def get_write_index(self, fft_index: int):
- return self.sub_sequence_offset0 + fft_index * self.output_stride
-
- def get_read_index(self, offset: int):
- return self.instance_id0 + offset
-
@dataclasses.dataclass
class FFTResources:
input_batch_offset: vc.ShaderVariable
@@ -78,49 +25,21 @@ class FFTResources:
config: FFTConfig
- output_strides: List[int]
- invocations: List[List[FFTRegisterStageInvocation]]
-
def __init__(self, config: FFTConfig, grid: FFTGridManager):
self.tid = grid.tid
self.grid = grid
self.config = config
self.input_batch_offset = vc.new_uint_register(var_name="input_batch_offset")
self.output_batch_offset = vc.new_uint_register(var_name="output_batch_offset")
- self.omega_register = vc.new_complex_register(var_name="omega_register")
+ self.omega_register = vc.new_register(config.compute_type, var_name="omega_register")
self.subsequence_offset = vc.new_uint_register(var_name="subsequence_offset")
self.io_index = vc.new_uint_register(var_name="io_index")
self.io_index_2 = vc.new_uint_register(var_name="io_index_2")
self.radix_registers = [
- vc.new_complex_register(var_name=f"radix_register_{i}") for i in range(config.max_prime_radix)
+ vc.new_register(config.compute_type, var_name=f"radix_register_{i}") for i in range(config.max_prime_radix)
]
- self.output_strides = []
- self.invocations = []
-
- output_stride = 1
- stage_count = len(config.stages)
-
- for i in range(stage_count):
- stage = config.stages[i]
- stage_invocations = []
-
- for ii in range(stage.instance_count):
- stage_invocations.append(FFTRegisterStageInvocation(
- stage.fft_length,
- stage.instance_count,
- output_stride,
- ii,
- self.tid,
- config.N
- ))
-
- self.output_strides.append(output_stride)
- self.invocations.append(stage_invocations)
-
- output_stride *= config.stages[i].fft_length
-
def stage_begin(self, stage_index: int):
thread_count = self.config.stages[stage_index].thread_count
@@ -144,4 +63,3 @@ def invocation_end(self, stage_index: int):
if stage.remainder_offset == 1:
vc.end()
-
diff --git a/vkdispatch/fft/sdata_manager.py b/vkdispatch/fft/sdata_manager.py
index f7e41fa7..d00ff31e 100644
--- a/vkdispatch/fft/sdata_manager.py
+++ b/vkdispatch/fft/sdata_manager.py
@@ -11,7 +11,7 @@
from .memory_iterators import memory_reads_iterator, memory_writes_iterator
class FFTSDataManager:
- sdata: vc.Buff[vc.c64]
+ sdata: vc.Buffer
sdata_offset: Union[vc.Const[vc.u32], Literal[0]]
sdata_row_size: int
@@ -46,7 +46,7 @@ def __init__(self, config: FFTConfig, grid: FFTGridManager, default_registers: F
total_inner_batches = grid.inline_batches_inner * grid.inline_batches_outer
self.sdata = vc.shared_buffer(
- vd.complex64,
+ config.compute_type,
config.sdata_allocation * total_inner_batches,
var_name="sdata")
@@ -90,7 +90,7 @@ def read_from_sdata(self, registers: Optional[FFTRegisters] = None, stage_index:
def write_to_sdata(self, registers: Optional[FFTRegisters] = None, stage_index: int = -1):
self.op_write()
- self.use_padding = self.padding_enabled and self.resources.output_strides[stage_index] < 32
+ self.use_padding = self.padding_enabled and self.resources.config.stages[stage_index].input_stride < 32
if registers is None:
registers = self.default_registers
@@ -101,4 +101,4 @@ def write_to_sdata(self, registers: Optional[FFTRegisters] = None, stage_index:
if self.use_padding:
self.resources.io_index[:] = self.resources.io_index + (self.resources.io_index // self.sdata_row_size)
- self.sdata[self.resources.io_index] = registers[write_op.register_id]
\ No newline at end of file
+ self.sdata[self.resources.io_index] = registers[write_op.register_id]
diff --git a/vkdispatch/fft/shader_factories.py b/vkdispatch/fft/shader_factories.py
index 7ccf92c7..28a481fd 100644
--- a/vkdispatch/fft/shader_factories.py
+++ b/vkdispatch/fft/shader_factories.py
@@ -2,7 +2,7 @@
import vkdispatch.codegen as vc
from vkdispatch.codegen.abreviations import *
-from .._compat import numpy_compat as npc
+from ..compat import numpy_compat as npc
from typing import Tuple, Optional
from functools import lru_cache
@@ -17,12 +17,28 @@ def make_fft_shader(
r2c: bool = False,
input_map: vd.MappingFunction = None,
output_map: vd.MappingFunction = None,
+ input_type: vd.dtype = None,
+ output_type: vd.dtype = None,
+ compute_type: vd.dtype = None,
input_signal_range: Optional[Tuple[Optional[int], Optional[int]]] = None) -> vd.ShaderFunction:
- with vd.fft.fft_context(buffer_shape, axis=axis) as ctx:
+ if output_type is None:
+ output_type = vd.complex64
+
+ if input_type is None and input_map is None:
+ input_type = output_type
+
+ if compute_type is None:
+ compute_type = vd.complex64
+
+ name = f"fft_shader_{buffer_shape}_{axis}_{inverse}_{normalize_inverse}_{r2c}"
+
+ with vd.fft.fft_context(buffer_shape, axis=axis, compute_type=compute_type, name=name) as ctx:
io_manager = ctx.make_io_manager(
input_map=input_map,
- output_map=output_map
+ output_map=output_map,
+ output_type=output_type,
+ input_type=input_type,
)
io_manager.read_input(
@@ -46,9 +62,10 @@ def make_fft_shader(
@lru_cache(maxsize=None)
def get_transposed_size(
buffer_shape: Tuple,
- axis: int = None) -> vd.ShaderFunction:
+ axis: int = None,
+ compute_type: vd.dtype = vd.complex64) -> vd.ShaderFunction:
- config = vd.fft.FFTConfig(buffer_shape, axis)
+ config = vd.fft.FFTConfig(buffer_shape, axis, compute_type=compute_type)
grid = vd.fft.FFTGridManager(config, True, False)
return npc.prod(grid.local_size) * npc.prod(grid.workgroup_count) * config.register_count
@@ -57,10 +74,13 @@ def get_transposed_size(
def make_transpose_shader(
buffer_shape: Tuple,
axis: int = None,
- kernel_inner_only: bool = False) -> vd.ShaderFunction:
+ kernel_inner_only: bool = False,
+ input_type: vd.dtype = vd.complex64,
+ output_type: vd.dtype = vd.complex64,
+ compute_type: vd.dtype = vd.complex64) -> vd.ShaderFunction:
- with vd.fft.fft_context(buffer_shape, axis=axis) as ctx:
- args = ctx.declare_shader_args([vc.Buffer[c64], vc.Buffer[c64]])
+ with vd.fft.fft_context(buffer_shape, axis=axis, compute_type=compute_type) as ctx:
+ args = ctx.declare_shader_args([vc.Buffer[output_type], vc.Buffer[input_type]])
if kernel_inner_only:
vc.if_statement(ctx.grid.global_outer_offset == 0)
@@ -95,23 +115,43 @@ def make_convolution_shader(
kernel_inner_only: bool = False,
input_map: vd.MappingFunction = None,
output_map: vd.MappingFunction = None,
+ input_type: vd.dtype = None,
+ output_type: vd.dtype = None,
+ kernel_type: vd.dtype = None,
+ compute_type: vd.dtype = None,
input_signal_range: Optional[Tuple[Optional[int], Optional[int]]] = None) -> vd.ShaderFunction:
+ if output_type is None:
+ output_type = vd.complex64
+
+ if input_type is None and input_map is None:
+ input_type = output_type
+
+ if kernel_type is None:
+ kernel_type = vd.complex64
+
+ if compute_type is None:
+ compute_type = vd.complex64
+
if kernel_map is None:
- def kernel_map_func(kernel_buffer: vc.Buffer[c64]):
+ def kernel_map_func(kernel_buffer: vc.Buffer[kernel_type]):
read_op = vd.fft.read_op()
- kernel_val = vc.new_complex_register()
+ kernel_val = vc.new_register(compute_type)
read_op.read_from_buffer(kernel_buffer, register=kernel_val)
read_op.register[:] = vc.mult_complex(read_op.register, kernel_val.conjugate())
- kernel_map = vd.map(kernel_map_func, input_types=[vc.Buffer[c64]])
+ kernel_map = vd.map(kernel_map_func, input_types=[vc.Buffer[kernel_type]])
+
+ name = f"convolution_shader_{buffer_shape}_{axis}"
- with vd.fft.fft_context(buffer_shape, axis=axis) as ctx:
+ with vd.fft.fft_context(buffer_shape, axis=axis, compute_type=compute_type, name=name) as ctx:
io_manager = ctx.make_io_manager(
input_map=input_map,
output_map=output_map,
+ output_type=output_type,
+ input_type=input_type,
kernel_map=kernel_map
)
@@ -123,10 +163,6 @@ def kernel_map_func(kernel_buffer: vc.Buffer[c64]):
ctx.execute(inverse=False)
ctx.register_shuffle()
- vc.comment("""Convolution pipeline phase 2/3.
-Apply one or more frequency-domain kernels to the transformed input spectrum.
-For multi-kernel runs, restore from backup registers so each kernel sees
-identical FFT-domain source values before inverse transformation.""")
backup_registers = None
if kernel_num > 1:
@@ -134,17 +170,19 @@ def kernel_map_func(kernel_buffer: vc.Buffer[c64]):
backup_registers.read_from_registers(ctx.registers)
for kern_index in range(kernel_num):
- vc.comment(f"""Convolution pipeline phase 3/3. Kernel {kern_index + 1}/{kernel_num}.
-Map this kernel onto the current spectrum.
-Run inverse FFT back to the spatial domain, optionally normalize by length,
-and write this kernel's output slice to global memory.""")
+ vc.comment(f"""Convolution pipeline phase 2/3. Kernel {kern_index + 1}/{kernel_num}.
+Map this kernel onto the current spectrum.""")
if backup_registers is not None:
ctx.registers.read_from_registers(backup_registers)
set_global_kernel_index(kern_index)
io_manager.read_kernel(format_transposed=transposed_kernel, inner_only=kernel_inner_only)
-
+
+ vc.comment(f"""Convolution pipeline phase 3/3.
+Run inverse FFT back to the spatial domain, optionally normalize by length,
+and write this kernel's output slice to global memory.""")
+
ctx.execute(inverse=True)
if normalize:
diff --git a/vkdispatch/fft/src_functions.py b/vkdispatch/fft/src_functions.py
new file mode 100644
index 00000000..e8952bb3
--- /dev/null
+++ b/vkdispatch/fft/src_functions.py
@@ -0,0 +1,342 @@
+import vkdispatch as vd
+
+from .shader_factories import make_fft_shader, make_convolution_shader, make_transpose_shader, get_transposed_size
+
+from typing import Tuple, Union, Optional
+
+def fft_src(
+ buffer_shape: Tuple,
+ axis: int = None,
+ inverse: bool = False,
+ normalize_inverse: bool = True,
+ r2c: bool = False,
+ input_map: vd.MappingFunction = None,
+ output_map: vd.MappingFunction = None,
+ input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None,
+ line_numbers: bool = False) -> vd.ShaderSource:
+
+ fft_shader = make_fft_shader(
+ tuple(buffer_shape),
+ axis,
+ inverse=inverse,
+ normalize_inverse=normalize_inverse,
+ r2c=r2c,
+ input_map=input_map,
+ output_map=output_map,
+ input_signal_range=input_signal_range)
+
+ return fft_shader.get_src(line_numbers=line_numbers)
+
+def fft2_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None):
+ assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer Shape must have 2 or 3 dimensions'
+
+ return (
+ fft_src(axis=len(buffer_shape) - 2, input_map=input_map),
+ fft_src(axis=len(buffer_shape) - 1, output_map=output_map)
+ )
+
+def fft3_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None):
+ assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions'
+
+ return (
+ fft_src(buffer_shape, axis=0, input_map=input_map),
+ fft_src(buffer_shape, axis=1),
+ fft_src(buffer_shape, axis=2, output_map=output_map)
+ )
+
+
+def ifft_src(
+ buffer_shape: Tuple,
+ axis: int = None,
+ normalize: bool = True,
+ input_map: vd.MappingFunction = None,
+ output_map: vd.MappingFunction = None):
+ return fft_src(buffer_shape, axis=axis, inverse=True, normalize_inverse=normalize, input_map=input_map, output_map=output_map)
+
+def ifft2_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None):
+ assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions'
+
+ return (
+ ifft_src(buffer_shape, axis=len(buffer_shape) - 2, normalize=normalize, input_map=input_map),
+ ifft_src(buffer_shape, axis=len(buffer_shape) - 1, normalize=normalize, output_map=output_map)
+ )
+
+def ifft3_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None):
+ assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions'
+
+ return (
+ ifft_src(buffer_shape, axis=0, normalize=normalize, input_map=input_map),
+ ifft_src(buffer_shape, axis=1, normalize=normalize),
+ ifft_src(buffer_shape, axis=2, normalize=normalize, output_map=output_map)
+ )
+
+
+def rfft_src(buffer_shape: Tuple):
+ return fft_src(buffer_shape, r2c=True)
+
+def rfft2_src(buffer_shape: Tuple):
+ assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions'
+
+ return (
+ rfft_src(buffer_shape),
+ fft_src(buffer_shape, axis=len(buffer_shape) - 2)
+ )
+
+def rfft3_src(buffer_shape: Tuple):
+ assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions'
+
+ return (
+ rfft_src(buffer_shape),
+ fft_src(buffer_shape, axis=1),
+ fft_src(buffer_shape, axis=0)
+ )
+
+def irfft_src(buffer_shape: Tuple, normalize: bool = True):
+ return fft_src(buffer_shape, inverse=True, normalize_inverse=normalize, r2c=True)
+
+def irfft2_src(buffer_shape: Tuple, normalize: bool = True):
+ assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions'
+
+ return (
+ ifft_src(buffer_shape, axis=len(buffer_shape) - 2, normalize=normalize),
+ irfft_src(buffer_shape, normalize=normalize)
+ )
+
+def irfft3_src(buffer_shape: Tuple, normalize: bool = True):
+ assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions'
+
+ return (
+ ifft_src(buffer_shape, axis=0, normalize=normalize),
+ ifft_src(buffer_shape, axis=1, normalize=normalize),
+ irfft_src(buffer_shape, normalize=normalize)
+ )
+
+def convolve_src(
+ buffer_shape: Tuple,
+ kernel_map: vd.MappingFunction = None,
+ kernel_num: int = 1,
+ axis: int = None,
+ normalize: bool = True,
+ transposed_kernel: bool = False,
+ kernel_inner_only: bool = False,
+ input_map: vd.MappingFunction = None,
+ output_map: vd.MappingFunction = None,
+ input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None,
+ line_numbers: bool = False) -> vd.ShaderSource:
+
+ fft_shader = make_convolution_shader(
+ tuple(buffer_shape),
+ kernel_map,
+ kernel_num,
+ axis,
+ transposed_kernel=transposed_kernel,
+ kernel_inner_only=kernel_inner_only,
+ normalize=normalize,
+ input_map=input_map,
+ output_map=output_map,
+ input_signal_range=input_signal_range)
+
+ return fft_shader.get_src(line_numbers=line_numbers)
+
+def convolve2D_src(
+ buffer_shape: Tuple,
+ kernel_map: vd.MappingFunction = None,
+ normalize: bool = True,
+ transposed_kernel: bool = False,
+ kernel_inner_only: bool = False,
+ input_map: vd.MappingFunction = None,
+ output_map: vd.MappingFunction = None):
+
+ assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions'
+
+ return (
+ fft_src(buffer_shape, input_map=input_map),
+ convolve_src(
+ buffer_shape,
+ kernel_map=kernel_map,
+ transposed_kernel=transposed_kernel,
+ kernel_inner_only=kernel_inner_only,
+ axis=len(buffer_shape) - 2,
+ normalize=normalize
+ ),
+ ifft_src(buffer_shape, normalize=normalize, output_map=output_map)
+ )
+
+def convolve2DR_src(
+ buffer_shape: Tuple,
+ kernel_map: vd.MappingFunction = None,
+ transposed_kernel: bool = False,
+ kernel_inner_only: bool = False,
+ normalize: bool = True):
+
+ assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions'
+
+ return (
+ rfft_src(buffer_shape),
+ convolve_src(
+ buffer_shape,
+ kernel_map=kernel_map,
+ transposed_kernel=transposed_kernel,
+ kernel_inner_only=kernel_inner_only,
+ axis=len(buffer_shape) - 2,
+ normalize=normalize
+ ),
+ irfft_src(buffer_shape, normalize=normalize)
+ )
+
+def transpose_src(
+ buffer_shape: Tuple,
+ axis: int = None,
+ kernel_inner_only: bool = False,
+ line_numbers: bool = False) -> vd.Buffer:
+
+ transpose_shader = make_transpose_shader(
+ tuple(buffer_shape),
+ axis=axis,
+ kernel_inner_only=kernel_inner_only
+ )
+
+ return transpose_shader.get_src(line_numbers=line_numbers)
+
+
+def fft_print_src(
+ buffer_shape: Tuple,
+ axis: int = None,
+ inverse: bool = False,
+ normalize_inverse: bool = True,
+ r2c: bool = False,
+ input_map: vd.MappingFunction = None,
+ output_map: vd.MappingFunction = None,
+ input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None,
+ line_numbers: bool = False) -> vd.ShaderSource:
+
+ print(fft_src(
+ buffer_shape,
+ axis,
+ inverse=inverse,
+ normalize_inverse=normalize_inverse,
+ r2c=r2c,
+ input_map=input_map,
+ output_map=output_map,
+ input_signal_range=input_signal_range,
+ line_numbers=line_numbers))
+
+def fft2_print_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None):
+ srcs = fft2_src(buffer_shape, input_map=input_map, output_map=output_map)
+ print(f"// FFT Stage 1 (axis {len(buffer_shape) - 2}):\n{srcs[0]}\n// FFT Stage 2 (axis {len(buffer_shape) - 1}):\n{srcs[1]}")
+
+def fft3_print_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None):
+ srcs = fft3_src(buffer_shape, input_map=input_map, output_map=output_map)
+ print(f"// FFT Stage 1 (axis 0):\n{srcs[0]}\n// FFT Stage 2 (axis 1):\n{srcs[1]}\n// FFT Stage 3 (axis 2):\n{srcs[2]}")
+
+def ifft_print_src(
+ buffer_shape: Tuple,
+ axis: int = None,
+ normalize: bool = True,
+ input_map: vd.MappingFunction = None,
+ output_map: vd.MappingFunction = None):
+ print(ifft_src(buffer_shape, axis=axis, normalize=normalize, input_map=input_map, output_map=output_map))
+
+def ifft2_print_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None):
+ srcs = ifft2_src(buffer_shape, normalize=normalize, input_map=input_map, output_map=output_map)
+ print(f"// IFFT Stage 1 (axis {len(buffer_shape) - 2}):\n{srcs[0]}\n// IFFT Stage 2 (axis {len(buffer_shape) - 1}):\n{srcs[1]}")
+
+def ifft3_print_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None):
+ srcs = ifft3_src(buffer_shape, normalize=normalize, input_map=input_map, output_map=output_map)
+ print(f"// IFFT Stage 1 (axis 0):\n{srcs[0]}\n// IFFT Stage 2 (axis 1):\n{srcs[1]}\n// IFFT Stage 3 (axis 2):\n{srcs[2]}")
+
+def rfft_print_src(buffer_shape: Tuple):
+ print(rfft_src(buffer_shape))
+
+def rfft2_print_src(buffer_shape: Tuple):
+ srcs = rfft2_src(buffer_shape)
+ print(f"// RFFT Stage 1:\n{srcs[0]}\n// RFFT Stage 2 (axis {len(buffer_shape) - 2}):\n{srcs[1]}")
+
+def rfft3_print_src(buffer_shape: Tuple):
+ srcs = rfft3_src(buffer_shape)
+ print(f"// RFFT Stage 1:\n{srcs[0]}\n// RFFT Stage 2 (axis 1):\n{srcs[1]}\n// RFFT Stage 3 (axis 0):\n{srcs[2]}")
+
+def irfft_print_src(buffer_shape: Tuple, normalize: bool = True):
+ print(irfft_src(buffer_shape, normalize=normalize))
+
+def irfft2_print_src(buffer_shape: Tuple, normalize: bool = True):
+ srcs = irfft2_src(buffer_shape, normalize=normalize)
+ print(f"// IRFFT Stage 1 (axis {len(buffer_shape) - 2}):\n{srcs[0]}\n// IRFFT Stage 2:\n{srcs[1]}")
+
+def irfft3_print_src(buffer_shape: Tuple, normalize: bool = True):
+ srcs = irfft3_src(buffer_shape, normalize=normalize)
+ print(f"// IRFFT Stage 1 (axis 0):\n{srcs[0]}\n// IRFFT Stage 2 (axis 1):\n{srcs[1]}\n// IRFFT Stage 3:\n{srcs[2]}")
+
+def convolve_print_src(
+ buffer_shape: Tuple,
+ kernel_map: vd.MappingFunction = None,
+ kernel_num: int = 1,
+ axis: int = None,
+ normalize: bool = True,
+ transposed_kernel: bool = False,
+ kernel_inner_only: bool = False,
+ input_map: vd.MappingFunction = None,
+ output_map: vd.MappingFunction = None,
+ input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None,
+ line_numbers: bool = False) -> vd.ShaderSource:
+
+ print(convolve_src(
+ buffer_shape,
+ kernel_map=kernel_map,
+ kernel_num=kernel_num,
+ axis=axis,
+ normalize=normalize,
+ transposed_kernel=transposed_kernel,
+ kernel_inner_only=kernel_inner_only,
+ input_map=input_map,
+ output_map=output_map,
+ input_signal_range=input_signal_range,
+ line_numbers=line_numbers
+ ))
+
+def convolve2D_print_src(
+ buffer_shape: Tuple,
+ kernel_map: vd.MappingFunction = None,
+ normalize: bool = True,
+ transposed_kernel: bool = False,
+ kernel_inner_only: bool = False,
+ input_map: vd.MappingFunction = None,
+ output_map: vd.MappingFunction = None):
+ srcs = convolve2D_src(
+ buffer_shape,
+ kernel_map=kernel_map,
+ normalize=normalize,
+ transposed_kernel=transposed_kernel,
+ kernel_inner_only=kernel_inner_only,
+ input_map=input_map,
+ output_map=output_map
+ )
+ print(f"// FFT Stage (axis {len(buffer_shape) - 2}):\n{srcs[0]}\n// Convolution Stage (axis {len(buffer_shape) - 2}):\n{srcs[1]}\n// IFFT Stage:\n{srcs[2]}")
+
+def convolve2DR_print_src(
+ buffer_shape: Tuple,
+ kernel_map: vd.MappingFunction = None,
+ transposed_kernel: bool = False,
+ kernel_inner_only: bool = False,
+ normalize: bool = True):
+ srcs = convolve2DR_src(
+ buffer_shape,
+ kernel_map=kernel_map,
+ transposed_kernel=transposed_kernel,
+ kernel_inner_only=kernel_inner_only,
+ normalize=normalize
+ )
+ print(f"// RFFT Stage:\n{srcs[0]}\n// Convolution Stage (axis {len(buffer_shape) - 2}):\n{srcs[1]}\n// IRFFT Stage:\n{srcs[2]}")
+
+def transpose_print_src(
+ buffer_shape: Tuple,
+ axis: int = None,
+ kernel_inner_only: bool = False,
+ line_numbers: bool = False) -> vd.Buffer:
+
+ print(transpose_src(
+ buffer_shape,
+ axis=axis,
+ kernel_inner_only=kernel_inner_only,
+ line_numbers=line_numbers
+ ))
\ No newline at end of file
diff --git a/vkdispatch/fft/stages.py b/vkdispatch/fft/stages.py
new file mode 100644
index 00000000..0cb348fd
--- /dev/null
+++ b/vkdispatch/fft/stages.py
@@ -0,0 +1,198 @@
+import vkdispatch as vd
+import vkdispatch.codegen as vc
+import dataclasses
+from typing import List, Tuple, Dict
+
+from ..compat import numpy_compat as npc
+from .prime_utils import prime_factors
+
+@dataclasses.dataclass
+class FFTStagePlanInvocation:
+ fft_length: int
+ input_stride: int
+ instance_index: int
+ instance_index_stride: int
+ block_width: int
+ full_width_block: bool
+ instance_id0: int
+ inner_block_offset0: int
+ sub_sequence_offset0: int
+ register_selection: slice
+
+ def __init__(self,
+ stage_fft_length: int,
+ stage_instance_count: int,
+ input_stride: int,
+ instance_index: int,
+ N: int):
+ self.fft_length = stage_fft_length
+ self.input_stride = input_stride
+ self.instance_index = instance_index
+ self.block_width = input_stride * stage_fft_length
+ self.instance_index_stride = N // (stage_fft_length * stage_instance_count)
+
+ self.full_width_block = self.block_width == N
+
+ # pretend tid is 0, used for calculating register shuffles
+ self.instance_id0 = self.instance_index_stride * instance_index
+ self.inner_block_offset0 = self.instance_id0 % input_stride
+ self.sub_sequence_offset0 = self.instance_id0 * stage_fft_length - self.inner_block_offset0 * (stage_fft_length - 1)
+
+ self.register_selection = slice(instance_index * stage_fft_length, (instance_index + 1) * stage_fft_length)
+
+ def get_offset(self, tid: vc.ShaderVariable):
+ return tid + self.instance_index_stride * self.instance_index
+
+ def get_inner_block_offset(self, tid: vc.ShaderVariable):
+ if self.input_stride == 1:
+ return 0
+
+ if self.full_width_block:
+ return self.get_offset(tid)
+
+ return self.get_offset(tid) % self.input_stride
+
+ def get_sub_sequence_offset(self, tid: vc.ShaderVariable):
+ if self.full_width_block:
+ return self.get_offset(tid)
+
+ return self.get_offset(tid) * self.fft_length - self.get_inner_block_offset(tid) * (self.fft_length - 1)
+
+ def get_write_index(self, fft_index: int):
+ return self.sub_sequence_offset0 + fft_index * self.input_stride
+
+ def get_read_index(self, offset: int):
+ return self.instance_id0 + offset
+
+@dataclasses.dataclass
+class FFTRegisterStageConfig:
+ """
+ Configuration for an FFT register stage.
+
+ Attributes:
+
+ primes (Tuple[int]): The prime numbers used for factorization.
+ fft_length (int): The length of each FFT stage.
+ instance_count (int): The number of instances required to achieve the desired level of parallelism.
+ registers_used (int): The total number of registers used by the FFT stage.
+ remainder (int): The remainder of `N` divided by `registers_used`.
+ remainder_offset (int): A flag indicating whether the remainder is non-zero.
+ extra_ffts (int): The additional number of FFT stages required to process the remainder.
+ thread_count (int): The total number of threads used in the computation.
+ sdata_size (int): The size of the shared memory buffer used to store intermediate results.
+ sdata_width (int): The width of each element in the shared memory buffer.
+ sdata_width_padded (int): The padded width of each element in the shared memory buffer.
+
+ """
+
+ N: int
+ primes: Tuple[int]
+ fft_length: int
+ instance_count: int
+ registers_used: int
+ remainder: int
+ remainder_offset: int
+ extra_ffts: int
+ thread_count: int
+ sdata_size: int
+ sdata_width: int
+ sdata_width_padded: int
+ input_stride: int
+ output_stride: int
+ invocations: Tuple[FFTStagePlanInvocation]
+
+ def __init__(self, primes: List[int],
+ max_register_count: int,
+ N: int,
+ compute_item_size: int,
+ input_stride: int):
+ """
+ Initializes the FFTRegisterStageConfig object.
+
+ Parameters:
+
+ primes (List[int]): The prime numbers to use for factorization.
+ max_register_count (int): The maximum number of registers allowed per thread.
+ N (int): The length of the input data.
+
+ """
+ self.N = N
+ self.primes = tuple(primes)
+ self.input_stride = input_stride
+ self.fft_length = int(round(npc.prod(primes)))
+ self.output_stride = self.input_stride * self.fft_length
+ instance_primes = prime_factors(N // self.fft_length)
+
+ self.instance_count = 1
+
+ while len(instance_primes) > 0:
+ if self.instance_count * self.fft_length * instance_primes[0] > max_register_count:
+ break
+ self.instance_count *= instance_primes[0]
+ instance_primes = instance_primes[1:]
+
+ self.registers_used = self.fft_length * self.instance_count
+
+ self.remainder = N % self.registers_used
+ assert self.remainder % self.fft_length == 0, "Remainder must be divisible by the FFT length"
+ self.remainder_offset = 1 if self.remainder != 0 else 0
+ self.extra_ffts = self.remainder // self.fft_length
+
+ self.thread_count = N // self.registers_used + self.remainder_offset
+
+ self.sdata_width = self.registers_used
+
+ threads_primes = prime_factors(self.thread_count)
+
+ while self.sdata_width < 16 and len(threads_primes) > 0:
+ self.sdata_width *= threads_primes[0]
+ threads_primes = threads_primes[1:]
+
+ self.sdata_width_padded = self.sdata_width
+
+ if self.sdata_width_padded % 2 == 0:
+ self.sdata_width_padded += 1
+
+ self.sdata_size = self.sdata_width_padded * int(npc.prod(threads_primes))
+
+ if self.sdata_size > vd.get_context().max_shared_memory // compute_item_size:
+ self.sdata_width_padded = self.sdata_width
+ self.sdata_size = self.sdata_width_padded * int(npc.prod(threads_primes))
+
+ invocations = []
+ for instance_index in range(self.instance_count):
+ invocations.append(FFTStagePlanInvocation(
+ stage_fft_length=self.fft_length,
+ stage_instance_count=self.instance_count,
+ input_stride=input_stride,
+ instance_index=instance_index,
+ N=N
+ ))
+
+ self.invocations = tuple(invocations)
+
+ def get_input_format(self, register_count: int) -> Dict[int, int]:
+ in_format = {}
+
+ stride = self.N // self.fft_length
+
+ register_index_list = list(range(register_count))
+
+ for invocation in self.invocations:
+ sub_registers = register_index_list[invocation.register_selection]
+
+ for i in range(len(sub_registers)):
+ in_format[invocation.get_read_index(stride * i)] = sub_registers[i]
+
+ return in_format
+
+ def get_output_format(self, register_count: int) -> Dict[int, int]:
+ out_format = {}
+
+ register_index_list = list(range(register_count))
+
+ for jj in range(self.fft_length):
+ for invocation in self.invocations:
+ out_format[invocation.get_write_index(jj)] = register_index_list[invocation.register_selection][jj]
+
+ return out_format
\ No newline at end of file
diff --git a/vkdispatch/reduce/operations.py b/vkdispatch/reduce/operations.py
index 9cabb583..4081982b 100644
--- a/vkdispatch/reduce/operations.py
+++ b/vkdispatch/reduce/operations.py
@@ -7,6 +7,8 @@
from typing import Union
from typing import Optional
+
+
@dataclasses.dataclass
class ReduceOp:
name: str
@@ -31,14 +33,14 @@ class ReduceOp:
SubgroupMin = ReduceOp(
name="min",
reduction=lambda x, y: vc.min(x, y),
- identity=vc.inf_f32,
+ identity="inf",
subgroup_reduction=vc.subgroup_min
)
SubgroupMax = ReduceOp(
name="max",
reduction=lambda x, y: vc.max(x, y),
- identity=vc.ninf_f32,
+ identity="-inf",
subgroup_reduction=vc.subgroup_max
)
@@ -61,4 +63,4 @@ class ReduceOp:
reduction=lambda x, y: x ^ y,
identity=0,
subgroup_reduction=vc.subgroup_xor
-)
\ No newline at end of file
+)
diff --git a/vkdispatch/reduce/reduce_function.py b/vkdispatch/reduce/reduce_function.py
index 6691b141..e8438498 100644
--- a/vkdispatch/reduce/reduce_function.py
+++ b/vkdispatch/reduce/reduce_function.py
@@ -6,7 +6,7 @@
from typing import List, Optional
-from .._compat import numpy_compat as npc
+from ..compat import numpy_compat as npc
class ReduceFunction:
def __init__(self,
@@ -49,11 +49,26 @@ def make_stages(self):
self.group_size,
True,
)
+
+ def get_src(self, line_numbers: bool = None) -> str:
+ self.make_stages()
+
+ return [
+ self.stage1.get_src(line_numbers),
+ self.stage2.get_src(line_numbers)
+ ]
+
+ def print_src(self, line_numbers: bool = None):
+ srcs = self.get_src(line_numbers)
+
+ print(f"// Reduction Stage 1:\n{srcs[0]}\n// Reduction Stage 2:\n{srcs[1]}")
def __repr__(self) -> str:
self.make_stages()
- return f"Stage 1:\n{self.stage1}\nStage 2:\n{self.stage2}"
+ srcs = self.get_src()
+
+ return f"// Reduction Stage 1:\n{srcs[0]}\n// Reduction Stage 2:\n{srcs[1]}"
def __call__(self, *args, **kwargs) -> vd.Buffer:
self.make_stages()
diff --git a/vkdispatch/reduce/stage.py b/vkdispatch/reduce/stage.py
index a9c91770..3bce6759 100644
--- a/vkdispatch/reduce/stage.py
+++ b/vkdispatch/reduce/stage.py
@@ -8,16 +8,16 @@
@dataclasses.dataclass
class ReductionParams:
- input_offset: vd.int32
- input_size: vd.int32
- input_stride: vd.int32
- input_y_batch_stride: vd.int32
- input_z_batch_stride: vd.int32
+ input_offset: vd.uint32
+ input_size: vd.uint32
+ input_stride: vd.uint32
+ input_y_batch_stride: vd.uint32
+ input_z_batch_stride: vd.uint32
- output_offset: vd.int32
- output_stride: vd.int32
- output_y_batch_stride: vd.int32
- output_z_batch_stride: vd.int32
+ output_offset: vd.uint32
+ output_stride: vd.uint32
+ output_y_batch_stride: vd.uint32
+ output_z_batch_stride: vd.uint32
__static_global_io_index: vc.ShaderVariable = None
@@ -36,7 +36,14 @@ def global_reduce(
map_func: Optional[vd.MappingFunction] = None):
ind = (vc.global_invocation_id().x * params.input_stride).to_register("ind")
- reduction_aggregate = vc.new_register(out_type, reduction.identity, var_name="reduction_aggregate")
+
+ reduction_identity = reduction.identity
+ if reduction_identity == "inf":
+ reduction_identity = vc.inf_f32() if out_type == vd.float32 else vc.inf_f64()
+ elif reduction_identity == "-inf":
+ reduction_identity = vc.ninf_f32() if out_type == vd.float32 else vc.ninf_f64()
+
+ reduction_aggregate = vc.new_register(out_type, reduction_identity, var_name="reduction_aggregate")
batch_offset = vc.workgroup_id().y * params.input_y_batch_stride
inside_batch_offset = vc.workgroup_id().z * params.input_z_batch_stride
@@ -76,15 +83,25 @@ def workgroup_reduce(
sdata[tid] = reduction_aggregate
vc.barrier()
+
+ subgroup_reduce_size = vd.get_context().subgroup_size
+
+ if not vd.get_context().subgroup_enabled:
+ subgroup_reduce_size = 1
current_size = group_size // 2
- while current_size > vd.get_context().subgroup_size:
+ while current_size > subgroup_reduce_size:
vc.if_statement(tid < current_size)
sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + current_size])
- if current_size // 2 > vd.get_context().subgroup_size:
+ if current_size // 2 > subgroup_reduce_size:
vc.end()
else:
- vc.else_if_statement(tid < 2*vc.subgroup_size())
+ tid_limit = 2
+
+ if subgroup_reduce_size != 1:
+ tid_limit = 2*vc.subgroup_size()
+
+ vc.else_if_statement(tid < tid_limit)
sdata[tid] = vc.new_register(out_type, 0)
vc.end()
@@ -99,22 +116,28 @@ def subgroup_reduce(
reduction: ReduceOp,
group_size: int):
tid = vc.local_invocation_id().x
- subgroup_size = vd.get_context().subgroup_size
+ subgroup_reduce_size = vd.get_context().subgroup_size
- if group_size > subgroup_size:
- vc.if_all(tid < subgroup_size)
- sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + subgroup_size])
+ if not vd.get_context().subgroup_enabled:
+ subgroup_reduce_size = 1
+
+ if group_size > subgroup_reduce_size:
+ vc.if_statement(tid < subgroup_reduce_size)
+ sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + subgroup_reduce_size])
vc.end()
+
+ if subgroup_reduce_size == 1:
+ return sdata[tid].to_register("local_var")
+
vc.subgroup_barrier()
-
if reduction.subgroup_reduction is not None:
local_var = sdata[tid].to_register("local_var")
local_var[:] = reduction.subgroup_reduction(local_var)
return local_var
else:
- current_size = subgroup_size // 2
+ current_size = subgroup_reduce_size // 2
while current_size > 1:
vc.if_statement(tid < current_size)
sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + current_size])
@@ -134,6 +157,8 @@ def make_reduction_stage(
output_is_input: bool,
map_func: Optional[vd.MappingFunction] = None,
input_types: List = None) -> vd.ShaderFunction:
+
+ name = f"reduction_stage_{reduction.name}_{out_type.name}_{input_types}_{group_size}"
with vd.shader_context() as context:
signature_type_array = []
@@ -162,4 +187,4 @@ def make_reduction_stage(
input_variables[0][batch_offset + output_offset + params.output_offset] = local_var
vc.end()
- return context.get_function(local_size=(group_size, 1, 1))
+ return context.get_function(local_size=(group_size, 1, 1), name=name)
diff --git a/vkdispatch/shader/context.py b/vkdispatch/shader/context.py
index 74688e63..9bd5713c 100644
--- a/vkdispatch/shader/context.py
+++ b/vkdispatch/shader/context.py
@@ -1,9 +1,8 @@
import vkdispatch as vd
import vkdispatch.codegen as vc
-from .signature import ShaderSignature
-
-from typing import List
+from .signature import ShaderSignature, ShaderArgumentType
+from typing import List, Optional, Any
import contextlib
@@ -15,21 +14,64 @@ class ShaderContext:
def __init__(self, builder: vc.ShaderBuilder):
self.builder = builder
self.signature = None
-
+ self.shader_function = None
+
def get_function(self,
local_size=None,
workgroups=None,
- exec_count=None) -> vd.ShaderFunction:
- return vd.ShaderFunction.from_description(
- self.builder.build("shader"),
+ exec_count=None,
+ name: Optional[str] = None) -> vd.ShaderFunction:
+ if self.shader_function is not None:
+ return self.shader_function
+
+ description = self.builder.build("shader" if name is None else name)
+
+ # Resource bindings are declared before final shader layout is known.
+ # For some shader construction paths (e.g. from_description), signatures are
+ # pre-populated and still hold logical bindings assuming a reserved UBO at 0.
+ binding_shift = description.resource_binding_base - 1
+ if binding_shift != 0:
+ binding_access_len = len(description.binding_access)
+ needs_remap = False
+
+ for shader_arg in self.signature.arguments:
+ if (
+ shader_arg.binding is not None
+ and (
+ shader_arg.arg_type == ShaderArgumentType.BUFFER
+ or shader_arg.arg_type == ShaderArgumentType.IMAGE
+ )
+ and shader_arg.binding >= binding_access_len
+ ):
+ needs_remap = True
+ break
+
+ if needs_remap:
+ for shader_arg in self.signature.arguments:
+ if (
+ shader_arg.binding is not None
+ and (
+ shader_arg.arg_type == ShaderArgumentType.BUFFER
+ or shader_arg.arg_type == ShaderArgumentType.IMAGE
+ )
+ ):
+ shader_arg.binding += binding_shift
+
+ self.shader_function = vd.ShaderFunction(
+ description,
self.signature,
local_size=local_size,
workgroups=workgroups,
exec_count=exec_count
)
+
+ return self.shader_function
- def declare_input_arguments(self, annotations: List):
- self.signature = ShaderSignature.from_type_annotations(self.builder, annotations)
+ def declare_input_arguments(self,
+ annotations: List,
+ names: Optional[List[str]] = None,
+ defaults: Optional[List[Any]] = None):
+ self.signature = ShaderSignature.from_type_annotations(self.builder, annotations, names, defaults)
return self.signature.get_variables()
@contextlib.contextmanager
diff --git a/vkdispatch/shader/decorator.py b/vkdispatch/shader/decorator.py
index 88e2ab8e..0dbe5239 100644
--- a/vkdispatch/shader/decorator.py
+++ b/vkdispatch/shader/decorator.py
@@ -1,9 +1,12 @@
import vkdispatch as vd
import vkdispatch.codegen as vc
+import dataclasses
import inspect
from typing import Callable, TypeVar
+from .context import shader_context
+
import sys
if sys.version_info >= (3, 10):
@@ -12,6 +15,31 @@
else:
P = ... # Placeholder for older Python versions
+def inspect_function_signature(func: Callable):
+ func_signature = inspect.signature(func)
+
+ annotations = []
+ names = []
+ defaults = []
+
+ for param in func_signature.parameters.values():
+ if param.annotation == inspect.Parameter.empty:
+ raise ValueError("All parameters must be annotated")
+
+
+ if not dataclasses.is_dataclass(param.annotation): # issubclass(param.annotation.__origin__, dataclasses.dataclass):
+ if not hasattr(param.annotation, '__args__'):
+ raise TypeError(f"Argument '{param.name}: vd.{param.annotation}' must have a type annotation")
+
+ if len(param.annotation.__args__) != 1:
+ raise ValueError(f"Type '{param.name}: vd.{param.annotation.__name__}' must have exactly one type argument")
+
+ annotations.append(param.annotation)
+ names.append(param.name)
+ defaults.append(param.default if param.default != inspect.Parameter.empty else None)
+
+ return annotations, names, defaults
+
def shader(
exec_size=None,
local_size=None,
@@ -43,12 +71,11 @@ def shader(
raise ValueError("Cannot specify both 'workgroups' and 'exec_size'")
def decorator_callback(func: Callable[P, None]) -> Callable[P, None]:
- return vd.ShaderFunction(
- func,
- local_size=local_size,
- workgroups=workgroups,
- exec_count=exec_size,
- flags=flags
- )
+ with shader_context(flags=flags) as context:
+ annotations, names, defaults = inspect_function_signature(func)
+ args = context.declare_input_arguments(annotations, names, defaults)
+ func(*args)
+
+ return context.get_function(local_size=local_size, workgroups=workgroups, exec_count=exec_size, name=func.__name__)
return decorator_callback
diff --git a/vkdispatch/shader/shader_function.py b/vkdispatch/shader/shader_function.py
index 72c9ee83..18e135ab 100644
--- a/vkdispatch/shader/shader_function.py
+++ b/vkdispatch/shader/shader_function.py
@@ -12,12 +12,10 @@
from .signature import ShaderArgumentType, ShaderSignature
import uuid
-import sys
import dataclasses
-from .._compat import numpy_compat as npc
-from ..base.backend import BACKEND_PYCUDA, BACKEND_VULKAN
+from ..compat import numpy_compat as npc
class LaunchParametersHolder:
def __init__(self, names_and_defaults, args, kwargs) -> None:
@@ -58,7 +56,7 @@ class ExectionBounds:
def __init__(self, names_and_defaults, local_size, workgroups, exec_size) -> None:
self.names_and_defaults = names_and_defaults
- self.local_size = local_size
+ self.local_size = tuple(local_size)
self.workgroups = workgroups
self.exec_size = exec_size
@@ -134,56 +132,49 @@ def get_blocks_and_limits(self, args, kwargs) -> Tuple[Tuple[int, int, int], Tup
return (my_blocks, my_limits)
+@dataclasses.dataclass
+class ShaderSource:
+ name: str
+ code: str
+ local_size: Tuple[int, int, int]
+
+ def __repr__(self):
+ return f"// ====== Source Code for '{self.name}', workgroup_size: {self.local_size} ======\n{self.code}"
+
class ShaderFunction:
plan: ComputePlan
- func: Callable
shader_description: vc.ShaderDescription
shader_signature: ShaderSignature
bounds: ExectionBounds
ready: bool
+ name: str
source: str
flags: vc.ShaderFlags
+ local_size: Union[Tuple[int, int, int], Callable, None]
+ workgroups: Union[Tuple[int, int, int], Callable, None]
+ exec_size: Union[Tuple[int, int, int], Callable, None]
def __init__(self,
- func: Callable,
+ shader_description: vc.ShaderDescription,
+ shader_signature: ShaderSignature,
local_size=None,
workgroups=None,
exec_count=None,
- flags: vc.ShaderFlags = vc.ShaderFlags.NONE) -> None:
+ flags: vc.ShaderFlags = vc.ShaderFlags.NONE,
+ name: str = None) -> None:
self.plan = None
- self.func = func
- self.shader_description = None
- self.shader_signature = None
+ self.shader_description = shader_description
+ self.shader_signature = shader_signature
self.bounds = None
self.ready = False
+ self.name = name if name is not None else None
self.source = None
self.local_size = local_size
self.workgroups = workgroups
self.exec_size = exec_count
self.flags = flags
- def from_description(
- shader_description: vc.ShaderDescription,
- shader_signature: ShaderSignature,
- local_size=None,
- workgroups=None,
- exec_count=None,
-
- ) -> "ShaderFunction":
- shader_obj = ShaderFunction(
- func=None,
- local_size=local_size,
- workgroups=workgroups,
- exec_count=exec_count,
- flags=vc.ShaderFlags.NONE
- )
-
- shader_obj.shader_description = shader_description
- shader_obj.shader_signature = shader_signature
-
- return shader_obj
-
def build(self):
if self.ready:
return
@@ -194,70 +185,65 @@ def build(self):
else [vd.get_context().max_workgroup_size[0], 1, 1]
)
- if self.shader_description is None or self.shader_signature is None:
- assert self.shader_description is None and self.shader_signature is None, "Shader description and signature must both be set or both be None!"
- assert self.func is not None, "Cannot build a shader without a function!"
-
- builder = vc.ShaderBuilder(
- flags=self.flags,
- is_apple_device=vd.get_context().is_apple()
- )
- old_builder = vc.set_builder(builder)
-
- signature = ShaderSignature.from_inspectable_function(builder, self.func)
-
- self.func(*signature.get_variables())
-
- vc.set_builder(old_builder)
-
- self.shader_description = builder.build(self.func.__module__ + "." + self.func.__name__)
- self.shader_signature = signature
-
self.bounds = ExectionBounds(self.shader_signature.get_names_and_defaults(), my_local_size, self.workgroups, self.exec_size)
- if not sys.implementation.name == "Brython":
- runtime_backend = vd.get_backend()
- shader_backend_name = (
- self.shader_description.backend.name
- if self.shader_description.backend is not None
- else "glsl"
- )
-
- if runtime_backend == BACKEND_PYCUDA and shader_backend_name != "cuda":
- raise RuntimeError(
- "PyCUDA runtime backend requires CUDA codegen output. "
- "Call vd.initialize(backend='pycuda') before building shaders."
- )
+ shader_backend_name = (
+ self.shader_description.backend.name
+ if self.shader_description.backend is not None
+ else "glsl"
+ )
- if runtime_backend == BACKEND_VULKAN and shader_backend_name == "cuda":
- raise RuntimeError(
- "Vulkan runtime backend cannot execute CUDA codegen output. "
- "Use GLSL codegen or initialize with backend='pycuda'."
- )
+ if vd.is_dummy():
+ pass
+ elif vd.is_cuda() and shader_backend_name != "cuda":
+ raise RuntimeError(
+ "The selected CUDA runtime backend requires CUDA codegen output. "
+ "Call vd.initialize(backend='cuda') "
+ "before building shaders."
+ )
+ elif vd.is_opencl() and shader_backend_name != "opencl":
+ raise RuntimeError(
+ "The selected OpenCL runtime backend requires OpenCL codegen output. "
+ "Call vd.initialize(backend='opencl') "
+ "before building shaders."
+ )
+ elif vd.is_vulkan() and shader_backend_name == "cuda":
+ raise RuntimeError(
+ "Vulkan runtime backend cannot execute CUDA codegen output. "
+ "Use GLSL codegen or initialize with backend='cuda'."
+ )
+ elif vd.is_vulkan() and shader_backend_name == "opencl":
+ raise RuntimeError(
+ "Vulkan runtime backend cannot execute OpenCL codegen output. "
+ "Use GLSL codegen or initialize with backend='opencl'."
+ )
self.source = self.shader_description.make_source(
my_local_size[0], my_local_size[1], my_local_size[2]
)
try:
- self.plan = ComputePlan(
- self.source,
- self.shader_description.binding_type_list,
- self.shader_description.pc_size,
- self.shader_description.name
- )
+ if not vd.is_dummy():
+ self.plan = ComputePlan(
+ self.source,
+ self.shader_description.binding_type_list,
+ self.shader_description.pc_size,
+ self.shader_description.name
+ )
except Exception as e:
print(f"Error building shader: {e}")
- print(self.make_repr())
+ print(self.get_src(build=False, line_numbers=True))
raise e
self.ready = True
def __repr__(self) -> str:
- self.build()
- return self.make_repr()
+ return self.get_src().__repr__()
- def make_repr(self, line_numbers: bool = None) -> str:
+ def get_src(self, line_numbers: bool = None, build: bool = True) -> ShaderSource:
+ if build:
+ self.build()
+
result = ""
if line_numbers is None:
@@ -268,9 +254,14 @@ def make_repr(self, line_numbers: bool = None) -> str:
result += f"{line_prefix}{line}\n"
- return result
+ return ShaderSource(name=self.name, code=result, local_size=self.bounds.local_size)
+
+ def print_src(self, line_numbers: bool = None):
+ print(self.get_src(line_numbers))
def __call__(self, *args, **kwargs):
+ assert not vd.is_dummy(), "Cannot execute shader functions with dummy backend!"
+
self.build()
if not self.ready:
diff --git a/vkdispatch/shader/signature.py b/vkdispatch/shader/signature.py
index c9cb53b7..8d6f4a46 100644
--- a/vkdispatch/shader/signature.py
+++ b/vkdispatch/shader/signature.py
@@ -19,6 +19,16 @@
import enum
+_PUSH_CONSTANT_UNSUPPORTED_BACKENDS = set()
+
+
+def _push_constant_not_supported_error(backend_name: str) -> str:
+ return (
+ f"Push Constants are not supported for the {backend_name.upper()} backend. "
+ "Use Const instead."
+ )
+
+
class ShaderArgumentType(enum.Enum):
BUFFER = 0
IMAGE = 1
@@ -139,6 +149,9 @@ def from_type_annotations(cls,
value_name = shader_param.raw_name
arg_type = ShaderArgumentType.CONSTANT
elif(issubclass(annotations[i].__origin__, vc.Variable)):
+ if builder.backend.name in _PUSH_CONSTANT_UNSUPPORTED_BACKENDS:
+ raise NotImplementedError(_push_constant_not_supported_error(builder.backend.name))
+
shader_param = builder.declare_variable(annotations[i].__args__[0])
arg_type = ShaderArgumentType.VARIABLE
value_name = shader_param.raw_name
diff --git a/vkdispatch/vkfft/vkfft_plan.py b/vkdispatch/vkfft/vkfft_plan.py
index 64f201f3..0ad12dea 100644
--- a/vkdispatch/vkfft/vkfft_plan.py
+++ b/vkdispatch/vkfft/vkfft_plan.py
@@ -1,4 +1,4 @@
-from vkdispatch.base.backend import native
+from vkdispatch.backends.backend_selection import native
import vkdispatch as vd