diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 5589de9c..f6f99017 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -15,12 +15,12 @@ on: jobs: - build_wheels: - name: Build wheels on ${{ matrix.os }} + build_native_wheels: + name: Build native wheels on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-13, macos-14] + os: [ubuntu-latest, windows-latest, macos-15-intel, macos-15] steps: - uses: actions/checkout@v4 @@ -28,15 +28,17 @@ jobs: # Used to host cibuildwheel - uses: actions/setup-python@v5 - - name: Install cibuildwheel + - name: Install cibuildwheel and native deps run: | python -m pip install --upgrade pip python -m pip install cibuildwheel==3.2.1 python fetch_dependencies.py - - name: Build wheels + - name: Build native wheels env: CIBW_SKIP: 'pp* manylinux_i686 musllinux*' + VKDISPATCH_BUILD_TARGET: native + CIBW_ENVIRONMENT: VKDISPATCH_BUILD_TARGET=native run: python -m cibuildwheel --output-dir wheelhouse # to supply options, put them in 'env', like: @@ -47,28 +49,44 @@ jobs: with: name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} path: ./wheelhouse/*.whl - build_sdist: - name: Build source distribution + build_python_dists: + name: Build native/core/meta sdists and pure wheels runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Install dependencies + - uses: actions/setup-python@v5 + + - name: Install build tooling run: | python -m pip install --upgrade pip + python -m pip install build + + - name: Build native source distribution + env: + VKDISPATCH_BUILD_TARGET: native + run: | python fetch_dependencies.py + python -m build --sdist --outdir dist + + - name: Build core wheel and source distribution + env: + VKDISPATCH_BUILD_TARGET: core + run: python -m build --wheel --sdist --outdir dist - - name: Build sdist - run: pipx run build --sdist + - name: Build meta wheel and source distribution + env: + VKDISPATCH_BUILD_TARGET: meta + run: python -m build --wheel --sdist --outdir dist - uses: actions/upload-artifact@v4 with: - name: cibw-sdist - path: dist/*.tar.gz + name: cibw-python-dists + path: dist/* publish-to-pypi: name: Publish Python package to PyPI # if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes - needs: [build_wheels, build_sdist] + needs: [build_native_wheels, build_python_dists] runs-on: ubuntu-latest environment: name: pypi diff --git a/docs/Makefile b/docs/Makefile index 4bf195e2..4c660da8 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -27,11 +27,9 @@ bundle_lib: @rm -rf "$(LIB_DEST)/vkdispatch" @mkdir -p "$(LIB_DEST)" @rm -f "$(LIB_DEST)/$(LIB_BUNDLE)" - @rm -f "$(LIB_DEST)/vkdispatch_native.brython.js" @rm -rf "$(LIB_STAGE)" @mkdir -p "$(LIB_STAGE)" @cp -r ../vkdispatch "$(LIB_STAGE)/vkdispatch" - @cp -r special_pages/libs/vkdispatch_native "$(LIB_STAGE)/vkdispatch_native" @cd "$(LIB_STAGE)" && $(PYTHON) -m brython make_package vkdispatch \ --src-dir . \ --output-path "$(CURDIR)/$(LIB_DEST)/$(LIB_BUNDLE)" diff --git a/docs/special/brython_shader_lab.rst b/docs/special/brython_shader_lab.rst deleted file mode 100644 index aeeffe87..00000000 --- a/docs/special/brython_shader_lab.rst +++ /dev/null @@ -1,16 +0,0 @@ -Brython Shader Lab -================== - -This page redirects to a standalone HTML app page. - -.. raw:: html - - - -

- Redirecting to the Brython shader lab page. - If you are not redirected, open - the standalone HTML page. -

diff --git a/docs/special/index.rst b/docs/special/index.rst index da840951..370fe451 100644 --- a/docs/special/index.rst +++ b/docs/special/index.rst @@ -6,4 +6,4 @@ Standalone pages integrated into the docs navigation. .. toctree:: :maxdepth: 1 - brython_shader_lab + shader_playground diff --git a/docs/special/shader_playground.rst b/docs/special/shader_playground.rst new file mode 100644 index 00000000..364329e1 --- /dev/null +++ b/docs/special/shader_playground.rst @@ -0,0 +1,16 @@ +Shader Playground +================== + +This page redirects to a standalone HTML app page. + +.. raw:: html + + + +

+ Redirecting to the shader playground page. + If you are not redirected, open + the standalone HTML page. +

diff --git a/docs/special_pages/libs/vkdispatch_native/__init__.py b/docs/special_pages/libs/vkdispatch_native/__init__.py deleted file mode 100644 index 673b054f..00000000 --- a/docs/special_pages/libs/vkdispatch_native/__init__.py +++ /dev/null @@ -1,1107 +0,0 @@ -"""Brython-friendly pure-Python shim for ``vkdispatch_native``. - -This module mirrors the Cython-exposed API used by ``vkdispatch`` and provides -an in-memory fake runtime suitable for docs execution and shader-source -compilation paths. -""" - -# NOTE: Keep this file dependency-light so it works under Brython. - -LOG_LEVEL_VERBOSE = 0 -LOG_LEVEL_INFO = 1 -LOG_LEVEL_WARNING = 2 -LOG_LEVEL_ERROR = 3 - -# Descriptor type enum values mirrored from vkdispatch_native/stages_extern.pxd. -DESCRIPTOR_TYPE_STORAGE_BUFFER = 1 -DESCRIPTOR_TYPE_STORAGE_IMAGE = 2 -DESCRIPTOR_TYPE_UNIFORM_BUFFER = 3 -DESCRIPTOR_TYPE_UNIFORM_IMAGE = 4 -DESCRIPTOR_TYPE_SAMPLER = 5 - -# Image format block sizes for formats exposed in vkdispatch.base.image.image_format. -_IMAGE_BLOCK_SIZES = { - 13: 1, - 14: 1, - 20: 2, - 21: 2, - 27: 3, - 28: 3, - 41: 4, - 42: 4, - 74: 2, - 75: 2, - 76: 2, - 81: 4, - 82: 4, - 83: 4, - 88: 6, - 89: 6, - 90: 6, - 95: 8, - 96: 8, - 97: 8, - 98: 4, - 99: 4, - 100: 4, - 101: 8, - 102: 8, - 103: 8, - 104: 12, - 105: 12, - 106: 12, - 107: 16, - 108: 16, - 109: 16, - 110: 8, - 111: 8, - 112: 8, - 113: 16, - 114: 16, - 115: 16, - 116: 24, - 117: 24, - 118: 24, - 119: 32, - 120: 32, - 121: 32, -} - -# --- Runtime state --- - -_initialized = False -_debug_mode = False -_log_level = LOG_LEVEL_WARNING -_error_string = None -_next_handle = 1 - -_contexts = {} -_signals = {} -_buffers = {} -_command_lists = {} -_compute_plans = {} -_descriptor_sets = {} -_images = {} -_samplers = {} -_fft_plans = {} - -# Device limits exposed through get_devices(); mutable so docs UI can tune them. -_DEFAULT_SUBGROUP_SIZE = 32 -_DEFAULT_MAX_WORKGROUP_SIZE = (1024, 1024, 64) -_DEFAULT_MAX_WORKGROUP_INVOCATIONS = 1024 -_DEFAULT_MAX_WORKGROUP_COUNT = (65535, 65535, 65535) -_DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE = 64 * 1024 - -_device_subgroup_size = _DEFAULT_SUBGROUP_SIZE -_device_max_workgroup_size = _DEFAULT_MAX_WORKGROUP_SIZE -_device_max_workgroup_invocations = _DEFAULT_MAX_WORKGROUP_INVOCATIONS -_device_max_workgroup_count = _DEFAULT_MAX_WORKGROUP_COUNT -_device_max_compute_shared_memory_size = _DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE - - -# --- Internal objects --- - -class _Signal: - __slots__ = ("done",) - - def __init__(self, done=True): - self.done = bool(done) - - -class _Context: - __slots__ = ( - "device_indices", - "queue_families", - "queue_count", - "queue_to_device", - "stopped", - ) - - def __init__(self, device_indices, queue_families): - self.device_indices = list(device_indices) - self.queue_families = [list(fam) for fam in queue_families] - - normalized = [] - for fam in self.queue_families: - normalized.append(fam if len(fam) > 0 else [0]) - self.queue_families = normalized - - self.queue_count = sum(len(fam) for fam in self.queue_families) - if self.queue_count <= 0: - self.queue_families = [[0]] - self.queue_count = 1 - - queue_to_device = [] - for dev_idx, fam in enumerate(self.queue_families): - for _ in fam: - queue_to_device.append(dev_idx) - - if len(queue_to_device) == 0: - queue_to_device = [0] - - self.queue_to_device = queue_to_device - self.stopped = False - - -class _Buffer: - __slots__ = ( - "context_handle", - "size", - "device_data", - "staging_data", - "signal_handles", - ) - - def __init__(self, context_handle, queue_count, size): - self.context_handle = context_handle - self.size = int(size) - - if queue_count <= 0: - queue_count = 1 - - self.device_data = [bytearray(self.size) for _ in range(queue_count)] - self.staging_data = [bytearray(self.size) for _ in range(queue_count)] - - signal_handles = [] - for _ in range(queue_count): - signal_handles.append(_new_handle(_signals, _Signal(done=True))) - self.signal_handles = signal_handles - - -class _CommandList: - __slots__ = ("context_handle", "commands", "compute_instance_size") - - def __init__(self, context_handle): - self.context_handle = context_handle - self.commands = [] - self.compute_instance_size = 0 - - -class _ComputePlan: - __slots__ = ("context_handle", "shader_source", "bindings", "pc_size", "shader_name") - - def __init__(self, context_handle, shader_source, bindings, pc_size, shader_name): - self.context_handle = context_handle - self.shader_source = shader_source - self.bindings = list(bindings) - self.pc_size = int(pc_size) - self.shader_name = shader_name - - -class _DescriptorSet: - __slots__ = ("plan_handle", "buffer_bindings", "image_bindings") - - def __init__(self, plan_handle): - self.plan_handle = plan_handle - self.buffer_bindings = {} - self.image_bindings = {} - - -class _Image: - __slots__ = ( - "context_handle", - "extent", - "layers", - "format", - "type", - "view_type", - "generate_mips", - "block_size", - "queue_data", - ) - - def __init__( - self, - context_handle, - queue_count, - extent, - layers, - format_, - image_type, - view_type, - generate_mips, - ): - self.context_handle = context_handle - self.extent = tuple(extent) - self.layers = int(layers) - self.format = int(format_) - self.type = int(image_type) - self.view_type = int(view_type) - self.generate_mips = int(generate_mips) - - self.block_size = image_format_block_size(self.format) - - if queue_count <= 0: - queue_count = 1 - - width = max(1, int(self.extent[0])) - height = max(1, int(self.extent[1])) - depth = max(1, int(self.extent[2])) - layer_count = max(1, self.layers) - total_bytes = width * height * depth * layer_count * self.block_size - - self.queue_data = [bytearray(total_bytes) for _ in range(queue_count)] - - -class _Sampler: - __slots__ = ( - "context_handle", - "mag_filter", - "min_filter", - "mip_mode", - "address_mode", - "mip_lod_bias", - "min_lod", - "max_lod", - "border_color", - ) - - def __init__( - self, - context_handle, - mag_filter, - min_filter, - mip_mode, - address_mode, - mip_lod_bias, - min_lod, - max_lod, - border_color, - ): - self.context_handle = context_handle - self.mag_filter = int(mag_filter) - self.min_filter = int(min_filter) - self.mip_mode = int(mip_mode) - self.address_mode = int(address_mode) - self.mip_lod_bias = float(mip_lod_bias) - self.min_lod = float(min_lod) - self.max_lod = float(max_lod) - self.border_color = int(border_color) - - -class _FFTPlan: - __slots__ = ( - "context_handle", - "dims", - "axes", - "buffer_size", - "input_buffer_size", - "kernel_num", - ) - - def __init__( - self, - context_handle, - dims, - axes, - buffer_size, - input_buffer_size, - kernel_num, - ): - self.context_handle = context_handle - self.dims = list(dims) - self.axes = list(axes) - self.buffer_size = int(buffer_size) - self.input_buffer_size = int(input_buffer_size) - self.kernel_num = int(kernel_num) - - -# --- Internal helpers --- - - -def _new_handle(registry, obj): - global _next_handle - handle = _next_handle - _next_handle += 1 - registry[handle] = obj - return handle - - -def _to_bytes(value): - if value is None: - return b"" - if isinstance(value, bytes): - return value - if isinstance(value, bytearray): - return bytes(value) - if isinstance(value, memoryview): - return value.tobytes() - try: - return bytes(value) - except Exception: - return b"" - - -def _normalize_extent(extent): - values = list(extent) - if len(values) < 3: - values.extend([1] * (3 - len(values))) - return (int(values[0]), int(values[1]), int(values[2])) - - -def _queue_indices(ctx, queue_index, all_on_negative=False): - if ctx is None or ctx.queue_count <= 0: - return [] - - if queue_index is None: - return [0] - - queue_index = int(queue_index) - - if all_on_negative and queue_index in (-1, -2): - return list(range(ctx.queue_count)) - - if 0 <= queue_index < ctx.queue_count: - return [queue_index] - - return [] - - -def _set_error(message): - global _error_string - _error_string = str(message) - - -def _clear_error(): - global _error_string - _error_string = None - - -def _as_positive_int(name, value): - try: - parsed = int(value) - except Exception as exc: - raise ValueError("%s must be an integer" % name) from exc - - if parsed <= 0: - raise ValueError("%s must be greater than zero" % name) - - return parsed - - -def _as_positive_triplet(name, value): - try: - parts = list(value) - except Exception as exc: - raise ValueError("%s must contain exactly 3 integers" % name) from exc - - if len(parts) != 3: - raise ValueError("%s must contain exactly 3 integers" % name) - - return ( - _as_positive_int("%s[0]" % name, parts[0]), - _as_positive_int("%s[1]" % name, parts[1]), - _as_positive_int("%s[2]" % name, parts[2]), - ) - - -# --- API: context/init/errors/logging --- - - -def reset_device_options(): - global _device_subgroup_size - global _device_max_workgroup_size - global _device_max_workgroup_invocations - global _device_max_workgroup_count - global _device_max_compute_shared_memory_size - - _device_subgroup_size = _DEFAULT_SUBGROUP_SIZE - _device_max_workgroup_size = _DEFAULT_MAX_WORKGROUP_SIZE - _device_max_workgroup_invocations = _DEFAULT_MAX_WORKGROUP_INVOCATIONS - _device_max_workgroup_count = _DEFAULT_MAX_WORKGROUP_COUNT - _device_max_compute_shared_memory_size = _DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE - - -def set_device_options( - subgroup_size=None, - max_workgroup_size=None, - max_workgroup_invocations=None, - max_workgroup_count=None, - max_compute_shared_memory_size=None, -): - global _device_subgroup_size - global _device_max_workgroup_size - global _device_max_workgroup_invocations - global _device_max_workgroup_count - global _device_max_compute_shared_memory_size - - if subgroup_size is not None: - _device_subgroup_size = _as_positive_int("subgroup_size", subgroup_size) - - if max_workgroup_size is not None: - _device_max_workgroup_size = _as_positive_triplet( - "max_workgroup_size", - max_workgroup_size, - ) - - if max_workgroup_invocations is not None: - _device_max_workgroup_invocations = _as_positive_int( - "max_workgroup_invocations", - max_workgroup_invocations, - ) - - if max_workgroup_count is not None: - _device_max_workgroup_count = _as_positive_triplet( - "max_workgroup_count", - max_workgroup_count, - ) - - if max_compute_shared_memory_size is not None: - _device_max_compute_shared_memory_size = _as_positive_int( - "max_compute_shared_memory_size", - max_compute_shared_memory_size, - ) - - -def init(debug, log_level): - global _initialized, _debug_mode, _log_level - _initialized = True - _debug_mode = bool(debug) - _log_level = int(log_level) - _clear_error() - - -def log(log_level, text, file_str, line_str): - # Keep logging quiet in docs/brython by default. - # Function kept for API compatibility. - _ = log_level - _ = text - _ = file_str - _ = line_str - - -def set_log_level(log_level): - global _log_level - _log_level = int(log_level) - - -def get_devices(): - if not _initialized: - init(False, _log_level) - - # One plausible fake discrete GPU with compute+graphics queue families. - device_tuple = ( - 0, # version_variant - 1, # version_major - 3, # version_minor - 0, # version_patch - 1001000, # driver_version - 0x1BAD, # vendor_id - 0x0001, # device_id - 2, # device_type (Discrete GPU) - "VKDispatch Web Dummy GPU", - 1, # shader_buffer_float32_atomics - 1, # shader_buffer_float32_atomic_add - 1, # float_64_support - 1, # float_16_support - 1, # int_64_support - 1, # int_16_support - 1, # storage_buffer_16_bit_access - 1, # uniform_and_storage_buffer_16_bit_access - 1, # storage_push_constant_16 - 1, # storage_input_output_16 - _device_max_workgroup_size, # max_workgroup_size - _device_max_workgroup_invocations, # max_workgroup_invocations - _device_max_workgroup_count, # max_workgroup_count - 8, # max_descriptor_set_count - 256, # max_push_constant_size - 1 << 30, # max_storage_buffer_range - 65536, # max_uniform_buffer_range - 16, # uniform_buffer_alignment - _device_subgroup_size, # subgroup_size - 0x7FFFFFFF, # supported_stages - 0x7FFFFFFF, # supported_operations - 1, # quad_operations_in_all_stages - _device_max_compute_shared_memory_size, # max_compute_shared_memory_size - [ - (8, 0x006), # compute + transfer - (4, 0x007), # graphics + compute + transfer - ], - 1, # scalar_block_layout - 1, # timeline_semaphores - bytes((0x56, 0x4B, 0x44, 0x30, 0x57, 0x45, 0x42, 0x31, 0x44, 0x55, 0x4D, 0x4D, 0x59, 0x00, 0x00, 0x01)), - ) - - return [device_tuple] - - -def context_create(device_indicies, queue_families): - try: - ctx = _Context(device_indicies, queue_families) - return _new_handle(_contexts, ctx) - except Exception as exc: - _set_error("Failed to create context: %s" % exc) - return 0 - - -def signal_wait(signal_ptr, wait_for_timestamp, queue_index): - _ = wait_for_timestamp - _ = queue_index - signal_obj = _signals.get(int(signal_ptr)) - if signal_obj is None: - return True - return bool(signal_obj.done) - - -def signal_insert(context, queue_index): - _ = context - _ = queue_index - return _new_handle(_signals, _Signal(done=True)) - - -def signal_destroy(signal_ptr): - _signals.pop(int(signal_ptr), None) - - -def context_destroy(context): - _contexts.pop(int(context), None) - - -def get_error_string(): - if _error_string is None: - return 0 - return _error_string - - -def context_stop_threads(context): - ctx = _contexts.get(int(context)) - if ctx is not None: - ctx.stopped = True - - -# --- API: buffers --- - - -def buffer_create(context, size, per_device): - _ = per_device - ctx = _contexts.get(int(context)) - if ctx is None: - _set_error("Invalid context handle for buffer_create") - return 0 - - size = int(size) - if size < 0: - size = 0 - - return _new_handle(_buffers, _Buffer(int(context), ctx.queue_count, size)) - - -def buffer_destroy(buffer): - obj = _buffers.pop(int(buffer), None) - if obj is None: - return - - for signal_handle in obj.signal_handles: - _signals.pop(signal_handle, None) - - -def buffer_get_queue_signal(buffer, queue_index): - obj = _buffers.get(int(buffer)) - if obj is None: - return _new_handle(_signals, _Signal(done=True)) - - queue_index = int(queue_index) - if queue_index < 0 or queue_index >= len(obj.signal_handles): - queue_index = 0 - - return obj.signal_handles[queue_index] - - -def buffer_wait_staging_idle(buffer, queue_index): - _ = buffer - _ = queue_index - return True - - -def buffer_write_staging(buffer, queue_index, data, size): - obj = _buffers.get(int(buffer)) - if obj is None: - return - - queue_index = int(queue_index) - if queue_index < 0 or queue_index >= len(obj.staging_data): - return - - payload = _to_bytes(data) - size = min(int(size), len(payload), obj.size) - if size <= 0: - return - - obj.staging_data[queue_index][:size] = payload[:size] - - -def buffer_read_staging(buffer, queue_index, size): - obj = _buffers.get(int(buffer)) - if obj is None: - return bytes(int(size)) - - queue_index = int(queue_index) - if queue_index < 0 or queue_index >= len(obj.staging_data): - return bytes(int(size)) - - size = int(size) - if size <= 0: - return b"" - - data = obj.staging_data[queue_index] - if size <= len(data): - return bytes(data[:size]) - - return bytes(data) + bytes(size - len(data)) - - -def buffer_write(buffer, offset, size, index): - obj = _buffers.get(int(buffer)) - if obj is None: - return - - offset = int(offset) - size = int(size) - - if size <= 0 or offset < 0: - return - - ctx = _contexts.get(obj.context_handle) - if ctx is None: - return - - queue_indices = _queue_indices(ctx, index, all_on_negative=True) - if len(queue_indices) == 0: - return - - for queue_index in queue_indices: - if queue_index >= len(obj.device_data) or queue_index >= len(obj.staging_data): - continue - - end = min(offset + size, obj.size) - copy_size = end - offset - if copy_size <= 0: - continue - - obj.device_data[queue_index][offset:end] = obj.staging_data[queue_index][:copy_size] - - signal_handle = obj.signal_handles[queue_index] - signal_obj = _signals.get(signal_handle) - if signal_obj is not None: - signal_obj.done = True - - -def buffer_read(buffer, offset, size, index): - obj = _buffers.get(int(buffer)) - if obj is None: - return - - offset = int(offset) - size = int(size) - - if size <= 0 or offset < 0: - return - - queue_index = int(index) - if queue_index < 0 or queue_index >= len(obj.device_data): - return - - end = min(offset + size, obj.size) - copy_size = end - offset - if copy_size <= 0: - return - - obj.staging_data[queue_index][:copy_size] = obj.device_data[queue_index][offset:end] - - signal_handle = obj.signal_handles[queue_index] - signal_obj = _signals.get(signal_handle) - if signal_obj is not None: - signal_obj.done = True - - -# --- API: command lists --- - - -def command_list_create(context): - if int(context) not in _contexts: - _set_error("Invalid context handle for command_list_create") - return 0 - - return _new_handle(_command_lists, _CommandList(int(context))) - - -def command_list_destroy(command_list): - _command_lists.pop(int(command_list), None) - - -def command_list_get_instance_size(command_list): - obj = _command_lists.get(int(command_list)) - if obj is None: - return 0 - - return int(obj.compute_instance_size) - - -def command_list_reset(command_list): - obj = _command_lists.get(int(command_list)) - if obj is None: - return - - obj.commands = [] - obj.compute_instance_size = 0 - - -def command_list_submit(command_list, data, instance_count, index): - _ = data - _ = instance_count - _ = index - - obj = _command_lists.get(int(command_list)) - if obj is None: - return True - - # No-op fake execution path: commands are accepted but not executed. - # Keep the command list intact (native keeps it until reset/destroy). - _ = obj.commands - return True - - -# --- API: descriptor sets --- - - -def descriptor_set_create(plan): - if int(plan) not in _compute_plans: - _set_error("Invalid compute plan handle for descriptor_set_create") - return 0 - - return _new_handle(_descriptor_sets, _DescriptorSet(int(plan))) - - -def descriptor_set_destroy(descriptor_set): - _descriptor_sets.pop(int(descriptor_set), None) - - -def descriptor_set_write_buffer( - descriptor_set, - binding, - object, - offset, - range, - uniform, - read_access, - write_access, -): - ds = _descriptor_sets.get(int(descriptor_set)) - if ds is None: - return - - ds.buffer_bindings[int(binding)] = ( - int(object), - int(offset), - int(range), - int(uniform), - int(read_access), - int(write_access), - ) - - -def descriptor_set_write_image( - descriptor_set, - binding, - object, - sampler_obj, - read_access, - write_access, -): - ds = _descriptor_sets.get(int(descriptor_set)) - if ds is None: - return - - ds.image_bindings[int(binding)] = ( - int(object), - int(sampler_obj), - int(read_access), - int(write_access), - ) - - -# --- API: images/samplers --- - - -def image_create(context, extent, layers, format, type, view_type, generate_mips): - ctx = _contexts.get(int(context)) - if ctx is None: - _set_error("Invalid context handle for image_create") - return 0 - - norm_extent = _normalize_extent(extent) - obj = _Image( - int(context), - ctx.queue_count, - norm_extent, - int(layers), - int(format), - int(type), - int(view_type), - int(generate_mips), - ) - - return _new_handle(_images, obj) - - -def image_destroy(image): - _images.pop(int(image), None) - - -def image_create_sampler( - context, - mag_filter, - min_filter, - mip_mode, - address_mode, - mip_lod_bias, - min_lod, - max_lod, - border_color, -): - if int(context) not in _contexts: - _set_error("Invalid context handle for image_create_sampler") - return 0 - - sampler = _Sampler( - int(context), - mag_filter, - min_filter, - mip_mode, - address_mode, - mip_lod_bias, - min_lod, - max_lod, - border_color, - ) - return _new_handle(_samplers, sampler) - - -def image_destroy_sampler(sampler): - _samplers.pop(int(sampler), None) - - -def image_write(image, data, offset, extent, baseLayer, layerCount, device_index): - _ = offset - _ = baseLayer - - obj = _images.get(int(image)) - if obj is None: - return - - payload = _to_bytes(data) - - extent = _normalize_extent(extent) - layer_count = max(1, int(layerCount)) - region_size = max(0, extent[0] * extent[1] * extent[2] * layer_count * obj.block_size) - if region_size <= 0: - return - - copy_size = min(region_size, len(payload)) - if copy_size <= 0: - return - - ctx = _contexts.get(obj.context_handle) - if ctx is None: - return - - queue_indices = _queue_indices(ctx, device_index, all_on_negative=True) - if len(queue_indices) == 0: - return - - for queue_index in queue_indices: - if queue_index < 0 or queue_index >= len(obj.queue_data): - continue - obj.queue_data[queue_index][:copy_size] = payload[:copy_size] - - -def image_format_block_size(format): - return int(_IMAGE_BLOCK_SIZES.get(int(format), 4)) - - -def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index): - _ = offset - _ = extent - _ = baseLayer - _ = layerCount - - obj = _images.get(int(image)) - out_size = max(0, int(out_size)) - - if obj is None: - return bytes(out_size) - - queue_index = int(device_index) - if queue_index < 0 or queue_index >= len(obj.queue_data): - queue_index = 0 - - data = obj.queue_data[queue_index] - if out_size <= len(data): - return bytes(data[:out_size]) - - return bytes(data) + bytes(out_size - len(data)) - - -# --- API: compute stage --- - - -def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name): - if int(context) not in _contexts: - _set_error("Invalid context handle for stage_compute_plan_create") - return 0 - - source_bytes = _to_bytes(shader_source) - name_bytes = _to_bytes(shader_name) - - plan = _ComputePlan(int(context), source_bytes, list(bindings), int(pc_size), name_bytes) - return _new_handle(_compute_plans, plan) - - -def stage_compute_plan_destroy(plan): - _compute_plans.pop(int(plan), None) - - -def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z): - cl = _command_lists.get(int(command_list)) - cp = _compute_plans.get(int(plan)) - - if cl is None or cp is None: - return - - cl.commands.append( - { - "type": "compute", - "plan": int(plan), - "descriptor_set": int(descriptor_set), - "blocks": (int(blocks_x), int(blocks_y), int(blocks_z)), - } - ) - cl.compute_instance_size += max(0, int(cp.pc_size)) - - -# --- API: FFT stage --- - - -def stage_fft_plan_create( - context, - dims, - axes, - buffer_size, - do_r2c, - normalize, - pad_left, - pad_right, - frequency_zeropadding, - kernel_num, - kernel_convolution, - conjugate_convolution, - convolution_features, - input_buffer_size, - num_batches, - single_kernel_multiple_batches, - keep_shader_code, -): - _ = do_r2c - _ = normalize - _ = pad_left - _ = pad_right - _ = frequency_zeropadding - _ = kernel_convolution - _ = conjugate_convolution - _ = convolution_features - _ = num_batches - _ = single_kernel_multiple_batches - _ = keep_shader_code - - if int(context) not in _contexts: - _set_error("Invalid context handle for stage_fft_plan_create") - return 0 - - plan = _FFTPlan( - int(context), - list(dims), - list(axes), - int(buffer_size), - int(input_buffer_size), - int(kernel_num), - ) - - return _new_handle(_fft_plans, plan) - - -def stage_fft_plan_destroy(plan): - _fft_plans.pop(int(plan), None) - - -def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer): - _ = buffer - _ = inverse - _ = kernel - _ = input_buffer - - cl = _command_lists.get(int(command_list)) - if cl is None or int(plan) not in _fft_plans: - return - - cl.commands.append( - { - "type": "fft", - "plan": int(plan), - } - ) - - -__all__ = [ - "reset_device_options", - "set_device_options", - "init", - "log", - "set_log_level", - "get_devices", - "context_create", - "signal_wait", - "signal_insert", - "signal_destroy", - "context_destroy", - "get_error_string", - "context_stop_threads", - "buffer_create", - "buffer_destroy", - "buffer_get_queue_signal", - "buffer_wait_staging_idle", - "buffer_write_staging", - "buffer_read_staging", - "buffer_write", - "buffer_read", - "command_list_create", - "command_list_destroy", - "command_list_get_instance_size", - "command_list_reset", - "command_list_submit", - "descriptor_set_create", - "descriptor_set_destroy", - "descriptor_set_write_buffer", - "descriptor_set_write_image", - "image_create", - "image_destroy", - "image_create_sampler", - "image_destroy_sampler", - "image_write", - "image_format_block_size", - "image_read", - "stage_compute_plan_create", - "stage_compute_plan_destroy", - "stage_compute_record", - "stage_fft_plan_create", - "stage_fft_plan_destroy", - "stage_fft_record", - "LOG_LEVEL_VERBOSE", - "LOG_LEVEL_INFO", - "LOG_LEVEL_WARNING", - "LOG_LEVEL_ERROR", - "DESCRIPTOR_TYPE_STORAGE_BUFFER", - "DESCRIPTOR_TYPE_STORAGE_IMAGE", - "DESCRIPTOR_TYPE_UNIFORM_BUFFER", - "DESCRIPTOR_TYPE_UNIFORM_IMAGE", - "DESCRIPTOR_TYPE_SAMPLER", -] diff --git a/docs/special_pages/brython_shader_lab.html b/docs/special_pages/shader_playground.html similarity index 91% rename from docs/special_pages/brython_shader_lab.html rename to docs/special_pages/shader_playground.html index 0e9e057c..9fb37d17 100644 --- a/docs/special_pages/brython_shader_lab.html +++ b/docs/special_pages/shader_playground.html @@ -401,6 +401,7 @@
+

VkDispatch Shader Playground

@@ -455,15 +456,17 @@

VkDispatch Shader Playground

- + +
- Shader Print Line Numbers + Subgroup Ops Enabled
-
--> +
@@ -676,7 +679,7 @@

VkDispatch Shader Playground

btn.classList.add("active"); /* Switch output highlighting mode */ - if (backend === "cuda") { + if (backend === "cuda" || backend === "opencl") { window.cmOutput.setOption("mode", "text/x-csrc"); } else { window.cmOutput.setOption("mode", "text/x-glsl"); @@ -690,6 +693,7 @@

VkDispatch Shader Playground

/* ── URL hash restore ── */ var deviceFields = [ { key: "ss", id: "opt-subgroup-size" }, + { key: "sse", id: "opt-subgroup-enabled" }, { key: "wsx", id: "opt-wg-size-x" }, { key: "wsy", id: "opt-wg-size-y" }, { key: "wsz", id: "opt-wg-size-z" }, @@ -720,11 +724,11 @@

VkDispatch Shader Playground

/* Restore backend from URL */ if (params.has("be")) { var be = params.get("be").toLowerCase(); - if (be === "cuda") { - window.currentBackend = "cuda"; + if (be === "cuda" || be === "opencl") { + window.currentBackend = be; toggleButtons.forEach(function (b) { b.classList.remove("active"); - if (b.getAttribute("data-backend") === "cuda") { + if (b.getAttribute("data-backend") === be) { b.classList.add("active"); } }); @@ -734,7 +738,15 @@

VkDispatch Shader Playground

deviceFields.forEach(function (f) { if (params.has(f.key)) { - document.getElementById(f.id).value = params.get(f.key); + if(f.id === "opt-subgroup-enabled") { + document.getElementById(f.id).checked = + params.get(f.key) === "1" || + params.get(f.key).toLowerCase() === "true" || + params.get(f.key).toLowerCase() === "yes" || + params.get(f.key).toLowerCase() === "on"; + } else { + document.getElementById(f.id).value = params.get(f.key); + } } }); toggleFields.forEach(function (f) { @@ -792,43 +804,74 @@

VkDispatch Shader Playground

}); /* ── share button ── */ - document - .getElementById("share-btn") - .addEventListener("click", function () { - var code = window.cmCode.getValue(); - var encoded = btoa(unescape(encodeURIComponent(code))); - - var hashParts = ["code=" + encoded]; - - /* Include backend in share link */ - hashParts.push("be=" + encodeURIComponent(window.currentBackend)); - - deviceFields.forEach(function (f) { - var val = document.getElementById(f.id).value.trim(); - if (val !== "") { - hashParts.push( - encodeURIComponent(f.key) + - "=" + - encodeURIComponent(val) - ); - } - }); - toggleFields.forEach(function (f) { - var el = document.getElementById(f.id); - if (!el) return; - var checked = el.checked ? "1" : "0"; + function buildPlaygroundHash() { + var code = window.cmCode.getValue(); + var encoded = btoa(unescape(encodeURIComponent(code))); + + var hashParts = ["code=" + encoded]; + + /* Include backend in share/runtime URL */ + hashParts.push("be=" + encodeURIComponent(window.currentBackend)); + + deviceFields.forEach(function (f) { + var val = document.getElementById(f.id).value.trim(); + + if(f.id === "opt-subgroup-enabled") { + val = document.getElementById(f.id).checked ? "1" : "0"; + } + + if (val !== "") { hashParts.push( encodeURIComponent(f.key) + "=" + - encodeURIComponent(checked) + encodeURIComponent(val) ); - }); + } + }); + toggleFields.forEach(function (f) { + var el = document.getElementById(f.id); + if (!el) return; + var checked = el.checked ? "1" : "0"; + hashParts.push( + encodeURIComponent(f.key) + + "=" + + encodeURIComponent(checked) + ); + }); + + return hashParts.join("&"); + } + + window.updatePlaygroundUrlState = function () { + var hash = buildPlaygroundHash(); + var nextUrl = + window.location.pathname + + window.location.search + + "#" + + hash; + + if (window.location.hash.slice(1) !== hash) { + if (window.history && window.history.replaceState) { + window.history.replaceState(null, "", nextUrl); + } else { + window.location.hash = hash; + } + } + + return hash; + }; + + document + .getElementById("share-btn") + .addEventListener("click", function () { + var hash = window.updatePlaygroundUrlState(); var url = window.location.origin + window.location.pathname + + window.location.search + "#" + - hashParts.join("&"); + hash; copyToClipboard(url).then(function () { showToast("Share link copied to clipboard."); @@ -917,13 +960,14 @@

VkDispatch Shader Playground

import sys import traceback -import vkdispatch_native +import vkdispatch as vd import vkdispatch.base.context as vd_context import vkdispatch.base.init as vd_init import vkdispatch.execution_pipeline.command_graph as vd_command_graph import vkdispatch.fft.shader_factories as vd_fft_shader_factories import vkdispatch.codegen as vc +vd.initialize(backend="dummy") class OutputBuffer: def __init__(self): @@ -961,6 +1005,7 @@

VkDispatch Shader Playground

def _read_device_options(): return { "subgroup_size": _parse_positive_int("opt-subgroup-size", "Subgroup Size"), + "subgroup_enabled": document["opt-subgroup-enabled"].checked, "max_workgroup_size": ( _parse_positive_int("opt-wg-size-x", "Max Workgroup Size X"), _parse_positive_int("opt-wg-size-y", "Max Workgroup Size Y"), @@ -984,16 +1029,8 @@

VkDispatch Shader Playground

def _reset_vkdispatch_runtime(): context = getattr(vd_context, "__context", None) - if context is not None: - if hasattr(vd_context, "set_running"): - vd_context.set_running(False) - - handles_list = list(context.handles_dict.values()) - for handle in handles_list: - handle.destroy() - - vkdispatch_native.context_destroy(context._handle) - vd_context.__context = None + #if context is not None: + # vd_context.destroy_context() vd_init.__initilized_instance = False vd_init.__device_infos = None @@ -1008,6 +1045,9 @@

VkDispatch Shader Playground

code = window.cmCode.getValue() window.cmOutput.setValue("") + if event is not None and hasattr(window, "updatePlaygroundUrlState"): + window.updatePlaygroundUrlState() + stdout_buffer = OutputBuffer() stderr_buffer = OutputBuffer() @@ -1017,14 +1057,18 @@

VkDispatch Shader Playground

try: options = _read_device_options() - vkdispatch_native.set_device_options( + _reset_vkdispatch_runtime() + + vd.initialize(backend="dummy") + vd.get_context() + vd.set_dummy_context_params( subgroup_size=options["subgroup_size"], + subgroup_enabled=options["subgroup_enabled"], max_workgroup_size=options["max_workgroup_size"], max_workgroup_invocations=options["max_workgroup_invocations"], max_workgroup_count=options["max_workgroup_count"], - max_compute_shared_memory_size=options["max_compute_shared_memory_size"], + max_shared_memory=options["max_compute_shared_memory_size"], ) - _reset_vkdispatch_runtime() # Set codegen backend based on toggle state backend = str(window.currentBackend) diff --git a/docs/tutorials/reductions_and_fft.rst b/docs/tutorials/reductions_and_fft.rst index b078503b..6b77430a 100644 --- a/docs/tutorials/reductions_and_fft.rst +++ b/docs/tutorials/reductions_and_fft.rst @@ -162,6 +162,14 @@ For advanced workflows (for example padded 2D cross-correlation), use ``input_ma ``output_map`` to remap FFT I/O indices and ``input_signal_range`` to skip inactive regions. +Map argument annotations do not determine FFT compute precision. ``read_op.register`` +and ``write_op.register`` always use the internal FFT compute type; map callbacks should +cast user-chosen buffer values to and from that register type as needed. If both FFT I/O +paths are mapped and ``compute_type`` is not provided, ``vd.fft`` defaults to +``complex64`` (falling back to ``complex32`` when required by device support). +When ``output_map`` is provided without ``input_map``, pass an explicit input buffer +argument after the ``output_map`` arguments so read and write phases use different proxies. + .. code-block:: python import vkdispatch.codegen as vc diff --git a/examples/pytorch_cuda_graph_cuda_python.py b/examples/pytorch_cuda_graph_cuda_python.py new file mode 100644 index 00000000..51a949f9 --- /dev/null +++ b/examples/pytorch_cuda_graph_cuda_python.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +"""Capture and replay a vkdispatch CUDA kernel inside a PyTorch CUDA Graph. + +This example uses: + - vkdispatch runtime backend: "cuda" + - a custom vkdispatch shader recorded into CommandGraph + - torch.cuda.CUDAGraph capture + replay + - zero-copy tensor sharing via __cuda_array_interface__ +""" + +import torch + +import vkdispatch as vd +import vkdispatch.codegen as vc +from vkdispatch.codegen.abreviations import Buff, Const, f32 + + +@vd.shader(exec_size=lambda args: args.x.size) +def custom_shader(out: Buff[f32], x: Buff[f32], bias: Const[f32]): + tid = vc.global_invocation_id().x + out[tid] = x[tid] * 1.5 + vc.sin(x[tid]) + bias + + +def main() -> None: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this example.") + + torch.cuda.set_device(0) + torch.manual_seed(0) + + vd.initialize(backend="cuda") + vd.make_context(device_ids=torch.cuda.current_device()) + + n = 16 + bias = 0.25 + + # Static allocations are required for CUDA Graph replay. + x = torch.empty(n, device="cuda", dtype=torch.float32) + out = torch.empty_like(x) + x.fill_(0.0) + + x_vd = vd.from_cuda_array(x) + out_vd = vd.from_cuda_array(out) + + cmd_graph = vd.CommandGraph() + + # Record one vkdispatch kernel launch into the command graph. + # For backend="cuda-python", Const/Var payloads are fixed at record time. + custom_shader(out=out_vd, x=x_vd, bias=bias, graph=cmd_graph) + + torch.cuda.synchronize() + # Pre-stage internal uniform uploads outside torch capture so only dispatch is captured. + #cmd_graph.prepare_for_cuda_graph_capture() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + # torch.cuda.graph(...) may switch to an internal capture stream. + # Bind vkdispatch to the active stream from inside that context. + with vd.cuda_graph_capture(torch.cuda.current_stream()): + print("Submitting vkdispatch CommandGraph to CUDA Graph...") + cmd_graph.submit() + print("Done recording.") + + replay_inputs = [0.0, 1.0, 2.0, 3.0] + for i, value in enumerate(replay_inputs, start=1): + x.fill_(value) + graph.replay() + torch.cuda.synchronize() + + expected = x * 1.5 + torch.sin(x) + bias + torch.testing.assert_close(out, expected, rtol=1e-5, atol=1e-5) + print( + f"replay {i} input={value:.1f} output[:8]={out[:8].detach().cpu().tolist()}" + ) + + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index fc741656..7379c159 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,36 +2,7 @@ requires = [ "setuptools>=59.0", "wheel", - "Cython" + "Cython", + "packaging" ] build-backend = "setuptools.build_meta" - -[project] -name = "vkdispatch" -version = "0.0.30" -authors = [ - { name="Shahar Sandhaus", email="shahar.sandhaus@gmail.com" }, -] -description = "A Python module for orchestrating and dispatching large computations across multi-GPU systems using Vulkan." -readme = "README.md" -requires-python = ">=3.6" -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - "Development Status :: 2 - Pre-Alpha", -] -dependencies = [ - "setuptools>=59.0", -] -scripts = { vdlist = 'vkdispatch.cli:cli_entrypoint' } - -[project.urls] -Homepage = "https://github.com/sharhar/vkdispatch" -Issues = "https://github.com/sharhar/vkdispatch/issues" - -[project.optional-dependencies] -cli = ["Click"] -cuda = ["cuda-python"] -pycuda = ["pycuda"] -numpy = ["numpy"] diff --git a/setup.py b/setup.py index 38f19dfc..422495ce 100644 --- a/setup.py +++ b/setup.py @@ -1,239 +1,129 @@ import os import platform +import re import subprocess +from pathlib import Path from setuptools import Extension +from setuptools import find_packages from setuptools import setup from setuptools.command.build_ext import build_ext -import re - -# Typically you'll put `packaging` in your setup_requires or pyproject.toml if needed. try: from packaging.version import Version except ImportError: - # As a fallback, if you absolutely can't rely on `packaging`, - # you could use distutils: from distutils.version import LooseVersion as Version print("Warning: 'packaging' not found; version comparisons might be less accurate.") from distutils.version import LooseVersion as Version -system = platform.system() +BUILD_TARGET_FULL = "full" +BUILD_TARGET_CORE = "core" +BUILD_TARGET_NATIVE = "native" +BUILD_TARGET_META = "meta" +VALID_BUILD_TARGETS = { + BUILD_TARGET_FULL, + BUILD_TARGET_CORE, + BUILD_TARGET_NATIVE, + BUILD_TARGET_META, +} -proj_root = os.path.abspath(os.path.dirname(__file__)) -molten_vk_path = "./deps/MoltenVK/MoltenVK/MoltenVK/static/MoltenVK.xcframework/macos-arm64_x86_64/" -vulkan_sdk_root = os.environ.get('VULKAN_SDK') -platform_name_dict = { - "Darwin": "MACOS", - "Windows": "WINDOWS", - "Linux": "LINUX" -} +def get_build_target() -> str: + target = os.environ.get("VKDISPATCH_BUILD_TARGET", BUILD_TARGET_FULL).strip().lower() + if target not in VALID_BUILD_TARGETS: + valid = ", ".join(sorted(VALID_BUILD_TARGETS)) + raise RuntimeError( + f"Invalid VKDISPATCH_BUILD_TARGET={target!r}. Expected one of: {valid}" + ) + return target -platform_library_dirs = [] -platform_define_macros = [] -platform_link_libraries = [] -platform_extra_link_args = [] -platform_extra_compile_args = ( - ["/W3", "/GL", "/DNDEBUG", "/MD", "/EHsc", "/std:c++17"] - if system == "Windows" - else [ - "-O2", - "-g", - "-std=c++17", - ] -) -include_directories = [ - proj_root + "/deps/VMA/include", - proj_root + "/deps/volk", - proj_root + "/deps/VkFFT/vkFFT", -] +BUILD_TARGET = get_build_target() -if os.name == "posix": - platform_extra_link_args.append("-g") - platform_extra_link_args.append("-O0") - platform_extra_link_args.append("-fno-omit-frame-pointer") - platform_link_libraries.extend(["dl", "pthread"]) - - -if vulkan_sdk_root is None: - include_directories.extend([ - proj_root + "/include_ext", - proj_root + "/deps/Vulkan-Headers/include", - proj_root + "/deps/Vulkan-Utility-Libraries/include", - proj_root + "/deps/glslang", - proj_root + "/deps/glslang/glslang/Include", - ]) - - if system == "Darwin": - platform_library_dirs.append(molten_vk_path) - platform_link_libraries.append("MoltenVK") - platform_extra_link_args.extend([ - "-framework", "Metal", - "-framework", "AVFoundation", - "-framework", "AppKit" - ]) - platform_extra_compile_args.append("-mmacosx-version-min=10.15") - else: - platform_define_macros.append(("VKDISPATCH_USE_VOLK", 1)) -else: - include_directories.extend([ - vulkan_sdk_root + '/include', - vulkan_sdk_root + '/include/utility', - vulkan_sdk_root + '/include/glslang/Include', - ]) +proj_root = Path(__file__).resolve().parent +system = platform.system() +molten_vk_path = "./deps/MoltenVK/MoltenVK/MoltenVK/static/MoltenVK.xcframework/macos-arm64_x86_64/" +vulkan_sdk_root = os.environ.get("VULKAN_SDK") - platform_define_macros.append(("VKDISPATCH_USE_VOLK", 1)) - platform_define_macros.append(("VKDISPATCH_LOADER_PATH", '"' + os.path.abspath(f"{vulkan_sdk_root}") + '/"')) - #if os.name == "posix": - # platform_link_libraries.append("vulkan") - #else: - # platform_link_libraries.append("vulkan-1") +def read_version() -> str: + init_path = proj_root / "vkdispatch" / "__init__.py" + text = init_path.read_text(encoding="utf-8") + match = re.search(r'^__version__\s*=\s*"([^"]+)"', text, re.MULTILINE) + if not match: + raise RuntimeError(f"Could not find __version__ in {init_path}") + return match.group(1) - platform_library_dirs.append(vulkan_sdk_root + '/lib') - platform_link_libraries.extend([ - "glslang", - "SPIRV", - "MachineIndependent", - "GenericCodeGen", - "SPIRV-Tools-opt", - "SPIRV-Tools-link", - "SPIRV-Tools-reduce", - "SPIRV-Tools", - "glslang-default-resource-limits" - ]) +def read_readme() -> str: + return (proj_root / "README.md").read_text(encoding="utf-8") -sources = [] +VERSION = read_version() -def append_to_sources(prefix, source_list): - global sources +COMMON_CLASSIFIERS = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Development Status :: 2 - Pre-Alpha", +] + +COMMON_PROJECT_URLS = { + "Homepage": "https://github.com/sharhar/vkdispatch", + "Issues": "https://github.com/sharhar/vkdispatch/issues", +} + +COMMON_EXTRAS = { + "cuda": ["cuda-python"], + "opencl": ["pyopencl", "numpy"], + "pycuda": ["pycuda"], + "numpy": ["numpy"], +} - for source in source_list: - sources.append(prefix + source) - - -sources.append("vkdispatch_native/wrapper.pyx") - -append_to_sources("vkdispatch_native/", [ - "context/init.cpp", - "context/context.cpp", - "context/errors.cpp", - "context/handles.cpp", - - "objects/buffer.cpp", - "objects/image.cpp", - "objects/command_list.cpp", - "objects/descriptor_set.cpp", - - "stages/stage_fft.cpp", - "stages/stage_compute.cpp", - - "queue/queue.cpp", - "queue/signal.cpp", - "queue/work_queue.cpp", - "queue/barrier_manager.cpp", - - "libs/VMAImpl.cpp", - "libs/VolkImpl.cpp" -]) - -if vulkan_sdk_root is None: - append_to_sources("deps/glslang/glslang/", [ - "CInterface/glslang_c_interface.cpp", - "GenericCodeGen/CodeGen.cpp", - "GenericCodeGen/Link.cpp", - "MachineIndependent/glslang_tab.cpp", - "MachineIndependent/attribute.cpp", - "MachineIndependent/Constant.cpp", - "MachineIndependent/iomapper.cpp", - "MachineIndependent/InfoSink.cpp", - "MachineIndependent/Initialize.cpp", - "MachineIndependent/IntermTraverse.cpp", - "MachineIndependent/Intermediate.cpp", - "MachineIndependent/ParseContextBase.cpp", - "MachineIndependent/ParseHelper.cpp", - "MachineIndependent/PoolAlloc.cpp", - "MachineIndependent/RemoveTree.cpp", - "MachineIndependent/Scan.cpp", - "MachineIndependent/ShaderLang.cpp", - "MachineIndependent/SpirvIntrinsics.cpp", - "MachineIndependent/SymbolTable.cpp", - "MachineIndependent/Versions.cpp", - "MachineIndependent/intermOut.cpp", - "MachineIndependent/limits.cpp", - "MachineIndependent/linkValidate.cpp", - "MachineIndependent/parseConst.cpp", - "MachineIndependent/reflection.cpp", - "MachineIndependent/preprocessor/Pp.cpp", - "MachineIndependent/preprocessor/PpAtom.cpp", - "MachineIndependent/preprocessor/PpContext.cpp", - "MachineIndependent/preprocessor/PpScanner.cpp", - "MachineIndependent/preprocessor/PpTokens.cpp", - "MachineIndependent/propagateNoContraction.cpp", - "ResourceLimits/ResourceLimits.cpp", - "ResourceLimits/resource_limits_c.cpp" - ]) - - append_to_sources("deps/glslang/SPIRV/", [ - "GlslangToSpv.cpp", - "InReadableOrder.cpp", - "Logger.cpp", - "SpvBuilder.cpp", - "SpvPostProcess.cpp", - "doc.cpp", - "SpvTools.cpp", - "disassemble.cpp", - "CInterface/spirv_c_interface.cpp" - ]) def parse_compiler_version(version_output): if not isinstance(version_output, str): return None - - # Try to match either clang or gcc version string - clang_match = re.search(r'clang version ([^\s]+)', version_output) - gcc_match = re.search(r'gcc.+?([\d.]+(?:-[a-zA-Z0-9]+)?)', version_output, re.IGNORECASE) - + + clang_match = re.search(r"clang version ([^\s]+)", version_output) + gcc_match = re.search( + r"gcc.+?([\d.]+(?:-[a-zA-Z0-9]+)?)", version_output, re.IGNORECASE + ) + match = clang_match or gcc_match if not match: return None try: return Version(match.group(1)) - except Exception as e: - print(f"Invalid version: {e}") + except Exception as exc: + print(f"Invalid version: {exc}") return None + def detect_unix_compiler(compiler_exe): - """ - Given the 'compiler_exe' (like 'gcc', 'clang', etc.), returns a string - denoting the compiler family: 'clang', 'gcc', or 'unknown'. - """ try: - # Run e.g. `gcc --version` or `clang --version` - version_output = subprocess.check_output([compiler_exe, '--version'], - stderr=subprocess.STDOUT, - universal_newlines=True) - - if 'clang' in version_output: - return 'clang', parse_compiler_version(version_output) - elif 'gcc' in version_output or 'Free Software Foundation' in version_output: - return 'gcc', parse_compiler_version(version_output) - else: - return 'unknown', None + version_output = subprocess.check_output( + [compiler_exe, "--version"], + stderr=subprocess.STDOUT, + universal_newlines=True, + ) + + if "clang" in version_output: + return "clang", parse_compiler_version(version_output) + if "gcc" in version_output or "Free Software Foundation" in version_output: + return "gcc", parse_compiler_version(version_output) + return "unknown", None except Exception: - return 'unknown', None - + return "unknown", None + + class CustomBuildExt(build_ext): def build_extensions(self): compiler_type = self.compiler.compiler_type print(f"Detected compiler type: {compiler_type}") - if compiler_type == 'unix': + if compiler_type == "unix": print(f"Detected compiler: {self.compiler.compiler}") compiler_family, version = detect_unix_compiler(self.compiler.compiler[0]) print(f"Detected compiler family: {compiler_family}") @@ -241,50 +131,290 @@ def build_extensions(self): if version is not None: for ext in self.extensions: - if compiler_family == 'clang' and version < Version('9.0'): - ext.libraries.append('c++fs') - elif compiler_family == 'gcc' and version < Version('9.1'): - ext.libraries.append('stdc++fs') + if compiler_family == "clang" and version < Version("9.0"): + ext.libraries.append("c++fs") + elif compiler_family == "gcc" and version < Version("9.1"): + ext.libraries.append("stdc++fs") else: - print("WARNING: Unknown compiler family, not adding filesystem library") + print( + "WARNING: Unknown compiler family, not adding filesystem library" + ) - # Now actually build the extensions super().build_extensions() -setup( - name="vkdispatch", - packages=[ - "vkdispatch", - "vkdispatch.base", - "vkdispatch.backends", - "vkdispatch._compat", - "vkdispatch.codegen", - "vkdispatch.codegen.backends", - "vkdispatch.codegen.functions", - "vkdispatch.codegen.functions.base_functions", - "vkdispatch.codegen.variables", - "vkdispatch.execution_pipeline", - "vkdispatch.shader", - "vkdispatch.reduce", - "vkdispatch.vkfft", - "vkdispatch.fft" - ], - ext_modules=[ - Extension( - "vkdispatch_native", - sources=sources, - language="c++", - define_macros=platform_define_macros, - library_dirs=platform_library_dirs, - libraries=platform_link_libraries, - extra_compile_args=platform_extra_compile_args, - extra_link_args=platform_extra_link_args, - include_dirs=include_directories, + +def append_to_sources(prefix, source_list, out_sources): + for source in source_list: + out_sources.append(prefix + source) + + +def build_native_extension(): + platform_library_dirs = [] + platform_define_macros = [] + platform_link_libraries = [] + platform_extra_link_args = [] + platform_extra_compile_args = ( + ["/W3", "/GL", "/DNDEBUG", "/MD", "/EHsc", "/std:c++17"] + if system == "Windows" + else ["-O2", "-g", "-std=c++17"] + ) + + include_directories = [ + str(proj_root / "deps" / "VMA" / "include"), + str(proj_root / "deps" / "volk"), + str(proj_root / "deps" / "VkFFT" / "vkFFT"), + ] + + if os.name == "posix": + platform_extra_link_args.extend(["-g", "-O0", "-fno-omit-frame-pointer"]) + platform_link_libraries.extend(["dl", "pthread"]) + + if vulkan_sdk_root is None: + include_directories.extend( + [ + str(proj_root / "include_ext"), + str(proj_root / "deps" / "Vulkan-Headers" / "include"), + str(proj_root / "deps" / "Vulkan-Utility-Libraries" / "include"), + str(proj_root / "deps" / "glslang"), + str(proj_root / "deps" / "glslang" / "glslang" / "Include"), + ] + ) + + if system == "Darwin": + platform_library_dirs.append(molten_vk_path) + platform_link_libraries.append("MoltenVK") + platform_extra_link_args.extend( + [ + "-framework", + "Metal", + "-framework", + "AVFoundation", + "-framework", + "AppKit", + ] + ) + platform_extra_compile_args.append("-mmacosx-version-min=10.15") + else: + platform_define_macros.append(("VKDISPATCH_USE_VOLK", 1)) + else: + include_directories.extend( + [ + vulkan_sdk_root + "/include", + vulkan_sdk_root + "/include/utility", + vulkan_sdk_root + "/include/glslang/Include", + ] ) - ], - cmdclass={ - 'build_ext': CustomBuildExt, - }, - version="0.0.30", - zip_safe=False, -) + + platform_define_macros.append(("VKDISPATCH_USE_VOLK", 1)) + platform_define_macros.append( + ("VKDISPATCH_LOADER_PATH", '"' + os.path.abspath(vulkan_sdk_root) + '/"') + ) + + platform_library_dirs.append(vulkan_sdk_root + "/lib") + platform_link_libraries.extend( + [ + "glslang", + "SPIRV", + "MachineIndependent", + "GenericCodeGen", + "SPIRV-Tools-opt", + "SPIRV-Tools-link", + "SPIRV-Tools-reduce", + "SPIRV-Tools", + "glslang-default-resource-limits", + ] + ) + + sources = [] + sources.append("vkdispatch_native/wrapper.pyx") + + append_to_sources( + "vkdispatch_native/", + [ + "context/init.cpp", + "context/context.cpp", + "context/errors.cpp", + "context/handles.cpp", + "objects/buffer.cpp", + "objects/image.cpp", + "objects/command_list.cpp", + "objects/descriptor_set.cpp", + "stages/stage_fft.cpp", + "stages/stage_compute.cpp", + "queue/queue.cpp", + "queue/signal.cpp", + "queue/work_queue.cpp", + "queue/barrier_manager.cpp", + "libs/VMAImpl.cpp", + "libs/VolkImpl.cpp", + ], + sources, + ) + + if vulkan_sdk_root is None: + append_to_sources( + "deps/glslang/glslang/", + [ + "CInterface/glslang_c_interface.cpp", + "GenericCodeGen/CodeGen.cpp", + "GenericCodeGen/Link.cpp", + "MachineIndependent/glslang_tab.cpp", + "MachineIndependent/attribute.cpp", + "MachineIndependent/Constant.cpp", + "MachineIndependent/iomapper.cpp", + "MachineIndependent/InfoSink.cpp", + "MachineIndependent/Initialize.cpp", + "MachineIndependent/IntermTraverse.cpp", + "MachineIndependent/Intermediate.cpp", + "MachineIndependent/ParseContextBase.cpp", + "MachineIndependent/ParseHelper.cpp", + "MachineIndependent/PoolAlloc.cpp", + "MachineIndependent/RemoveTree.cpp", + "MachineIndependent/Scan.cpp", + "MachineIndependent/ShaderLang.cpp", + "MachineIndependent/SpirvIntrinsics.cpp", + "MachineIndependent/SymbolTable.cpp", + "MachineIndependent/Versions.cpp", + "MachineIndependent/intermOut.cpp", + "MachineIndependent/limits.cpp", + "MachineIndependent/linkValidate.cpp", + "MachineIndependent/parseConst.cpp", + "MachineIndependent/reflection.cpp", + "MachineIndependent/preprocessor/Pp.cpp", + "MachineIndependent/preprocessor/PpAtom.cpp", + "MachineIndependent/preprocessor/PpContext.cpp", + "MachineIndependent/preprocessor/PpScanner.cpp", + "MachineIndependent/preprocessor/PpTokens.cpp", + "MachineIndependent/propagateNoContraction.cpp", + "ResourceLimits/ResourceLimits.cpp", + "ResourceLimits/resource_limits_c.cpp", + ], + sources, + ) + + append_to_sources( + "deps/glslang/SPIRV/", + [ + "GlslangToSpv.cpp", + "InReadableOrder.cpp", + "Logger.cpp", + "SpvBuilder.cpp", + "SpvPostProcess.cpp", + "doc.cpp", + "SpvTools.cpp", + "disassemble.cpp", + "CInterface/spirv_c_interface.cpp", + ], + sources, + ) + + return Extension( + "vkdispatch_vulkan_native", + sources=sources, + language="c++", + define_macros=platform_define_macros, + library_dirs=platform_library_dirs, + libraries=platform_link_libraries, + extra_compile_args=platform_extra_compile_args, + extra_link_args=platform_extra_link_args, + include_dirs=include_directories, + ) + + +def base_setup_kwargs(): + return { + "version": VERSION, + "author": "Shahar Sandhaus", + "author_email": "shahar.sandhaus@gmail.com", + "description": ( + "A Python module for orchestrating and dispatching large computations " + "across multi-GPU systems using Vulkan." + ), + "long_description": read_readme(), + "long_description_content_type": "text/markdown", + "python_requires": ">=3.6", + "classifiers": COMMON_CLASSIFIERS, + "project_urls": COMMON_PROJECT_URLS, + "zip_safe": False, + } + + +def core_packages(): + return find_packages(include=["vkdispatch", "vkdispatch.*"]) + + +def setup_for_target(target: str): + kwargs = base_setup_kwargs() + + if target == BUILD_TARGET_FULL: + kwargs.update( + { + "name": "vkdispatch", + "packages": core_packages(), + "install_requires": ["setuptools>=59.0"], + "extras_require": { + "cli": ["Click"], + **COMMON_EXTRAS, + }, + "entry_points": { + "console_scripts": [ + "vdlist=vkdispatch.cli:cli_entrypoint", + ] + }, + "ext_modules": [build_native_extension()], + "cmdclass": {"build_ext": CustomBuildExt}, + } + ) + return kwargs + + if target == BUILD_TARGET_CORE: + kwargs.update( + { + "name": "vkdispatch-core", + "packages": core_packages(), + "install_requires": ["setuptools>=59.0"], + "extras_require": dict(COMMON_EXTRAS), + } + ) + return kwargs + + if target == BUILD_TARGET_NATIVE: + kwargs.update( + { + "name": "vkdispatch-vulkan-native", + "packages": [], + "py_modules": [], + "install_requires": [], + "ext_modules": [build_native_extension()], + "cmdclass": {"build_ext": CustomBuildExt}, + } + ) + return kwargs + + if target == BUILD_TARGET_META: + kwargs.update( + { + "name": "vkdispatch", + "packages": [], + "py_modules": [], + "install_requires": [ + f"vkdispatch-core=={VERSION}", + f"vkdispatch-vulkan-native=={VERSION}", + ], + "extras_require": { + "cli": ["Click"], + **COMMON_EXTRAS, + }, + "entry_points": { + "console_scripts": [ + "vdlist=vkdispatch.cli:cli_entrypoint", + ] + }, + } + ) + return kwargs + + raise AssertionError(f"Unhandled build target: {target}") + + +setup(**setup_for_target(BUILD_TARGET)) diff --git a/shader_run.py b/shader_run.py new file mode 100644 index 00000000..8c34a024 --- /dev/null +++ b/shader_run.py @@ -0,0 +1,89 @@ +import vkdispatch as vd + +from vkdispatch.base.command_list import CommandList +from vkdispatch.base.compute_plan import ComputePlan +from vkdispatch.base.descriptor_set import DescriptorSet + +import numpy as np + +def load_shader(path: str) -> ComputePlan: + shader_source = open(path, 'r').read() + + return ComputePlan( + shader_source=shader_source, + binding_type_list=[1, 1, 1], + pc_size=0, + shader_name=f"shader_{path.split('/')[-1].split('.')[0]}" + ) + +def make_descriptor(plan: ComputePlan, out_buff: vd.Buffer, in_buff: vd.Buffer, kern_buff: vd.Buffer): + descriptor_set = DescriptorSet(plan) + + descriptor_set.bind_buffer(out_buff, 0) + descriptor_set.bind_buffer(in_buff, 1) + descriptor_set.bind_buffer(kern_buff, 2) + + return descriptor_set + +def numpy_convolution(signal: np.ndarray, kernel: np.ndarray) -> np.ndarray: + return np.fft.ifft( + np.fft.fft(signal, axis=1).astype(np.complex64) + * + kernel.conjugate(), + axis=1 + ) + +BUFF_SHAPE = (4, 512, 257) + +np.random.seed(1337) + +in_data = (np.random.rand(*BUFF_SHAPE) + 1j * np.random.rand(*BUFF_SHAPE)).astype(np.complex64) +kern_data = (np.random.rand(*BUFF_SHAPE) + 1j * np.random.rand(*BUFF_SHAPE)).astype(np.complex64) + +reference_result_data = numpy_convolution(in_data, kern_data[0]) + +out_buff = vd.buffer_c64(BUFF_SHAPE) +in_buff = vd.buffer_c64(BUFF_SHAPE) +kern_buff = vd.buffer_c64(BUFF_SHAPE) + +in_buff.write(in_data) +kern_buff.write(kern_data) + +block_count = (1028, 32, 1) + +plan_bad = load_shader("conv_bad.comp") +plan_good = load_shader("conv_good.comp") + +cmd_list_bad = CommandList() + +cmd_list_bad.record_compute_plan( + plan_bad, + make_descriptor(plan_bad, out_buff, in_buff, kern_buff), + block_count +) + +cmd_list_bad.submit(instance_count=1) + +result_data_bad = out_buff.read(0) + +cmd_list_good = CommandList() + +cmd_list_good.record_compute_plan( + plan_good, + make_descriptor(plan_good, out_buff, in_buff, kern_buff), + block_count +) + +cmd_list_good.submit(instance_count=1) + +result_data_good = out_buff.read(0) + +for i in range(BUFF_SHAPE[0]): + np.save(f"result_bad_{i}.npy", result_data_bad[i]) + np.save(f"result_good_{i}.npy", result_data_good[i]) + np.save(f"reference_result_{i}.npy", reference_result_data[i]) + np.save(f"diff_bad_{i}.npy", result_data_bad[i] - reference_result_data[i]) + np.save(f"diff_good_{i}.npy", result_data_good[i] - reference_result_data[i]) + np.save(f"diff_{i}.npy", result_data_good[i] - result_data_bad[i]) + +assert np.allclose(result_data_good, result_data_bad, atol=1e-3) diff --git a/test.py b/test.py index 320b68e5..b7f21622 100644 --- a/test.py +++ b/test.py @@ -1,51 +1,17 @@ import vkdispatch as vd import vkdispatch.codegen as vc -import numpy as np -from typing import Tuple +vd.initialize(backend="vulkan", log_level=vd.LogLevel.INFO) +vc.set_codegen_backend("glsl") -vd.initialize(backend="pycuda") +SIZE = 4096 -def make_shape(fft_size: int, data_size: int) -> Tuple[int, ...]: - total_square_size = fft_size * fft_size - assert data_size % total_square_size == 0, "Data size must be a multiple of fft_size squared" - return (data_size // total_square_size, fft_size, fft_size) +buff_shape = (2, SIZE, SIZE) -def make_random_data(fft_size: int, run_index: int, data_size: int, seed: int = 1337) -> np.ndarray: - shape = make_shape(fft_size, data_size) - rng = np.random.default_rng(seed + fft_size * 1000 + run_index) +buff = vd.Buffer(buff_shape, var_type=vd.complex64) - real = rng.standard_normal(shape).astype(np.float32) - imag = rng.standard_normal(shape).astype(np.float32) - return (real + 1j * imag).astype(np.complex64) +vd.vkfft.fft(buff, axis=1) #, print_shader=True) -def compute_metrics(reference: np.ndarray, result: np.ndarray): - reference64 = reference.astype(np.complex128, copy=False) - result64 = result.astype(np.complex128, copy=False) +vd.queue_wait_idle() - delta = result64 - reference64 - abs_delta = np.abs(delta) - abs_reference = np.abs(reference64) - - eps = 1e-12 - relative_l2 = np.linalg.norm(delta.ravel()) / max(np.linalg.norm(reference64.ravel()), eps) - max_relative = np.max(abs_delta / np.maximum(abs_reference, eps)) - max_absolute = np.max(abs_delta) - - return float(relative_l2), float(max_relative), float(max_absolute) - -fft_size = 64 -data_size = 16 * 1024 * 1024 - -input_data = make_random_data(fft_size, 0, data_size) -reference = np.fft.fft(input_data) - -shape = make_shape(fft_size, data_size) - -buffer = vd.buffer_c64(shape) #Buffer(shape, var_type=vd.complex64) - -buffer.write(input_data) -vd.fft.fft(buffer) #, print_shader=True) -result_data = buffer.read(0) - -print(compute_metrics(reference, result_data)) \ No newline at end of file +#print(vd.fft.fft_src(buff_shape, axis=1).code) \ No newline at end of file diff --git a/test2.py b/test2.py index 813a205e..5f494e18 100644 --- a/test2.py +++ b/test2.py @@ -1,304 +1,109 @@ import vkdispatch as vd import vkdispatch.codegen as vc -vd.initialize(debug_mode=True, backend="pycuda") #, log_level=vd.LogLevel.INFO) +#vd.initialize(debug_mode=True, backend="cuda") +#vc.set_codegen_backend("cuda") -vc.set_codegen_backend("cuda") - -import dataclasses -import enum - -from typing import List -from typing import Any -from typing import Dict -from typing import Tuple - -#vd.initialize(debug_mode=True) -vd.make_context(use_cpu=True) - -from vkdispatch.base.compute_plan import ComputePlan -from vkdispatch.base.descriptor_set import DescriptorSet -from vkdispatch.base.command_list import CommandList +from typing import Callable, Union, Tuple import numpy as np -class CommandType(enum.Enum): - ADD_VALUE = 0 - SUB_VALUE = 1 - MULT_VALUE = 2 - DIV_VALUE = 3 - SIN_VALUE = 4 - COS_VALUE = 5 - -valid_commands = [ - CommandType.ADD_VALUE, - CommandType.SUB_VALUE, -] - -command_type_to_str = { - CommandType.ADD_VALUE: "ADD", - CommandType.SUB_VALUE: "SUB", - CommandType.MULT_VALUE: "MULT", - CommandType.DIV_VALUE: "DIV", - CommandType.SIN_VALUE: "SIN", - CommandType.COS_VALUE: "COS" -} - -@dataclasses.dataclass -class ProgramCommand: - command_type: CommandType - value: float +import time +import dataclasses @dataclasses.dataclass -class RunConfig: - buffer_count: int - buffer_sizes: List[int] - - program_count: int - program_commands: List[List[ProgramCommand]] - - def __repr__(self): - commands_repr = "" - - for commands in self.program_commands: - commands_repr += "\n" - - for command in commands: - command_name = command_type_to_str[command.command_type] - - commands_repr += f" {command_name} {command.value}\n" - - return f"""RunConfig( - buffer_count={self.buffer_count}, - buffer_sizes={self.buffer_sizes}, - program_count={self.program_count}, - program_commands=[{commands_repr} -])""" - -def make_random_config() -> RunConfig: - buffer_count = np.random.randint(10, 50) - buffer_sizes = np.random.randint(500, 2500, size=buffer_count).tolist() - - program_count = np.random.randint(10, 50) - program_commands = [] - - for _ in range(program_count): - command_count = np.random.randint(10, 50) - commands = [] - - for _ in range(command_count): - command_type = np.random.choice(valid_commands) - value = np.random.uniform(-10, 10) - - commands.append(ProgramCommand(command_type, value)) - - program_commands.append(commands) - - return RunConfig( - buffer_count=buffer_count, - buffer_sizes=buffer_sizes, - program_count=program_count, - program_commands=program_commands - ) - -buffer_cache: Dict[int, vd.Buffer] = {} - -def get_buffer(index: int, config: RunConfig) -> vd.Buffer: - global buffer_cache +class Config: + data_size: int + iter_count: int + iter_batch: int + run_count: int + signal_factor: int + warmup: int = 10 + + def make_shape(self, fft_size: int) -> Tuple[int, ...]: + total_square_size = fft_size * fft_size + assert self.data_size % total_square_size == 0, "Data size must be a multiple of fft_size squared" + return (self.data_size // total_square_size, fft_size, fft_size) - if index not in buffer_cache: - buffer_cache[index] = vd.asbuffer( - np.zeros( - shape=(config.buffer_sizes[index],), - dtype=np.float32 - ) - ) - - return buffer_cache[index] + def make_random_data(self, fft_size: int): + shape = self.make_shape(fft_size) + return np.random.rand(*shape).astype(np.complex64) -array_cache: Dict[int, np.ndarray] = {} +def run_vkdispatch(config: Config, + fft_size: int, + io_count: Union[int, Callable], + gpu_function: Callable) -> float: + shape = config.make_shape(fft_size) -def get_array(index: int, config: RunConfig) -> np.ndarray: - global array_cache + buffer = vd.Buffer(shape, var_type=vd.complex64) + kernel = vd.Buffer(shape, var_type=vd.complex64) - if index not in array_cache: - array_cache[index] = np.zeros( - shape=(config.buffer_sizes[index],), - dtype=np.float32 - ) - - return array_cache[index] - -def make_source(commands: List[ProgramCommand]): - local_size_x = vd.get_context().max_workgroup_size[0] - - header = """ -#version 450 -#extension GL_ARB_separate_shader_objects : enable -//#extension GL_EXT_debug_printf : enable - -layout(push_constant) uniform PushConstant { - uint exec_count; -} PC; - -layout(set = 0, binding = 0) buffer Buffer0 { float data[]; } bufOut; -layout(set = 0, binding = 1) buffer Buffer1 { float data[]; } bufIn; -""" + f""" -layout(local_size_x = {local_size_x}, local_size_y = 1, local_size_z = 1) in; -""" + """ -void main() { - if(PC.exec_count <= gl_GlobalInvocationID.x) { - return ; - } - - uint tid = gl_GlobalInvocationID.x; - - float value = bufIn.data[tid]; -""" - - body = "" - - for command in commands: - if command.command_type == CommandType.ADD_VALUE: - body += f" value += {command.value};\n" - elif command.command_type == CommandType.SUB_VALUE: - body += f" value -= {command.value};\n" - elif command.command_type == CommandType.MULT_VALUE: - body += f" value *= {command.value};\n" - elif command.command_type == CommandType.DIV_VALUE: - body += f" value /= {command.value};\n" - elif command.command_type == CommandType.SIN_VALUE: - body += f" value = sin(value);\n" - elif command.command_type == CommandType.COS_VALUE: - body += f" value = cos(value);\n" - - ending = """ - bufOut.data[tid] = value; -} -""" - - return header + body + ending - -program_cache: Dict[int, ComputePlan] = {} - -def get_program(index: int, config: RunConfig) -> ComputePlan: - global program_cache - - if index not in program_cache: - program_cache[index] = ComputePlan( - shader_source=make_source(config.program_commands[index]), - binding_type_list=[1, 1], - pc_size=4, - shader_name=f"program_{index}" - ) - - return program_cache[index] - -descriptor_set_cache: Dict[Tuple[int, int, int], DescriptorSet] = {} - -def get_descriptor_set(out_buffer: int, in_buffer: int, program: ComputePlan, config: RunConfig) -> DescriptorSet: - global descriptor_set_cache - - dict_key = (out_buffer, in_buffer, program._handle) - - if dict_key not in descriptor_set_cache: - output_buffer = get_buffer(out_buffer, config) - input_buffer = get_buffer(in_buffer, config) - - descriptor_set = DescriptorSet(program) - descriptor_set.bind_buffer(output_buffer, 0) - descriptor_set.bind_buffer(input_buffer, 1) - - descriptor_set_cache[dict_key] = descriptor_set + graph = vd.CommandGraph() + old_graph = vd.set_global_graph(graph) + + gpu_function(config, fft_size, buffer, kernel) - return descriptor_set_cache[dict_key] + vd.set_global_graph(old_graph) -def clear_caches(): - global buffer_cache - global array_cache - global program_cache - global descriptor_set_cache + for _ in range(config.warmup): + graph.submit(config.iter_batch) - buffer_cache.clear() - array_cache.clear() - program_cache.clear() - descriptor_set_cache.clear() + vd.queue_wait_idle() -def do_vkdispatch_command(cmd_list: CommandList, out_buffer: int, in_buffer: int, program: int, config: RunConfig): - compute_plan = get_program(program, config) - descriptor_set = get_descriptor_set(out_buffer, in_buffer, compute_plan, config) + if callable(io_count): + io_count = io_count(buffer.size, fft_size) - cmd_list.reset() + gb_byte_count = io_count * 8 * buffer.size / (1024 * 1024 * 1024) - local_size = vd.get_context().max_workgroup_size[0] + start_time = time.perf_counter() - total_exec_size = min(config.buffer_sizes[out_buffer], config.buffer_sizes[in_buffer]) + for _ in range(config.iter_count // config.iter_batch): + graph.submit(config.iter_batch) - block_count = (total_exec_size + local_size - 1) // local_size + vd.queue_wait_idle() - cmd_list.record_compute_plan(compute_plan, descriptor_set, [block_count, 1, 1]) + elapsed_time = time.perf_counter() - start_time - cmd_list.submit(data=np.array([total_exec_size], dtype=np.uint32).tobytes()) + buffer.destroy() + kernel.destroy() + graph.destroy() + vd.fft.cache_clear() -def do_numpy_command(out_buffer: int, in_buffer: int, program: int, config: RunConfig): - output_array = get_array(out_buffer, config) - input_array = get_array(in_buffer, config) + time.sleep(1) - total_exec_size = min(config.buffer_sizes[out_buffer], config.buffer_sizes[in_buffer]) + vd.queue_wait_idle() - temp_array = np.zeros(shape=(total_exec_size,), dtype=np.float32) - temp_array[:] = input_array[:total_exec_size] + return gb_byte_count, elapsed_time - commands = config.program_commands[program] - for command in commands: - if command.command_type == CommandType.ADD_VALUE: - temp_array += command.value - temp_array = temp_array.astype(np.float32) - elif command.command_type == CommandType.SUB_VALUE: - temp_array -= command.value - temp_array = temp_array.astype(np.float32) - elif command.command_type == CommandType.MULT_VALUE: - temp_array *= command.value - temp_array = temp_array.astype(np.float32) - elif command.command_type == CommandType.DIV_VALUE: - temp_array /= command.value - temp_array = temp_array.astype(np.float32) - elif command.command_type == CommandType.SIN_VALUE: - temp_array = np.sin(temp_array) - temp_array = temp_array.astype(np.float32) - elif command.command_type == CommandType.COS_VALUE: - temp_array = np.cos(temp_array) - temp_array = temp_array.astype(np.float32) +def run_test(config: Config, + io_count: Union[int, Callable], + gpu_function: Callable): + fft_sizes = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] - output_array[:total_exec_size] = temp_array + for fft_size in fft_sizes: + rates = [] -def test_async_commands(): - for _ in range(50): - clear_caches() - - config = make_random_config() + for _ in range(config.run_count): + gb_byte_count, elapsed_time = run_vkdispatch(config, fft_size, io_count, gpu_function) + gb_per_second = config.iter_count * gb_byte_count / elapsed_time - cmd_list = CommandList() + print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.4f} GB/s") + rates.append(gb_per_second) - exec_count = np.random.randint(1, 250) +def do_fft(config: Config, + fft_size: int, + buffer: vd.Buffer, + kernel: vd.Buffer): + vd.fft.fft(buffer) - input_buffers = np.random.randint(0, config.buffer_count, size=exec_count) - output_buffers = np.random.randint(0, config.buffer_count, size=exec_count) - programs = np.random.randint(0, config.program_count, size=exec_count) - for input_buffer, output_buffer, program in zip(input_buffers, output_buffers, programs): - do_vkdispatch_command(cmd_list, output_buffer, input_buffer, program, config) - - for input_buffer, output_buffer, program in zip(input_buffers, output_buffers, programs): - do_numpy_command(output_buffer, input_buffer, program, config) - - for i in range(config.buffer_count): - numpy_buffer = get_array(i, config) - vkbuffer = get_buffer(i, config).read(0) - - assert np.allclose(vkbuffer, numpy_buffer, atol=1e-3) - - clear_caches() +conf = Config( + data_size=2**26, + iter_count=80, + iter_batch=10, + run_count=1, + signal_factor=8 +) -test_async_commands() \ No newline at end of file +run_test(conf, 2, do_fft) \ No newline at end of file diff --git a/test3.py b/test3.py deleted file mode 100644 index 7b29f4eb..00000000 --- a/test3.py +++ /dev/null @@ -1,470 +0,0 @@ - -import pycuda.autoinit -import pycuda.driver as cuda -import numpy as np -from pycuda.compiler import SourceModule - -import struct - - -cuda_kernel = """ -// Expected local size: (8, 1, 1) -#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X 8 -#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y 1 -#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Z 1 - -#include -#include -#include - -#define VKDISPATCH_ENABLE_SUBGROUP_OPS 1 -#define VKDISPATCH_ENABLE_PRINTF 1 - -__device__ __forceinline__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); } -__device__ __forceinline__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); } -__device__ __forceinline__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); } -__device__ __forceinline__ float2 operator*(float s, float2 v) { return make_float2(s * v.x, s * v.y); } -__device__ __forceinline__ float2 operator*(float2 v, float s) { return make_float2(v.x * s, v.y * s); } - -__device__ __forceinline__ float2 vkdispatch_make_float2(float x, float y) { return make_float2(x, y); } -__device__ __forceinline__ float2 vkdispatch_make_float2(float x) { return make_float2(x, x); } -template __device__ __forceinline__ float2 vkdispatch_make_float2(TVec v) { return make_float2((float)v.x, (float)v.y); } - -__device__ __forceinline__ uint3 vkdispatch_local_invocation_id() { - return make_uint3((unsigned int)threadIdx.x, (unsigned int)threadIdx.y, (unsigned int)threadIdx.z); -} - -__device__ __forceinline__ uint3 vkdispatch_workgroup_id() { - return make_uint3((unsigned int)blockIdx.x, (unsigned int)blockIdx.y, (unsigned int)blockIdx.z); -} - -__device__ __forceinline__ unsigned int vkdispatch_local_invocation_index() { - return (unsigned int)(threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z)); -} - -__shared__ float2 sdata[68]; - -struct UniformObjectBuffer { - uint4 exec_count; - int4 sdata_shape; - int4 buf1_shape; -}; -struct Buffer1 { float2* data; }; - -extern "C" __global__ void vkdispatch_main(const UniformObjectBuffer* vkdispatch_uniform_ptr, float2* vkdispatch_binding_1_ptr) { - const UniformObjectBuffer& UBO = *vkdispatch_uniform_ptr; - Buffer1 buf1 = {vkdispatch_binding_1_ptr}; - unsigned int workgroup_index = ((unsigned int)(vkdispatch_workgroup_id().x)); - unsigned int tid = vkdispatch_local_invocation_id().x; - unsigned int input_batch_offset = ((unsigned int)(0)); - unsigned int output_batch_offset = ((unsigned int)(0)); - float2 omega_register = vkdispatch_make_float2(0.0f); - unsigned int subsequence_offset = ((unsigned int)(0)); - unsigned int io_index = ((unsigned int)(0)); - unsigned int io_index_2 = ((unsigned int)(0)); - float2 radix_register_0 = vkdispatch_make_float2(0.0f); - float2 radix_register_1 = vkdispatch_make_float2(0.0f); - float2 fft_reg_0 = vkdispatch_make_float2(0.0f); - float2 fft_reg_1 = vkdispatch_make_float2(0.0f); - float2 fft_reg_2 = vkdispatch_make_float2(0.0f); - float2 fft_reg_3 = vkdispatch_make_float2(0.0f); - float2 fft_reg_4 = vkdispatch_make_float2(0.0f); - float2 fft_reg_5 = vkdispatch_make_float2(0.0f); - float2 fft_reg_6 = vkdispatch_make_float2(0.0f); - float2 fft_reg_7 = vkdispatch_make_float2(0.0f); - - /* Reading input samples from global memory into FFT registers. */ - input_batch_offset = ((workgroup_index + vkdispatch_local_invocation_id().y) << 6); - io_index = (tid + input_batch_offset); - fft_reg_0 = buf1.data[io_index]; - io_index = ((tid + 8) + input_batch_offset); - fft_reg_1 = buf1.data[io_index]; - io_index = ((tid + 16) + input_batch_offset); - fft_reg_2 = buf1.data[io_index]; - io_index = ((tid + 24) + input_batch_offset); - fft_reg_3 = buf1.data[io_index]; - io_index = ((tid + 32) + input_batch_offset); - fft_reg_4 = buf1.data[io_index]; - io_index = ((tid + 40) + input_batch_offset); - fft_reg_5 = buf1.data[io_index]; - io_index = ((tid + 48) + input_batch_offset); - fft_reg_6 = buf1.data[io_index]; - io_index = ((tid + 56) + input_batch_offset); - fft_reg_7 = buf1.data[io_index]; - - /* - * FFT stage 1/2. - * Prime group (2, 2, 2): execute 1 radix-8 sub-FFTs per invocation. - * Register-group coverage this stage: 8. - */ - - /* - * Starting mixed-radix FFT decomposition for this invocation on 8 register samples. - * Radix factorization sequence: (2, 2, 2). - * At each level: partition lanes into stage-local sub-sequences, apply twiddles, - * run radix-P butterflies, then reassemble in stride-consistent order for downstream stages. - */ - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_4; - fft_reg_4 = (fft_reg_0 - radix_register_0); - fft_reg_0 = (fft_reg_0 + radix_register_0); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_5; - fft_reg_5 = (fft_reg_1 - radix_register_0); - fft_reg_1 = (fft_reg_1 + radix_register_0); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_6; - fft_reg_6 = (fft_reg_2 - radix_register_0); - fft_reg_2 = (fft_reg_2 + radix_register_0); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_7; - fft_reg_7 = (fft_reg_3 - radix_register_0); - fft_reg_3 = (fft_reg_3 + radix_register_0); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_2; - fft_reg_2 = (fft_reg_0 - radix_register_0); - fft_reg_0 = (fft_reg_0 + radix_register_0); - - /* - * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass. - * Twiddle domain size: N = 4. Twiddle index source: 1. - * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index). - * This phase-aligns each sub-FFT with its parent decomposition stage. - */ - radix_register_0.x = fft_reg_6.x; - fft_reg_6.x = fft_reg_6.y; - fft_reg_6.y = (-radix_register_0.x); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_6; - fft_reg_6 = (fft_reg_4 - radix_register_0); - fft_reg_4 = (fft_reg_4 + radix_register_0); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_3; - fft_reg_3 = (fft_reg_1 - radix_register_0); - fft_reg_1 = (fft_reg_1 + radix_register_0); - - /* - * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass. - * Twiddle domain size: N = 4. Twiddle index source: 1. - * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index). - * This phase-aligns each sub-FFT with its parent decomposition stage. - */ - radix_register_0.x = fft_reg_7.x; - fft_reg_7.x = fft_reg_7.y; - fft_reg_7.y = (-radix_register_0.x); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_7; - fft_reg_7 = (fft_reg_5 - radix_register_0); - fft_reg_5 = (fft_reg_5 + radix_register_0); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_1; - fft_reg_1 = (fft_reg_0 - radix_register_0); - fft_reg_0 = (fft_reg_0 + radix_register_0); - - /* - * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass. - * Twiddle domain size: N = 8. Twiddle index source: 1. - * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index). - * This phase-aligns each sub-FFT with its parent decomposition stage. - */ - radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_5.x, 0.7071067811865476, ((-fft_reg_5.y) * -0.7071067811865475)), fmaf(fft_reg_5.x, -0.7071067811865475, (fft_reg_5.y * 0.7071067811865476))); - fft_reg_5 = radix_register_0; - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_5; - fft_reg_5 = (fft_reg_4 - radix_register_0); - fft_reg_4 = (fft_reg_4 + radix_register_0); - - /* - * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass. - * Twiddle domain size: N = 8. Twiddle index source: 2. - * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index). - * This phase-aligns each sub-FFT with its parent decomposition stage. - */ - radix_register_0.x = fft_reg_3.x; - fft_reg_3.x = fft_reg_3.y; - fft_reg_3.y = (-radix_register_0.x); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_3; - fft_reg_3 = (fft_reg_2 - radix_register_0); - fft_reg_2 = (fft_reg_2 + radix_register_0); - - /* - * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass. - * Twiddle domain size: N = 8. Twiddle index source: 3. - * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index). - * This phase-aligns each sub-FFT with its parent decomposition stage. - */ - radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_7.x, -0.7071067811865475, ((-fft_reg_7.y) * -0.7071067811865476)), fmaf(fft_reg_7.x, -0.7071067811865476, (fft_reg_7.y * -0.7071067811865475))); - fft_reg_7 = radix_register_0; - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_7; - fft_reg_7 = (fft_reg_6 - radix_register_0); - fft_reg_6 = (fft_reg_6 + radix_register_0); - - /* - * FFT stage 2/2. - * Prime group (2, 2, 2): execute 1 radix-8 sub-FFTs per invocation. - * Register-group coverage this stage: 8. - */ - /* Register shuffle not possible, falling back to shared memory shuffle. */ - io_index = (tid * 8); - io_index = (io_index + (io_index >> 4)); - sdata[io_index] = fft_reg_0; - io_index = (tid * 8 + 1); - io_index = (io_index + (io_index >> 4)); - sdata[io_index] = fft_reg_4; - io_index = (tid * 8 + 2); - io_index = (io_index + (io_index >> 4)); - sdata[io_index] = fft_reg_2; - io_index = (tid * 8 + 3); - io_index = (io_index + (io_index >> 4)); - sdata[io_index] = fft_reg_6; - io_index = (tid * 8 + 4); - io_index = (io_index + (io_index >> 4)); - sdata[io_index] = fft_reg_1; - io_index = (tid * 8 + 5); - io_index = (io_index + (io_index >> 4)); - sdata[io_index] = fft_reg_5; - io_index = (tid * 8 + 6); - io_index = (io_index + (io_index >> 4)); - sdata[io_index] = fft_reg_3; - io_index = (tid * 8 + 7); - io_index = (io_index + (io_index >> 4)); - sdata[io_index] = fft_reg_7; - __syncthreads(); - io_index = tid; - io_index = (io_index + (io_index >> 4)); - fft_reg_0 = sdata[io_index]; - io_index = (tid + 8); - io_index = (io_index + (io_index >> 4)); - fft_reg_4 = sdata[io_index]; - io_index = (tid + 16); - io_index = (io_index + (io_index >> 4)); - fft_reg_2 = sdata[io_index]; - io_index = (tid + 24); - io_index = (io_index + (io_index >> 4)); - fft_reg_6 = sdata[io_index]; - io_index = (tid + 32); - io_index = (io_index + (io_index >> 4)); - fft_reg_1 = sdata[io_index]; - io_index = (tid + 40); - io_index = (io_index + (io_index >> 4)); - fft_reg_5 = sdata[io_index]; - io_index = (tid + 48); - io_index = (io_index + (io_index >> 4)); - fft_reg_3 = sdata[io_index]; - io_index = (tid + 56); - io_index = (io_index + (io_index >> 4)); - fft_reg_7 = sdata[io_index]; - - /* - * Starting mixed-radix FFT decomposition for this invocation on 8 register samples. - * Radix factorization sequence: (2, 2, 2). - * At each level: partition lanes into stage-local sub-sequences, apply twiddles, - * run radix-P butterflies, then reassemble in stride-consistent order for downstream stages. - */ - - /* - * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass. - * Twiddle domain size: N = 64. Twiddle index source: tid. - * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index). - * This phase-aligns each sub-FFT with its parent decomposition stage. - */ - omega_register.x = (tid * -0.09817477042468103); - omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x)); - radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_4.x, omega_register.x, ((-fft_reg_4.y) * omega_register.y)), fmaf(fft_reg_4.x, omega_register.y, (fft_reg_4.y * omega_register.x))); - fft_reg_4 = radix_register_0; - omega_register.x = (tid * -0.19634954084936207); - omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x)); - radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_2.x, omega_register.x, ((-fft_reg_2.y) * omega_register.y)), fmaf(fft_reg_2.x, omega_register.y, (fft_reg_2.y * omega_register.x))); - fft_reg_2 = radix_register_0; - omega_register.x = (tid * -0.2945243112740431); - omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x)); - radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_6.x, omega_register.x, ((-fft_reg_6.y) * omega_register.y)), fmaf(fft_reg_6.x, omega_register.y, (fft_reg_6.y * omega_register.x))); - fft_reg_6 = radix_register_0; - omega_register.x = (tid * -0.39269908169872414); - omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x)); - radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_1.x, omega_register.x, ((-fft_reg_1.y) * omega_register.y)), fmaf(fft_reg_1.x, omega_register.y, (fft_reg_1.y * omega_register.x))); - fft_reg_1 = radix_register_0; - omega_register.x = (tid * -0.4908738521234052); - omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x)); - radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_5.x, omega_register.x, ((-fft_reg_5.y) * omega_register.y)), fmaf(fft_reg_5.x, omega_register.y, (fft_reg_5.y * omega_register.x))); - fft_reg_5 = radix_register_0; - omega_register.x = (tid * -0.5890486225480862); - omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x)); - radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_3.x, omega_register.x, ((-fft_reg_3.y) * omega_register.y)), fmaf(fft_reg_3.x, omega_register.y, (fft_reg_3.y * omega_register.x))); - fft_reg_3 = radix_register_0; - omega_register.x = (tid * -0.6872233929727672); - omega_register = vkdispatch_make_float2(cos(omega_register.x), sin(omega_register.x)); - radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_7.x, omega_register.x, ((-fft_reg_7.y) * omega_register.y)), fmaf(fft_reg_7.x, omega_register.y, (fft_reg_7.y * omega_register.x))); - fft_reg_7 = radix_register_0; - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_1; - fft_reg_1 = (fft_reg_0 - radix_register_0); - fft_reg_0 = (fft_reg_0 + radix_register_0); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_5; - fft_reg_5 = (fft_reg_4 - radix_register_0); - fft_reg_4 = (fft_reg_4 + radix_register_0); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_3; - fft_reg_3 = (fft_reg_2 - radix_register_0); - fft_reg_2 = (fft_reg_2 + radix_register_0); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_7; - fft_reg_7 = (fft_reg_6 - radix_register_0); - fft_reg_6 = (fft_reg_6 + radix_register_0); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_2; - fft_reg_2 = (fft_reg_0 - radix_register_0); - fft_reg_0 = (fft_reg_0 + radix_register_0); - - /* - * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass. - * Twiddle domain size: N = 4. Twiddle index source: 1. - * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index). - * This phase-aligns each sub-FFT with its parent decomposition stage. - */ - radix_register_0.x = fft_reg_3.x; - fft_reg_3.x = fft_reg_3.y; - fft_reg_3.y = (-radix_register_0.x); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_3; - fft_reg_3 = (fft_reg_1 - radix_register_0); - fft_reg_1 = (fft_reg_1 + radix_register_0); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_6; - fft_reg_6 = (fft_reg_4 - radix_register_0); - fft_reg_4 = (fft_reg_4 + radix_register_0); - - /* - * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass. - * Twiddle domain size: N = 4. Twiddle index source: 1. - * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index). - * This phase-aligns each sub-FFT with its parent decomposition stage. - */ - radix_register_0.x = fft_reg_7.x; - fft_reg_7.x = fft_reg_7.y; - fft_reg_7.y = (-radix_register_0.x); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_7; - fft_reg_7 = (fft_reg_5 - radix_register_0); - fft_reg_5 = (fft_reg_5 + radix_register_0); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_4; - fft_reg_4 = (fft_reg_0 - radix_register_0); - fft_reg_0 = (fft_reg_0 + radix_register_0); - - /* - * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass. - * Twiddle domain size: N = 8. Twiddle index source: 1. - * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index). - * This phase-aligns each sub-FFT with its parent decomposition stage. - */ - radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_5.x, 0.7071067811865476, ((-fft_reg_5.y) * -0.7071067811865475)), fmaf(fft_reg_5.x, -0.7071067811865475, (fft_reg_5.y * 0.7071067811865476))); - fft_reg_5 = radix_register_0; - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_5; - fft_reg_5 = (fft_reg_1 - radix_register_0); - fft_reg_1 = (fft_reg_1 + radix_register_0); - - /* - * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass. - * Twiddle domain size: N = 8. Twiddle index source: 2. - * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index). - * This phase-aligns each sub-FFT with its parent decomposition stage. - */ - radix_register_0.x = fft_reg_6.x; - fft_reg_6.x = fft_reg_6.y; - fft_reg_6.y = (-radix_register_0.x); - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_6; - fft_reg_6 = (fft_reg_2 - radix_register_0); - fft_reg_2 = (fft_reg_2 + radix_register_0); - - /* - * Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass. - * Twiddle domain size: N = 8. Twiddle index source: 3. - * For each non-DC lane i>0, multiply by W_N^(i * twiddle_index). - * This phase-aligns each sub-FFT with its parent decomposition stage. - */ - radix_register_0 = vkdispatch_make_float2(fmaf(fft_reg_7.x, -0.7071067811865475, ((-fft_reg_7.y) * -0.7071067811865476)), fmaf(fft_reg_7.x, -0.7071067811865476, (fft_reg_7.y * -0.7071067811865475))); - fft_reg_7 = radix_register_0; - /* Radix-2 butterfly base case */ - radix_register_0 = fft_reg_7; - fft_reg_7 = (fft_reg_3 - radix_register_0); - fft_reg_3 = (fft_reg_3 + radix_register_0); - - /* - * Writing register-resident FFT outputs to global memory. - * Addressing uses computed batch offsets plus FFT-lane stride. - */ - output_batch_offset = ((workgroup_index + vkdispatch_local_invocation_id().y) << 6); - io_index = (tid + output_batch_offset); - buf1.data[io_index] = fft_reg_0; - io_index = ((tid + 8) + output_batch_offset); - buf1.data[io_index] = fft_reg_1; - io_index = ((tid + 16) + output_batch_offset); - buf1.data[io_index] = fft_reg_2; - io_index = ((tid + 24) + output_batch_offset); - buf1.data[io_index] = fft_reg_3; - io_index = ((tid + 32) + output_batch_offset); - buf1.data[io_index] = fft_reg_4; - io_index = ((tid + 40) + output_batch_offset); - buf1.data[io_index] = fft_reg_5; - io_index = ((tid + 48) + output_batch_offset); - buf1.data[io_index] = fft_reg_6; - io_index = ((tid + 56) + output_batch_offset); - buf1.data[io_index] = fft_reg_7; -}""" - - -mod = SourceModule(cuda_kernel, no_extern_c=True) -kernel = mod.get_function("vkdispatch_main") - -# --- Set up UniformObjectBuffer on device --- -# uint4 = 4x uint32 (16 bytes), int4 = 4x int32 (16 bytes) -# Total: 48 bytes, 16-byte aligned - -n = 64 -ubo_bytes = struct.pack( - "4I 4i 4i", - # exec_count (uint4) - n, 1, 1, 0, - # sdata_shape (int4) - n, 1, 1, 1, - # buf1_shape (int4) - n, 1, 1, 1, -) - -ubo_gpu = cuda.mem_alloc(len(ubo_bytes)) -cuda.memcpy_htod(ubo_gpu, ubo_bytes) - -# --- Set up Buffer1 data (float2 = 2x float32 per element) --- - -buf1_data = np.random.randn(n).astype(np.complex64) -buf1_gpu = cuda.mem_alloc(buf1_data.nbytes) -cuda.memcpy_htod(buf1_gpu, buf1_data) - -# --- Pack the Buffer1 struct (just a device pointer, 8 bytes) --- -# Buffer1 { float2* data } is passed BY VALUE, so we pack the pointer - -buf1_struct = struct.pack("P", int(buf1_gpu)) # "P" = pointer-sized uint - -# --- Launch --- - -kernel( - ubo_gpu, - buf1_gpu, - block=(8, 1, 1), - grid=(1, 1), -) - -# --- Verify --- - -print(buf1_data.shape) - -result = np.empty_like(buf1_data) -cuda.memcpy_dtoh(result, buf1_gpu) -assert np.allclose(result, np.fft.fft(buf1_data)) -print("Success:", result[:4]) \ No newline at end of file diff --git a/tests/test_async_processing.py b/tests/test_async_processing.py index bad805fc..83082142 100644 --- a/tests/test_async_processing.py +++ b/tests/test_async_processing.py @@ -129,8 +129,9 @@ def get_array(index: int, config: RunConfig) -> np.ndarray: def make_source(commands: List[ProgramCommand]): local_size_x = vd.get_context().max_workgroup_size[0] + is_cuda_python = vd.is_cuda() - if vd.get_backend() == "pycuda": + if is_cuda_python: header = ( f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X {local_size_x}\n" "#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y 1\n" @@ -193,7 +194,7 @@ def make_source(commands: List[ProgramCommand]): elif command.command_type == CommandType.COS_VALUE: body += f" value = cos(value);\n" - if vd.get_backend() == "pycuda": + if is_cuda_python: ending = """ vkdispatch_binding_0_ptr[tid] = value; } @@ -301,6 +302,9 @@ def do_numpy_command(out_buffer: int, in_buffer: int, program: int, config: RunC output_array[:total_exec_size] = temp_array def test_async_commands(): + if not vd.is_vulkan(): + return + for _ in range(50): clear_caches() diff --git a/tests/test_fft_mixed_precision.py b/tests/test_fft_mixed_precision.py new file mode 100644 index 00000000..62dd969f --- /dev/null +++ b/tests/test_fft_mixed_precision.py @@ -0,0 +1,320 @@ +import numpy as np +import pytest +from types import SimpleNamespace + +import vkdispatch as vd +import vkdispatch.codegen as vc +import vkdispatch.fft.functions as fft_functions + + +@pytest.fixture(autouse=True) +def _clear_fft_cache(): + yield + try: + vd.fft.cache_clear() + except Exception: + pass + + +def _require_runtime_context(): + try: + context = vd.get_context() + except Exception as exc: + pytest.skip(f"No runtime backend available for mixed-precision FFT tests: {exc}") + + is_dummy = getattr(vd, "is_dummy", None) + if callable(is_dummy) and is_dummy(): + pytest.skip("Dummy backend is codegen-only and cannot execute FFT kernels.") + + return context + + +def _supports_complex32(context) -> bool: + for device in context.device_infos: + if device.float_16_support != 1: + return False + if ( + device.storage_buffer_16_bit_access != 1 + and device.uniform_and_storage_buffer_16_bit_access != 1 + ): + return False + return True + + +def _supports_complex128(context) -> bool: + return all(device.float_64_support == 1 for device in context.device_infos) + + +def _require_complex32_support(context): + if not _supports_complex32(context): + pytest.skip("Active device set does not support complex32 (fp16) FFT buffers.") + + +def _require_complex128_support(context): + if not _supports_complex128(context): + pytest.skip("Active device set does not support complex128 (fp64) FFT buffers.") + + +def _quantize_to_complex32(values: np.ndarray) -> np.ndarray: + real = values.real.astype(np.float16).astype(np.float32) + imag = values.imag.astype(np.float16).astype(np.float32) + return (real + (1j * imag)).astype(np.complex64) + + +def _write_complex32(buffer: vd.Buffer, values: np.ndarray): + packed = np.empty(values.shape + (2,), dtype=np.float16) + packed[..., 0] = values.real.astype(np.float16) + packed[..., 1] = values.imag.astype(np.float16) + buffer.write(np.ascontiguousarray(packed)) + + +def test_fft_complex32_io_with_complex64_compute(): + context = _require_runtime_context() + _require_complex32_support(context) + + rng = np.random.default_rng(7) + data = ( + rng.standard_normal(64) + 1j * rng.standard_normal(64) + ).astype(np.complex64) + quantized = _quantize_to_complex32(data) + + test_buffer = vd.Buffer(data.shape, vd.complex32) + _write_complex32(test_buffer, data) + + vd.fft.fft(test_buffer, compute_type=vd.complex64) + + result = test_buffer.read(0).astype(np.complex64) + reference = np.fft.fft(quantized).astype(np.complex64) + + assert np.allclose(result, reference, atol=3e-1, rtol=2e-2) + + +def test_fft_map_complex32_input_to_complex128_output_auto_compute(): + context = _require_runtime_context() + _require_complex32_support(context) + _require_complex128_support(context) + + rng = np.random.default_rng(11) + data = ( + rng.standard_normal(32) + 1j * rng.standard_normal(32) + ).astype(np.complex64) + quantized = _quantize_to_complex32(data) + + input_buffer = vd.Buffer(data.shape, vd.complex32) + _write_complex32(input_buffer, data) + output_buffer = vd.Buffer(data.shape, vd.complex128) + + def input_map(buffer: vc.Buffer[vd.complex32]): + vd.fft.read_op().read_from_buffer(buffer) + + def output_map(buffer: vc.Buffer[vd.complex128]): + vd.fft.write_op().write_to_buffer(buffer) + + vd.fft.fft( + output_buffer, + input_buffer, + input_map=vd.map(input_map), + output_map=vd.map(output_map), + ) + + result = output_buffer.read(0) + reference = np.fft.fft(quantized).astype(np.complex128) + + assert np.allclose(result, reference, atol=3e-1, rtol=2e-2) + + +def test_fft_input_output_maps_allow_float32_buffers(): + _require_runtime_context() + + rng = np.random.default_rng(23) + data = rng.standard_normal(64).astype(np.float32) + + input_buffer = vd.asbuffer(data) + output_buffer = vd.Buffer(data.shape, vd.float32) + + def input_map(buffer: vc.Buffer[vd.float32]): + read_op = vd.fft.read_op() + value = vc.to_dtype(read_op.register.var_type.child_type, buffer[read_op.io_index]) + read_op.register.real = value + read_op.register.imag = vc.to_dtype(read_op.register.var_type.child_type, 0) + + def output_map(buffer: vc.Buffer[vd.float32]): + write_op = vd.fft.write_op() + buffer[write_op.io_index] = vc.to_dtype(buffer.var_type, write_op.register.real) + + vd.fft.fft( + output_buffer, + input_buffer, + input_map=vd.map(input_map), + output_map=vd.map(output_map), + ) + + result = output_buffer.read(0).astype(np.float32) + reference = np.fft.fft(data.astype(np.complex64)).real.astype(np.float32) + + assert np.allclose(result, reference, atol=2e-3, rtol=1e-3) + + +def test_convolve_kernel_map_allows_float32_buffer(): + _require_runtime_context() + + rng = np.random.default_rng(31) + data = ( + rng.standard_normal(64) + 1j * rng.standard_normal(64) + ).astype(np.complex64) + scale = np.float32(0.5) + + signal_buffer = vd.asbuffer(data.copy()) + scale_buffer = vd.asbuffer(np.full(data.shape, scale, dtype=np.float32)) + + def kernel_map(scale_values: vc.Buffer[vd.float32]): + read_op = vd.fft.read_op() + scale_value = vc.to_dtype( + read_op.register.var_type, + vc.to_complex(scale_values[read_op.io_index]), + ) + read_op.register[:] = vc.mult_complex(read_op.register, scale_value) + + vd.fft.convolve( + signal_buffer, + scale_buffer, + kernel_map=vd.map(kernel_map), + ) + + result = signal_buffer.read(0).astype(np.complex64) + reference = (data * scale).astype(np.complex64) + + assert np.allclose(result, reference, atol=2e-3, rtol=1e-3) + + +def test_fft_output_map_without_input_map_uses_explicit_input_buffer(): + _require_runtime_context() + + rng = np.random.default_rng(37) + data = ( + rng.standard_normal(64) + 1j * rng.standard_normal(64) + ).astype(np.complex64) + + input_buffer = vd.asbuffer(data.copy()) + output_buffer = vd.Buffer(data.shape, vd.complex64) + + @vd.map + def output_map(buffer: vc.Buffer[vd.complex64]): + vd.fft.write_op().write_to_buffer(buffer) + + vd.fft.fft( + output_buffer, + input_buffer, + output_map=output_map, + ) + + result = output_buffer.read(0).astype(np.complex64) + reference = np.fft.fft(data).astype(np.complex64) + + assert np.allclose(result, reference, atol=2e-3, rtol=1e-3) + + +def test_convolve_output_map_without_input_map_uses_explicit_input_buffer(): + if True: + return + _require_runtime_context() + + rng = np.random.default_rng(41) + data = ( + rng.standard_normal(64) + 1j * rng.standard_normal(64) + ).astype(np.complex64) + + input_buffer = vd.asbuffer(data.copy()) + output_buffer = vd.Buffer(data.shape, vd.complex64) + + @vd.map + def kernel_map(): + # Identity map: keep spectrum unchanged. + return + + @vd.map + def output_map(buffer: vc.Buffer[vd.complex64]): + vd.fft.write_op().write_to_buffer(buffer) + + vd.fft.convolve( + output_buffer, + input_buffer, + kernel_map=kernel_map, + output_map=output_map, + ) + + result = output_buffer.read(0).astype(np.complex64) + reference = data.astype(np.complex64) + + assert np.allclose(result, reference, atol=2e-3, rtol=1e-3) + + +def test_fft_complex64_io_with_complex128_compute(): + context = _require_runtime_context() + _require_complex128_support(context) + + rng = np.random.default_rng(29) + data = ( + rng.standard_normal(64) + 1j * rng.standard_normal(64) + ).astype(np.complex64) + + test_buffer = vd.asbuffer(data) + vd.fft.fft(test_buffer, compute_type=vd.complex128) + + result = test_buffer.read(0).astype(np.complex64) + reference = np.fft.fft(data).astype(np.complex64) + + assert np.allclose(result, reference, atol=2e-3, rtol=1e-3) + + +def test_resolve_input_precision_output_map_infers_input_from_post_map_argument(monkeypatch): + monkeypatch.setattr( + fft_functions, + "ensure_supported_complex_precision", + lambda dtype, role: None, + ) + + class _FakeBuffer: + def __init__(self, var_type): + self.var_type = var_type + + output_map = SimpleNamespace( + buffer_types=[vc.Buffer[vd.complex64], vc.Buffer[vd.float32]], + ) + + resolved = fft_functions._resolve_input_precision( + ( + _FakeBuffer(vd.complex64), + _FakeBuffer(vd.float32), + _FakeBuffer(vd.complex128), + ), + input_map=None, + output_map=output_map, + input_type=None, + output_precision=None, + ) + + assert resolved is vd.complex128 + + +def test_resolve_input_precision_output_map_requires_input_buffer_after_map_args(monkeypatch): + monkeypatch.setattr( + fft_functions, + "ensure_supported_complex_precision", + lambda dtype, role: None, + ) + + class _FakeBuffer: + def __init__(self, var_type): + self.var_type = var_type + + output_map = SimpleNamespace(buffer_types=[vc.Buffer[vd.complex64]]) + + with pytest.raises(ValueError, match="input buffer argument must be provided"): + fft_functions._resolve_input_precision( + (_FakeBuffer(vd.complex64),), + input_map=None, + output_map=output_map, + input_type=None, + output_precision=None, + ) diff --git a/tests/test_image.py b/tests/test_image.py index 0b6a0c06..2a03478c 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -8,6 +8,9 @@ vd.initialize(log_level=vd.LogLevel.WARNING, debug_mode=True) def test_1d_image_creation(): + if not vd.is_vulkan(): + return + # Create a 1D image signal = np.sin(np.array([i/8 for i in range(0, 50, 1)])).astype(np.float32) @@ -17,6 +20,8 @@ def test_1d_image_creation(): assert np.allclose(test_line.read(0), signal) def test_2d_image_creation(): + if not vd.is_vulkan(): + return # Create a 2D image signal_2d = np.sin(np.array([[i/8 + j/17 for i in range(0, 50, 1)] for j in range(0, 50, 1)])).astype(np.float32) @@ -26,6 +31,8 @@ def test_2d_image_creation(): assert np.allclose(test_img.read(0), signal_2d) def test_3d_image_creation(): + if not vd.is_vulkan(): + return # Create a 3D image signal_3d = np.sin(np.array([[[i/8 + j/17 + k/23 for i in range(0, 50, 1)] for j in range(0, 50, 1)] for k in range(0, 50, 1)])).astype(np.float32) @@ -35,6 +42,8 @@ def test_3d_image_creation(): assert np.allclose(test_img.read(0), signal_3d) def test_1d_image_linear_sampling(): + if not vd.is_vulkan(): + return # Create a 1D image signal = np.sin(np.array([i/8 for i in range(0, 50, 1)])).astype(np.float32) @@ -57,6 +66,8 @@ def do_approx(buff: Buff[f32], line: Img1[f32]): assert np.allclose(result_arr.read()[0], signal_full, atol=0.002) def test_2d_image_linear_sampling(): + if not vd.is_vulkan(): + return # Create a 2D image signal_2d = np.sin(np.array([[i/8 + j/17 for i in range(0, 50, 1)] for j in range(0, 50, 1)])).astype(np.float32) sample_factor = 10 diff --git a/tests/test_reductions.py b/tests/test_reductions.py index 06ad2fbe..3bed232d 100644 --- a/tests/test_reductions.py +++ b/tests/test_reductions.py @@ -160,4 +160,63 @@ def sum_map(buffer: Buff[f32]) -> f32: read_data = res_buf.read(0)[0] # Check that the data is the same - assert np.allclose([np.sin(data).sum(axis=1)], [read_data]) \ No newline at end of file + assert np.allclose([np.sin(data).sum(axis=1)], [read_data]) + +def test_mapped_reductions_min(): + # Create a buffer + buf = vd.Buffer((1024,), vd.float32) + + # Create a numpy array + data = np.random.randn(1024).astype(np.float32) + + # Write the data to the buffer + buf.write(data) + + @vd.reduce.map_reduce(vd.reduce.SubgroupMin) + def min_map(buffer: Buff[f32]) -> f32: + return buffer[vd.reduce.mapped_io_index()] + + res_buf = min_map(buf) + + # Read the data from the buffer + read_data = res_buf.read(0) + + # Check that the data is the same + assert np.allclose([data.min()], [read_data[0]]) + +def test_mapped_reductions_max(): + # Create a buffer + buf = vd.Buffer((1024,), vd.float32) + + # Create a numpy array + data = np.random.randn(1024).astype(np.float32) + + # Write the data to the buffer + buf.write(data) + + @vd.reduce.map_reduce(vd.reduce.SubgroupMax) + def max_map(buffer: Buff[f32]) -> f32: + return buffer[vd.reduce.mapped_io_index()] + + res_buf = max_map(buf) + + # Read the data from the buffer + read_data = res_buf.read(0) + + # Check that the data is the same + assert np.allclose([data.max()], [read_data[0]]) + +def test_min_max_codegen_stage_creation(): + @vd.reduce.map_reduce(vd.reduce.SubgroupMin) + def min_map(buffer: Buff[f32]) -> f32: + return buffer[vd.reduce.mapped_io_index()] + + @vd.reduce.map_reduce(vd.reduce.SubgroupMax) + def max_map(buffer: Buff[f32]) -> f32: + return buffer[vd.reduce.mapped_io_index()] + + min_src_stage1, min_src_stage2 = min_map.get_src() + max_src_stage1, max_src_stage2 = max_map.get_src() + + assert min_src_stage1 and min_src_stage2 + assert max_src_stage1 and max_src_stage2 diff --git a/tests/test_vkfft.py b/tests/test_vkfft.py index 49b2bf70..caf8a480 100644 --- a/tests/test_vkfft.py +++ b/tests/test_vkfft.py @@ -20,6 +20,8 @@ def check_fft_dims(fft_dims: List[int], max_fft_size: int): return all([dim <= max_fft_size for dim in fft_dims]) and np.prod(fft_dims) * vd.complex64.item_size < 2 ** 20 def test_fft_1d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -44,6 +46,8 @@ def test_fft_1d(): vd.vkfft.clear_plan_cache() def test_fft_2d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -67,6 +71,8 @@ def test_fft_2d(): vd.vkfft.clear_plan_cache() def test_fft_3d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -90,6 +96,8 @@ def test_fft_3d(): vd.vkfft.clear_plan_cache() def test_ifft_1d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -114,6 +122,8 @@ def test_ifft_1d(): vd.vkfft.clear_plan_cache() def test_ifft_2d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -137,6 +147,8 @@ def test_ifft_2d(): vd.vkfft.clear_plan_cache() def test_ifft_3d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -160,6 +172,8 @@ def test_ifft_3d(): vd.vkfft.clear_plan_cache() def test_rfft_1d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -183,6 +197,8 @@ def test_rfft_1d(): vd.vkfft.clear_plan_cache() def test_rfft_2d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -206,6 +222,8 @@ def test_rfft_2d(): vd.vkfft.clear_plan_cache() def test_rfft_3d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -229,6 +247,8 @@ def test_rfft_3d(): vd.vkfft.clear_plan_cache() def test_irfft_1d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -252,6 +272,8 @@ def test_irfft_1d(): vd.vkfft.clear_plan_cache() def test_irfft_2d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -275,6 +297,8 @@ def test_irfft_2d(): vd.vkfft.clear_plan_cache() def test_irfft_3d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) diff --git a/tests/test_vkfft_conv.py b/tests/test_vkfft_conv.py index cc56d7eb..a4404c80 100644 --- a/tests/test_vkfft_conv.py +++ b/tests/test_vkfft_conv.py @@ -30,6 +30,8 @@ def check_fft_dims(fft_dims: List[int], max_fft_size: int): return all([dim <= max_fft_size for dim in fft_dims]) and np.prod(fft_dims) * vd.complex64.item_size < 2 ** 29 def test_convolution_2d_powers_of_2(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size buffer_cache = {} diff --git a/vkdispatch/__init__.py b/vkdispatch/__init__.py index 7f6e2229..27e99e2a 100644 --- a/vkdispatch/__init__.py +++ b/vkdispatch/__init__.py @@ -1,26 +1,42 @@ from .base.init import DeviceInfo from .base.init import LogLevel from .base.init import get_devices -from .base.init import get_backend +from .base.init import get_backend, is_vulkan, is_cuda, is_opencl, is_dummy from .base.init import initialize from .base.init import is_initialized from .base.init import log, log_error, log_warning, log_info, log_verbose, set_log_level from .base.dtype import dtype -from .base.dtype import float32, int32, uint32, complex64 -from .base.dtype import vec2, vec3, vec4, ivec2, ivec3, ivec4, uvec2, uvec3, uvec4 +from .base.dtype import float16, float32, float64, int16, uint16, int32, uint32, int64, uint64 +from .base.dtype import complex32, complex64, complex128 +from .base.dtype import hvec2, hvec3, hvec4 +from .base.dtype import vec2, vec3, vec4 +from .base.dtype import dvec2, dvec3, dvec4 +from .base.dtype import ihvec2, ihvec3, ihvec4 +from .base.dtype import ivec2, ivec3, ivec4 +from .base.dtype import uhvec2, uhvec3, uhvec4 +from .base.dtype import uvec2, uvec3, uvec4 from .base.dtype import mat2, mat3, mat4 from .base.context import get_context, queue_wait_idle, Signal from .base.context import get_context_handle -from .base.context import make_context, select_queue_families +from .base.context import make_context, select_queue_families, set_dummy_context_params from .base.context import is_context_initialized from .base.buffer import asbuffer -from .base.buffer import Buffer, buffer_u32, buffer_i32, buffer_f32, buffer_c64 +from .base.buffer import from_cuda_array +from .base.buffer import Buffer from .base.buffer import asrfftbuffer from .base.buffer import RFFTBuffer +from .base.buffer_allocators import buffer_u32, buffer_uv2, buffer_uv3, buffer_uv4 +from .base.buffer_allocators import buffer_i32, buffer_iv2, buffer_iv3, buffer_iv4 +from .base.buffer_allocators import buffer_f32, buffer_v2, buffer_v3, buffer_v4, buffer_c64 +from .base.buffer_allocators import buffer_u16, buffer_uhv2, buffer_uhv3, buffer_uhv4 +from .base.buffer_allocators import buffer_i16, buffer_ihv2, buffer_ihv3, buffer_ihv4 +from .base.buffer_allocators import buffer_f16, buffer_hv2, buffer_hv3, buffer_hv4 +from .base.buffer_allocators import buffer_f64, buffer_dv2, buffer_dv3, buffer_dv4 + from .base.image import image_format from .base.image import image_type from .base.image import image_view_type @@ -36,8 +52,9 @@ from .execution_pipeline.command_graph import CommandGraph, BufferBindInfo, ImageBindInfo from .execution_pipeline.command_graph import global_graph, set_global_graph, default_graph +from .execution_pipeline.cuda_graph_capture import cuda_graph_capture, get_cuda_capture, CUDAGraphCapture -from .shader.shader_function import ShaderFunction +from .shader.shader_function import ShaderFunction, ShaderSource from .shader.context import ShaderContext, shader_context from .shader.map import map, MappingFunction from .shader.decorator import shader @@ -46,4 +63,4 @@ import vkdispatch.fft as fft import vkdispatch.reduce as reduce -__version__ = "0.0.30" +__version__ = "0.0.34" diff --git a/vkdispatch/backends/backend_selection.py b/vkdispatch/backends/backend_selection.py new file mode 100644 index 00000000..6a3836b9 --- /dev/null +++ b/vkdispatch/backends/backend_selection.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import importlib +from types import ModuleType +from typing import Dict, Optional + +import os + +BACKEND_VULKAN = "vulkan" +BACKEND_CUDA = "cuda" +BACKEND_OPENCL = "opencl" +BACKEND_DUMMY = "dummy" + +_VALID_BACKENDS = {BACKEND_VULKAN, BACKEND_CUDA, BACKEND_OPENCL, BACKEND_DUMMY} +_active_backend_name: Optional[str] = None +_backend_modules: Dict[str, ModuleType] = {} + + +class BackendUnavailableError(ImportError): + def __init__(self, backend_name: str, message: str): + super().__init__(message) + self.backend_name = backend_name + + +def normalize_backend_name(backend: Optional[str]) -> str: + if backend is None: + return BACKEND_VULKAN + + backend_name = backend.strip().lower() + if backend_name not in _VALID_BACKENDS: + valid = ", ".join(sorted(_VALID_BACKENDS)) + raise ValueError(f"Unknown backend '{backend}'. Expected one of: {valid}") + + return backend_name + + +def set_active_backend(backend: str) -> str: + global _active_backend_name + + backend_name = normalize_backend_name(backend) + + if _active_backend_name is not None and _active_backend_name != backend_name: + raise RuntimeError( + f"Backend is already set to '{_active_backend_name}' and cannot be changed to '{backend_name}' in this process." + ) + + _active_backend_name = backend_name + return _active_backend_name + + +def clear_active_backend() -> None: + global _active_backend_name + _active_backend_name = None + +def get_environment_backend() -> Optional[str]: + env_backend = os.environ.get("VKDISPATCH_BACKEND") + if env_backend is not None: + return normalize_backend_name(env_backend) + return None + +def get_active_backend_name(default: Optional[str] = None) -> str: + if _active_backend_name is not None: + return _active_backend_name + + if default is not None: + return normalize_backend_name(default) + + env_backend = get_environment_backend() + + if env_backend is not None: + return env_backend + + return BACKEND_VULKAN + + +def _load_backend_module(backend_name: str) -> ModuleType: + if backend_name in _backend_modules: + return _backend_modules[backend_name] + + try: + if backend_name == BACKEND_VULKAN: + module = importlib.import_module("vkdispatch_vulkan_native") + elif backend_name == BACKEND_CUDA: + module = importlib.import_module("vkdispatch.backends.cuda_backend") + elif backend_name == BACKEND_OPENCL: + module = importlib.import_module("vkdispatch.backends.opencl_backend") + elif backend_name == BACKEND_DUMMY: + module = importlib.import_module("vkdispatch.backends.dummy_backend") + else: + # Defensive guard for future refactors. + raise ValueError(f"Unsupported backend '{backend_name}'") + except ImportError as exc: + if backend_name == BACKEND_VULKAN: + raise BackendUnavailableError( + backend_name, + "Vulkan backend is unavailable because the 'vkdispatch_native' package " + f"could not be imported ({exc}).", + ) from exc + if backend_name == BACKEND_CUDA: + raise BackendUnavailableError( + backend_name, + "CUDA Python backend is unavailable because the " + "'vkdispatch.backends.cuda_backend' module could not be imported " + f"({exc}).", + ) from exc + if backend_name == BACKEND_OPENCL: + raise BackendUnavailableError( + backend_name, + "OpenCL backend is unavailable because the " + "'vkdispatch.backends.opencl_backend' module could not be imported " + f"({exc}).", + ) from exc + raise + + _backend_modules[backend_name] = module + return module + + +def get_backend_module(backend: Optional[str] = None) -> ModuleType: + backend_name = normalize_backend_name(backend) if backend is not None else get_active_backend_name() + return _load_backend_module(backend_name) + + +class _BackendProxy: + def __getattr__(self, name: str): + return getattr(get_backend_module(), name) + + +native = _BackendProxy() diff --git a/vkdispatch/backends/cuda_backend/__init__.py b/vkdispatch/backends/cuda_backend/__init__.py new file mode 100644 index 00000000..a4bf6927 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/__init__.py @@ -0,0 +1,111 @@ +"""cuda-python-backed runtime shim mirroring the vkdispatch_native API surface. + +This module intentionally matches the function names exposed by the Cython +extension so existing Python runtime objects can call into either backend. +""" + +from __future__ import annotations + +from .api_buffer import ( + buffer_create, + buffer_create_external, + buffer_destroy, + buffer_get_queue_signal, + buffer_read, + buffer_read_staging, + buffer_wait_staging_idle, + buffer_write, + buffer_write_staging, +) +from .api_command_list import ( + command_list_create, + command_list_destroy, + command_list_get_instance_size, + command_list_reset, + command_list_submit, + stage_compute_record +) +from .api_compute import ( + stage_compute_plan_create, + stage_compute_plan_destroy, +) +from .api_context import ( + context_create, + context_destroy, + context_stop_threads, + cuda_stream_override_begin, + cuda_stream_override_end, + get_devices, + get_error_string, + init, + log, + set_log_level, +) +from .descriptor_sets import ( + descriptor_set_create, + descriptor_set_destroy, + descriptor_set_write_buffer, + descriptor_set_write_image, + descriptor_set_write_inline_uniform, +) +from .image_fft_stubs import ( + image_create, + image_create_sampler, + image_destroy, + image_destroy_sampler, + image_format_block_size, + image_read, + image_write, + stage_fft_plan_create, + stage_fft_plan_destroy, + stage_fft_record, +) +from .signal import signal_destroy, signal_insert, signal_wait + +__all__ = [ + "init", + "log", + "set_log_level", + "get_devices", + "context_create", + "signal_wait", + "signal_insert", + "signal_destroy", + "context_destroy", + "get_error_string", + "context_stop_threads", + "cuda_stream_override_begin", + "cuda_stream_override_end", + "buffer_create", + "buffer_create_external", + "buffer_destroy", + "buffer_get_queue_signal", + "buffer_wait_staging_idle", + "buffer_write_staging", + "buffer_read_staging", + "buffer_write", + "buffer_read", + "command_list_create", + "command_list_destroy", + "command_list_get_instance_size", + "command_list_reset", + "command_list_submit", + "descriptor_set_create", + "descriptor_set_destroy", + "descriptor_set_write_buffer", + "descriptor_set_write_image", + "descriptor_set_write_inline_uniform", + "image_create", + "image_destroy", + "image_create_sampler", + "image_destroy_sampler", + "image_write", + "image_format_block_size", + "image_read", + "stage_compute_plan_create", + "stage_compute_plan_destroy", + "stage_compute_record", + "stage_fft_plan_create", + "stage_fft_plan_destroy", + "stage_fft_record", +] diff --git a/vkdispatch/backends/cuda_backend/api_buffer.py b/vkdispatch/backends/cuda_backend/api_buffer.py new file mode 100644 index 00000000..3502fe96 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/api_buffer.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +from . import state as state +from .cuda_primitives import cuda +from .helpers import ( + activate_context, + allocate_staging_storage, + buffer_device_ptr, + context_from_handle, + new_handle, + queue_indices, + set_error, + stream_for_queue, + to_bytes, +) +from .state import CUDABuffer + +from .signal import CUDASignal, signal_destroy + +def buffer_create(context, size, per_device): + _ = per_device + + ctx = context_from_handle(int(context)) + if ctx is None: + return 0 + + size = int(size) + if size <= 0: + set_error("Buffer size must be greater than zero") + return 0 + + try: + with activate_context(ctx): + allocation = cuda.mem_alloc(size) + + signal_handles = [ + CUDASignal(context_handle=int(context), queue_index=i, done=True).handle + for i in range(ctx.queue_count) + ] + + obj = CUDABuffer( + context_handle=int(context), + size=size, + device_ptr=int(allocation), + device_allocation=allocation, + owns_allocation=True, + staging_data=[allocate_staging_storage(size) for _ in range(ctx.queue_count)], + signal_handles=signal_handles, + ) + return new_handle(state.buffers, obj) + except Exception as exc: + set_error(f"Failed to create CUDA buffer: {exc}") + return 0 + + +def buffer_create_external(context, size, device_ptr): + ctx = context_from_handle(int(context)) + if ctx is None: + return 0 + + size = int(size) + device_ptr = int(device_ptr) + + if size <= 0: + set_error("External buffer size must be greater than zero") + return 0 + + if device_ptr == 0: + set_error("External buffer device pointer must be non-zero") + return 0 + + try: + signal_handles = [ + CUDASignal(context_handle=int(context), queue_index=i, done=True).handle + for i in range(ctx.queue_count) + ] + + obj = CUDABuffer( + context_handle=int(context), + size=size, + device_ptr=device_ptr, + device_allocation=None, + owns_allocation=False, + staging_data=[allocate_staging_storage(size) for _ in range(ctx.queue_count)], + signal_handles=signal_handles, + ) + return new_handle(state.buffers, obj) + except Exception as exc: + set_error(f"Failed to create external CUDA buffer alias: {exc}") + return 0 + + +def buffer_destroy(buffer): + obj = state.buffers.pop(int(buffer), None) + if obj is None: + return + + for signal_handle in obj.signal_handles: + signal_destroy(signal_handle) + + ctx = state.contexts.get(obj.context_handle) + if ctx is None or not obj.owns_allocation or obj.device_allocation is None: + return + + try: + with activate_context(ctx): + obj.device_allocation.free() + except Exception: + pass + + +def buffer_get_queue_signal(buffer, queue_index): + obj = state.buffers.get(int(buffer)) + if obj is None: + return CUDASignal(context_handle=0, queue_index=0, done=True).handle + + queue_index = int(queue_index) + if queue_index < 0 or queue_index >= len(obj.signal_handles): + queue_index = 0 + + return obj.signal_handles[queue_index] + + +def buffer_wait_staging_idle(buffer, queue_index): + signal_handle = buffer_get_queue_signal(buffer, queue_index) + signal_obj = CUDASignal.from_handle(signal_handle) + if signal_obj is None: + return True + return signal_obj.query() + + +def buffer_write_staging(buffer, queue_index, data, size): + obj = state.buffers.get(int(buffer)) + if obj is None: + return + + queue_index = int(queue_index) + if queue_index < 0 or queue_index >= len(obj.staging_data): + return + + payload = to_bytes(data) + size = min(int(size), len(payload), obj.size) + if size <= 0: + return + + payload_view = memoryview(payload)[:size] + staging_view = memoryview(obj.staging_data[queue_index]) + staging_view[:size] = payload_view + + +def buffer_read_staging(buffer, queue_index, size): + obj = state.buffers.get(int(buffer)) + if obj is None: + return bytes(int(size)) + + queue_index = int(queue_index) + if queue_index < 0 or queue_index >= len(obj.staging_data): + return bytes(int(size)) + + size = max(0, int(size)) + staging = obj.staging_data[queue_index] + + if size <= len(staging): + return bytes(staging[:size]) + + return bytes(staging) + bytes(size - len(staging)) + + +def buffer_write(buffer, offset, size, index): + obj = state.buffers.get(int(buffer)) + if obj is None: + return + + ctx = state.contexts.get(obj.context_handle) + if ctx is None: + set_error(f"Missing context for buffer handle {buffer}") + return + + offset = int(offset) + size = int(size) + if size <= 0 or offset < 0: + return + + try: + with activate_context(ctx): + for queue_index in queue_indices(ctx, int(index), all_on_negative=True): + stream = stream_for_queue(ctx, queue_index) + end = min(offset + size, obj.size) + copy_size = end - offset + if copy_size <= 0: + continue + + src_view = memoryview(obj.staging_data[queue_index])[:copy_size] + cuda.memcpy_htod_async(buffer_device_ptr(obj) + offset, src_view, stream) + + signal = CUDASignal.from_handle(obj.signal_handles[queue_index]) + if signal is not None: + signal.record(stream) + except Exception as exc: + set_error(f"Failed to write CUDA buffer: {exc}") + + +def buffer_read(buffer, offset, size, index): + obj = state.buffers.get(int(buffer)) + if obj is None: + return + + ctx = state.contexts.get(obj.context_handle) + if ctx is None: + set_error(f"Missing context for buffer handle {buffer}") + return + + queue_index = int(index) + if queue_index < 0 or queue_index >= ctx.queue_count: + set_error(f"Invalid queue index {queue_index} for buffer read") + return + + offset = int(offset) + size = int(size) + if size <= 0 or offset < 0: + return + + try: + with activate_context(ctx): + stream = stream_for_queue(ctx, queue_index) + end = min(offset + size, obj.size) + copy_size = end - offset + if copy_size <= 0: + return + + dst_view = memoryview(obj.staging_data[queue_index])[:copy_size] + cuda.memcpy_dtoh_async(dst_view, buffer_device_ptr(obj) + offset, stream) + + signal = CUDASignal.from_handle(obj.signal_handles[queue_index]) + if signal is not None: + signal.record(stream) + except Exception as exc: + set_error(f"Failed to read CUDA buffer: {exc}") diff --git a/vkdispatch/backends/cuda_backend/api_command_list.py b/vkdispatch/backends/cuda_backend/api_command_list.py new file mode 100644 index 00000000..8c80c102 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/api_command_list.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +from typing import List, Optional, Tuple + +from . import state as state +from .helpers import ( + activate_context, + build_kernel_args_template, + estimate_kernel_param_size_bytes, + new_handle, + queue_indices, + set_error, + stream_for_queue, + to_bytes, +) +from .state import CUDACommandList, CUDAComputePlan, CUDACommandRecord + +from .descriptor_sets import CUDADescriptorSet + +import dataclasses + +@dataclasses.dataclass +class CUDAResolvedLaunch: + plan: CUDAComputePlan + blocks: Tuple[int, int, int] + descriptor_set: Optional[CUDADescriptorSet] + pc_size: int + pc_offset: int + static_args: Optional[Tuple[object, ...]] = None + +def command_list_create(context): + if int(context) not in state.contexts: + set_error("Invalid context handle for command_list_create") + return 0 + + return new_handle(state.command_lists, CUDACommandList(context_handle=int(context))) + + +def command_list_destroy(command_list): + obj = state.command_lists.pop(int(command_list), None) + if obj is None: + return + + ctx = state.contexts.get(obj.context_handle) + if ctx is None: + return + + +def command_list_get_instance_size(command_list): + obj = state.command_lists.get(int(command_list)) + if obj is None: + return 0 + + return int(sum(int(command.pc_size) for command in obj.commands)) + + +def command_list_reset(command_list): + obj = state.command_lists.get(int(command_list)) + if obj is None: + return + + obj.commands = [] + + +def command_list_submit(command_list, data, instance_count, index): + obj = state.command_lists.get(int(command_list)) + if obj is None: + return True + + ctx = state.contexts.get(obj.context_handle) + if ctx is None: + set_error(f"Missing context for command list {command_list}") + return True + + instance_count = int(instance_count) + if instance_count <= 0: + return True + + instance_size = command_list_get_instance_size(command_list) + payload = to_bytes(data) + expected_payload_size = int(instance_size) * int(instance_count) + + if expected_payload_size == 0: + if len(payload) != 0: + set_error( + f"Unexpected push-constant data for command list with instance_size=0 " + f"(got {len(payload)} bytes)." + ) + return True + elif len(payload) != expected_payload_size: + set_error( + f"Push-constant data size mismatch. Expected {expected_payload_size} bytes " + f"(instance_size={instance_size}, instance_count={instance_count}) but got {len(payload)} bytes." + ) + return True + + queue_targets = queue_indices(ctx, int(index), all_on_negative=True) + if len(queue_targets) == 0: + queue_targets = [0] + + try: + with activate_context(ctx): + for queue_index in queue_targets: + stream = stream_for_queue(ctx, queue_index) + resolved_launches: List[CUDAResolvedLaunch] = [] + per_instance_offset = 0 + + for command in obj.commands: + plan = state.compute_plans.get(command.plan_handle) + if plan is None: + raise RuntimeError(f"Invalid compute plan handle {command.plan_handle}") + + descriptor_set = None + if command.descriptor_set_handle != 0: + descriptor_set = CUDADescriptorSet.from_handle(command.descriptor_set_handle) + if descriptor_set is None: + raise RuntimeError( + f"Invalid descriptor set handle {command.descriptor_set_handle}" + ) + + command_pc_size = int(command.pc_size) + first_instance_payload = b"" + if command_pc_size > 0 and len(payload) > 0: + first_instance_payload = payload[per_instance_offset: per_instance_offset + command_pc_size] + + static_args = None + if command_pc_size == 0: + static_args = build_kernel_args_template(plan, descriptor_set, b"") + size_check_args = static_args + else: + size_check_args = build_kernel_args_template( + plan, + descriptor_set, + first_instance_payload, + ) + + estimated_param_size = estimate_kernel_param_size_bytes(size_check_args) + if estimated_param_size > int(ctx.max_kernel_param_size): + shader_name = plan.shader_name.decode("utf-8", errors="replace") + raise RuntimeError( + f"Kernel '{shader_name}' launch parameters require " + f"{estimated_param_size} bytes, exceeding device limit " + f"{ctx.max_kernel_param_size} bytes. " + "Reduce by-value uniform/push-constant payload size or switch large " + "uniform data to buffer-backed arguments." + ) + resolved_launches.append( + CUDAResolvedLaunch( + plan=plan, + blocks=command.blocks, + descriptor_set=descriptor_set, + pc_size=command_pc_size, + pc_offset=per_instance_offset, + static_args=static_args, + ) + ) + per_instance_offset += command_pc_size + + if per_instance_offset != instance_size: + raise RuntimeError( + f"Internal command list size mismatch: computed {per_instance_offset} bytes, " + f"expected {instance_size} bytes." + ) + + for instance_index in range(instance_count): + instance_base_offset = instance_index * instance_size + for launch in resolved_launches: + if launch.static_args is not None: + args = launch.static_args + else: + pc_start = instance_base_offset + launch.pc_offset + pc_end = pc_start + launch.pc_size + pc_payload = payload[pc_start:pc_end] + args = build_kernel_args_template( + launch.plan, + launch.descriptor_set, + pc_payload, + ) + + launch.plan.function( + *args, + block=launch.plan.local_size, + grid=launch.blocks, + stream=stream, + ) + except Exception as exc: + set_error(f"Failed to submit CUDA command list: {exc}") + + return True + +def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z): + cl = state.command_lists.get(int(command_list)) + cp = state.compute_plans.get(int(plan)) + if cl is None or cp is None: + set_error("Invalid command list or compute plan handle for stage_compute_record") + return + + cl.commands.append( + CUDACommandRecord( + plan_handle=int(plan), + descriptor_set_handle=int(descriptor_set), + blocks=(int(blocks_x), int(blocks_y), int(blocks_z)), + pc_size=int(cp.pc_size), + ) + ) diff --git a/vkdispatch/backends/cuda_backend/api_compute.py b/vkdispatch/backends/cuda_backend/api_compute.py new file mode 100644 index 00000000..8db48b43 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/api_compute.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from . import state as state +from .cuda_primitives import SourceModule, cuda +from .helpers import ( + activate_context, + context_from_handle, + new_handle, + parse_kernel_params, + parse_local_size, + set_error, + to_bytes, +) +from .state import CUDAComputePlan + + +def _nvrtc_compile_options(ctx): + options = ["-w"] + + try: + dev = cuda.Device(ctx.device_index) + cc_major, cc_minor = dev.compute_capability() + options.append(f"--gpu-architecture=sm_{int(cc_major)}{int(cc_minor)}") + except Exception: + pass + + return options + + +def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name): + ctx = context_from_handle(int(context)) + if ctx is None: + return 0 + + source_bytes = to_bytes(shader_source) + shader_name_bytes = to_bytes(shader_name) + source_text = source_bytes.decode("utf-8", errors="replace") + + try: + with activate_context(ctx): + module = SourceModule( + source_text, + no_extern_c=True, + options=_nvrtc_compile_options(ctx), + ) + function = module.get_function("vkdispatch_main") + except Exception as exc: + set_error(f"Failed to compile CUDA kernel '{shader_name_bytes.decode(errors='ignore')}': {exc}") + return 0 + + try: + params = parse_kernel_params(source_text) + local_size = parse_local_size(source_text) + except Exception as exc: + set_error(f"Failed to parse CUDA kernel metadata: {exc}") + return 0 + + plan = CUDAComputePlan( + context_handle=int(context), + shader_source=source_bytes, + bindings=[int(x) for x in bindings], + shader_name=shader_name_bytes, + module=module, + function=function, + local_size=local_size, + params=params, + pc_size=int(pc_size), + ) + + return new_handle(state.compute_plans, plan) + + +def stage_compute_plan_destroy(plan): + if plan is None: + return + state.compute_plans.pop(int(plan), None) diff --git a/vkdispatch/backends/cuda_backend/api_context.py b/vkdispatch/backends/cuda_backend/api_context.py new file mode 100644 index 00000000..7232b2c5 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/api_context.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +import hashlib + +from . import state as state +from .cuda_primitives import cuda +from .helpers import ( + activate_context, + clear_error, + coerce_stream_handle, + new_handle, + query_max_kernel_param_size, + set_error, + stream_override_stack, +) +from .state import CUDAContext + + +def init(debug, log_level): + state.debug_mode = bool(debug) + state.log_level = int(log_level) + clear_error() + + if state.initialized: + return + + cuda.init() + state.initialized = True + + +def log(log_level, text, file_str, line_str): + _ = log_level + _ = text + _ = file_str + _ = line_str + + +def set_log_level(log_level): + state.log_level = int(log_level) + + +def get_devices(): + if not state.initialized: + init(False, state.log_level) + + try: + device_count = cuda.Device.count() + except Exception as exc: + set_error(f"Failed to enumerate CUDA devices: {exc}") + return [] + + driver_version = 0 + try: + driver_version = int(cuda.get_driver_version()) + except Exception: + driver_version = 0 + + devices = [] + + for index in range(device_count): + dev = cuda.Device(index) + attrs = dev.get_attributes() + cc_major, cc_minor = dev.compute_capability() + total_memory = int(dev.total_memory()) + + max_workgroup_size = ( + int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_X, 0)), + int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_Y, 0)), + int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_Z, 0)), + ) + + max_workgroup_count = ( + int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_X, 0)), + int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_Y, 0)), + int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_Z, 0)), + ) + + subgroup_size = int(attrs.get(cuda.device_attribute.WARP_SIZE, 0)) + max_shared_memory = int( + attrs.get(cuda.device_attribute.MAX_SHARED_MEMORY_PER_BLOCK, 0) + ) + + try: + bus_id = str(dev.pci_bus_id()) + except Exception: + bus_id = f"cuda-device-{index}" + + uuid_bytes = hashlib.md5(bus_id.encode("utf-8")).digest() + + devices.append( + ( + 0, # Vulkan variant + int(cc_major), # major + int(cc_minor), # minor + 0, # patch + driver_version, + 0, # vendor id unknown in this API layer + index, # device id + 2, # discrete gpu + str(dev.name()), + 1, # shader_buffer_float32_atomics + 1, # shader_buffer_float32_atomic_add + 1, # float64 support + 1 if (cc_major > 5 or (cc_major == 5 and cc_minor >= 3)) else 0, # float16 support + 1, # int64 + 1, # int16 + 1, # storage_buffer_16_bit_access + 1, # uniform_and_storage_buffer_16_bit_access + 1, # storage_push_constant_16 + 1, # storage_input_output_16 + max_workgroup_size, + int(attrs.get(cuda.device_attribute.MAX_THREADS_PER_BLOCK, 0)), + max_workgroup_count, + 8, # max descriptor sets (virtualized for parity) + 4096, # max push constant size + min(total_memory, (1 << 31) - 1), + 65536, + 16, + subgroup_size, + 0x7FFFFFFF, # supported stages (virtualized for parity) + 0x7FFFFFFF, # supported operations (virtualized for parity) + 1, + max_shared_memory, + [(1, 0x002)], # compute queue + 1, # scalar block layout + 1, # timeline semaphores equivalent + uuid_bytes, + ) + ) + + return devices + + +def context_create(device_indicies, queue_families): + if not state.initialized: + init(False, state.log_level) + + try: + device_ids = [int(x) for x in device_indicies] + except Exception: + set_error("context_create expected a list of integer device indices") + return 0 + + if len(device_ids) != 1: + set_error("CUDA Python backend currently supports exactly one device") + return 0 + + if len(queue_families) != 1 or len(queue_families[0]) != 1: + set_error("CUDA Python backend currently supports exactly one queue") + return 0 + + device_index = device_ids[0] + + cuda_context = None + context_pushed = False + + try: + if device_index < 0 or device_index >= cuda.Device.count(): + set_error(f"Invalid CUDA device index {device_index}") + return 0 + + dev = cuda.Device(device_index) + cc_major, _cc_minor = dev.compute_capability() + max_kernel_param_size = query_max_kernel_param_size(dev.device_raw, cc_major) + uses_primary_context = False + + if hasattr(dev, "retain_primary_context"): + cuda_context = dev.retain_primary_context() + uses_primary_context = True + cuda_context.push() + else: # pragma: no cover - fallback for older CUDA Python + cuda_context = dev.make_context() + context_pushed = True + stream = cuda.Stream() + + ctx = CUDAContext( + device_index=device_index, + cuda_context=cuda_context, + streams=[stream], + queue_count=1, + queue_to_device=[0], + max_kernel_param_size=int(max_kernel_param_size), + uses_primary_context=uses_primary_context, + stopped=False, + ) + handle = new_handle(state.contexts, ctx) + + # Leave no context current after creation. + cuda.Context.pop() + context_pushed = False + return handle + except Exception as exc: + if context_pushed: + try: + cuda.Context.pop() + except Exception: + pass + + if cuda_context is not None: + try: + cuda_context.detach() + except Exception: + pass + + set_error(f"Failed to create CUDA Python context: {exc}") + return 0 + + +def context_destroy(context): + ctx = state.contexts.pop(int(context), None) + if ctx is None: + return + + try: + with activate_context(ctx): + for stream in ctx.streams: + stream.synchronize() + except Exception: + pass + + try: + ctx.cuda_context.detach() + except Exception: + pass + + +def context_stop_threads(context): + ctx = state.contexts.get(int(context)) + if ctx is not None: + ctx.stopped = True + + +def get_error_string(): + if state.error_string is None: + return 0 + return state.error_string + + +def cuda_stream_override_begin(stream_obj): + try: + stack = stream_override_stack() + stack.append(coerce_stream_handle(stream_obj)) + except Exception as exc: + set_error(f"Failed to activate external CUDA stream override: {exc}") + + +def cuda_stream_override_end(): + stack = stream_override_stack() + if len(stack) > 0: + stack.pop() diff --git a/vkdispatch/backends/cuda_backend/bindings.py b/vkdispatch/backends/cuda_backend/bindings.py new file mode 100644 index 00000000..be7d82ee --- /dev/null +++ b/vkdispatch/backends/cuda_backend/bindings.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +import ctypes +import importlib.util +import os +from pathlib import Path +import shutil +import sys +from typing import List, Optional + +try: + import numpy as np +except Exception as exc: # pragma: no cover - import failure path + raise ImportError( + "The CUDA Python backend requires both 'cuda-python' and 'numpy' to be installed." + ) from exc + +try: + from cuda.bindings import driver, nvrtc +except Exception: + try: + from cuda import cuda as driver # type: ignore + from cuda import nvrtc # type: ignore + except Exception as exc: # pragma: no cover - import failure path + raise ImportError( + "The CUDA Python backend requires the NVIDIA cuda-python package " + "(`pip install cuda-python`)." + ) from exc + + +def to_int(value) -> int: + if isinstance(value, int): + return int(value) + + if hasattr(value, "value"): + try: + return int(value.value) + except Exception: + pass + + return int(value) + + +def drv_call(names, *args): + if isinstance(names, str): + names = [names] + + last_error = None + for name in names: + fn = getattr(driver, name, None) + if fn is not None: + try: + return fn(*args) + except TypeError as exc: + last_error = exc + continue + + if last_error is not None: + raise RuntimeError(f"CUDA Driver call failed for {names}: {last_error}") from last_error + raise RuntimeError(f"CUDA Driver symbol not found: {names}") + + +def nvrtc_call(names, *args): + if isinstance(names, str): + names = [names] + + last_error = None + for name in names: + fn = getattr(nvrtc, name, None) + if fn is not None: + try: + return fn(*args) + except TypeError as exc: + last_error = exc + continue + + if last_error is not None: + raise RuntimeError(f"NVRTC call failed for {names}: {last_error}") from last_error + raise RuntimeError(f"NVRTC symbol not found: {names}") + + +def status_success(status) -> bool: + try: + return to_int(status) == 0 + except Exception: + return str(status).endswith("CUDA_SUCCESS") or str(status).endswith("NVRTC_SUCCESS") + + +def drv_error_string(status) -> str: + try: + name_res = drv_call("cuGetErrorName", status) + string_res = drv_call("cuGetErrorString", status) + _name_status = name_res[0] if isinstance(name_res, tuple) else 1 + _string_status = string_res[0] if isinstance(string_res, tuple) else 1 + if status_success(_name_status) and status_success(_string_status): + name = name_res[1] if isinstance(name_res, tuple) and len(name_res) > 1 else name_res + text = string_res[1] if isinstance(string_res, tuple) and len(string_res) > 1 else string_res + if isinstance(name, (bytes, bytearray)): + name = name.decode("utf-8", errors="replace") + if isinstance(text, (bytes, bytearray)): + text = text.decode("utf-8", errors="replace") + return f"{name}: {text}" + except Exception: + pass + + return str(status) + + +def drv_check(result, op_name: str): + if isinstance(result, tuple): + status = result[0] + payload = result[1:] + else: + status = result + payload = () + + if not status_success(status): + raise RuntimeError(f"{op_name} failed ({drv_error_string(status)})") + + if len(payload) == 0: + return None + + if len(payload) == 1: + return payload[0] + + return payload + + +def nvrtc_check(result, op_name: str): + if isinstance(result, tuple): + status = result[0] + payload = result[1:] + else: + status = result + payload = () + + if not status_success(status): + raise RuntimeError(f"{op_name} failed ({status})") + + if len(payload) == 0: + return None + + if len(payload) == 1: + return payload[0] + + return payload + + +def nvrtc_read_bytes(program, size_api: str, read_api: str) -> bytes: + raw_size = nvrtc_check(nvrtc_call(size_api, program), size_api) + size = int(to_int(raw_size)) + if size <= 0: + return b"" + + def _normalize_output(data) -> Optional[bytes]: + if data is None: + return None + + if isinstance(data, memoryview): + data = data.tobytes() + elif isinstance(data, str): + data = data.encode("utf-8", errors="replace") + + if isinstance(data, (bytes, bytearray)): + raw = bytes(data) + if len(raw) >= size: + return raw[:size] + return raw + (b"\x00" * (size - len(raw))) + + if isinstance(data, (tuple, list)): + for item in data: + normalized = _normalize_output(item) + if normalized is not None: + return normalized + + return None + + try: + direct_data = nvrtc_check(nvrtc_call(read_api, program), read_api) + normalized = _normalize_output(direct_data) + if normalized is not None: + return normalized + except Exception: + pass + + out_c = ctypes.create_string_buffer(size) + out_bytearray = bytearray(size) + out_bytes = bytes(size) + + for out_candidate in (out_bytes, out_bytearray, out_c): + try: + call_result = nvrtc_check(nvrtc_call(read_api, program, out_candidate), read_api) + normalized_result = _normalize_output(call_result) + if normalized_result is not None: + return normalized_result + + if isinstance(out_candidate, bytearray): + return bytes(out_candidate) + + if out_candidate is out_c: + return bytes(out_c.raw) + except Exception: + continue + + return bytes(out_c.raw) + + +def discover_cuda_include_dirs() -> List[str]: + include_dirs: List[str] = [] + seen = set() + + def add_dir(path_like) -> None: + if path_like is None: + return + try: + resolved = str(Path(path_like).resolve()) + except Exception: + resolved = str(path_like) + if resolved in seen: + return + header_path = Path(resolved) / "cuda_runtime.h" + if header_path.exists(): + seen.add(resolved) + include_dirs.append(resolved) + + # Standard CUDA environment variables. + for env_name in ( + "CUDA_HOME", + "CUDA_PATH", + "CUDA_ROOT", + "CUDA_TOOLKIT_ROOT_DIR", + "CUDAToolkit_ROOT", + ): + root = os.environ.get(env_name) + if root: + add_dir(Path(root) / "include") + + # CUDA toolkit from nvcc location. + nvcc_path = shutil.which("nvcc") + if nvcc_path: + try: + nvcc_root = Path(nvcc_path).resolve().parent.parent + add_dir(nvcc_root / "include") + except Exception: + pass + + # Common Unix install locations. + add_dir("/usr/local/cuda/include") + add_dir("/opt/cuda/include") + add_dir("/usr/include") + + # Conda cudatoolkit layouts. + conda_prefix = os.environ.get("CONDA_PREFIX") + if conda_prefix: + add_dir(Path(conda_prefix) / "include") + add_dir(Path(conda_prefix) / "targets" / "x86_64-linux" / "include") + add_dir(Path(conda_prefix) / "Library" / "include") + + # NVIDIA pip wheel layout. + for base in sys.path: + add_dir(Path(base) / "nvidia" / "cuda_runtime" / "include") + + # Some environments expose this namespace package. + try: + spec = importlib.util.find_spec("nvidia.cuda_runtime") + if spec is not None and spec.submodule_search_locations: + for entry in spec.submodule_search_locations: + add_dir(Path(entry) / "include") + except Exception: + pass + + return include_dirs + + +def prepare_nvrtc_options(options: List[bytes]) -> List[bytes]: + normalized: List[bytes] = [] + has_include_path = False + + for opt in options: + as_str = opt.decode("utf-8", errors="replace") + if as_str.startswith("-I") or as_str.startswith("--include-path"): + has_include_path = True + normalized.append(opt) + + if not has_include_path: + for include_dir in discover_cuda_include_dirs(): + normalized.append(f"--include-path={include_dir}".encode("utf-8")) + + return normalized + + +def as_driver_handle(type_name: str, value): + handle_type = getattr(driver, type_name, None) + if handle_type is None: + return value + + try: + if isinstance(value, handle_type): + return value + except Exception: + pass + + try: + return handle_type(to_int(value)) + except Exception: + return value + + +def writable_host_ptr(view: memoryview): + byte_view = view.cast("B") + try: + c_buffer = (ctypes.c_ubyte * len(byte_view)).from_buffer(byte_view) + return ctypes.addressof(c_buffer), c_buffer + except Exception: + copied = ctypes.create_string_buffer(byte_view.tobytes()) + return ctypes.addressof(copied), copied + + +def readonly_host_ptr(view: memoryview): + byte_view = view.cast("B") + try: + c_buffer = (ctypes.c_ubyte * len(byte_view)).from_buffer(byte_view) + return ctypes.addressof(c_buffer), c_buffer + except Exception: + copied = ctypes.create_string_buffer(byte_view.tobytes()) + return ctypes.addressof(copied), copied diff --git a/vkdispatch/backends/cuda_backend/constants.py b/vkdispatch/backends/cuda_backend/constants.py new file mode 100644 index 00000000..246346be --- /dev/null +++ b/vkdispatch/backends/cuda_backend/constants.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import re + +# Log level constants mirrored from native bindings. +LOG_LEVEL_VERBOSE = 0 +LOG_LEVEL_INFO = 1 +LOG_LEVEL_WARNING = 2 +LOG_LEVEL_ERROR = 3 + +# Descriptor type enum values mirrored from vkdispatch_native/stages_extern.pxd. +DESCRIPTOR_TYPE_STORAGE_BUFFER = 1 +DESCRIPTOR_TYPE_STORAGE_IMAGE = 2 +DESCRIPTOR_TYPE_UNIFORM_BUFFER = 3 +DESCRIPTOR_TYPE_UNIFORM_IMAGE = 4 +DESCRIPTOR_TYPE_SAMPLER = 5 + +LOCAL_X_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_X\s+(\d+)") +LOCAL_Y_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Y\s+(\d+)") +LOCAL_Z_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Z\s+(\d+)") +KERNEL_SIGNATURE_RE = re.compile(r"vkdispatch_main\s*\(([^)]*)\)", re.S) +BINDING_PARAM_RE = re.compile(r"vkdispatch_binding_(\d+)_ptr$") +SAMPLER_PARAM_RE = re.compile(r"vkdispatch_sampler_(\d+)$") diff --git a/vkdispatch/backends/cuda_backend/cuda_primitives.py b/vkdispatch/backends/cuda_backend/cuda_primitives.py new file mode 100644 index 00000000..8a3af54a --- /dev/null +++ b/vkdispatch/backends/cuda_backend/cuda_primitives.py @@ -0,0 +1,571 @@ +from __future__ import annotations + +import ctypes +from dataclasses import dataclass +from typing import List, Optional + +from .bindings import ( + np, + driver, + as_driver_handle, + discover_cuda_include_dirs, + drv_call, + drv_check, + nvrtc_call, + nvrtc_check, + nvrtc_read_bytes, + prepare_nvrtc_options, + readonly_host_ptr, + status_success, + to_int, + writable_host_ptr, +) + + +@dataclass +class _ByValueKernelArg: + payload: bytes + raw_name: str + + +class _DeviceAllocation: + def __init__(self, ptr: int): + self.ptr = int(ptr) + self.freed = False + + def __int__(self): + return int(self.ptr) + + def free(self): + if self.freed: + return + + drv_check( + drv_call( + ["cuMemFree", "cuMemFree_v2"], + as_driver_handle("CUdeviceptr", self.ptr), + ), + "cuMemFree", + ) + self.freed = True + + +class _ContextHandle: + def __init__(self, context_raw, device_index: int, uses_primary_context: bool): + self.context_raw = context_raw + self.device_index = int(device_index) + self.uses_primary_context = bool(uses_primary_context) + self._detached = False + + def push(self): + drv_check( + drv_call( + "cuCtxPushCurrent", + as_driver_handle("CUcontext", self.context_raw), + ), + "cuCtxPushCurrent", + ) + + def detach(self): + if self._detached: + return + + if self.uses_primary_context: + dev = drv_check(drv_call("cuDeviceGet", int(self.device_index)), "cuDeviceGet") + drv_check(drv_call("cuDevicePrimaryCtxRelease", dev), "cuDevicePrimaryCtxRelease") + else: + drv_check( + drv_call( + ["cuCtxDestroy", "cuCtxDestroy_v2"], + as_driver_handle("CUcontext", self.context_raw), + ), + "cuCtxDestroy", + ) + self._detached = True + + +class _StreamHandle: + def __init__(self, handle: Optional[int] = None, ptr: Optional[int] = None, *args, **kwargs): + _ = kwargs + if handle is None and ptr is None and len(args) == 1: + handle = int(args[0]) + if handle is None and ptr is not None: + handle = int(ptr) + + if handle is None: + stream_raw = drv_check(drv_call("cuStreamCreate", 0), "cuStreamCreate") + self.handle = int(to_int(stream_raw)) + self.owned = True + else: + self.handle = int(handle) + self.owned = False + + def synchronize(self): + drv_check( + drv_call( + "cuStreamSynchronize", + as_driver_handle("CUstream", self.handle), + ), + "cuStreamSynchronize", + ) + + def __int__(self): + return int(self.handle) + + @property + def ptr(self): + return int(self.handle) + + @property + def cuda_stream(self): + return int(self.handle) + + +class _EventHandle: + def __init__(self): + self.event_raw = drv_check(drv_call("cuEventCreate", 0), "cuEventCreate") + + def record(self, stream_obj: Optional["_StreamHandle"]): + stream_handle = 0 if stream_obj is None else int(stream_obj) + drv_check( + drv_call( + "cuEventRecord", + self.event_raw, + as_driver_handle("CUstream", stream_handle), + ), + "cuEventRecord", + ) + + def query(self) -> bool: + res = drv_call("cuEventQuery", self.event_raw) + status = res[0] if isinstance(res, tuple) else res + + if status_success(status): + return True + + status_text = str(status) + if "NOT_READY" in status_text: + return False + + if to_int(status) != 0: + return False + + return True + + def synchronize(self): + drv_check(drv_call("cuEventSynchronize", self.event_raw), "cuEventSynchronize") + + +class _KernelFunction: + def __init__(self, function_raw): + self.function_raw = function_raw + + def __call__(self, *args, block, grid, stream=None): + arg_values = [] + + def _dedupe(values): + out = [] + seen = set() + for value in values: + key = f"{type(value).__name__}:{repr(value)}" + if key in seen: + continue + seen.add(key) + out.append(value) + return out + + arg_ptr_values = [] + for arg in args: + if isinstance(arg, _ByValueKernelArg): + payload = arg.payload + if len(payload) == 0: + payload = b"\x00" + + payload_storage = (ctypes.c_ubyte * len(payload)).from_buffer_copy(payload) + arg_values.append(payload_storage) + arg_ptr_values.append(ctypes.addressof(payload_storage)) + continue + + scalar_storage = ctypes.c_uint64(int(arg)) + arg_values.append(scalar_storage) + arg_ptr_values.append(ctypes.addressof(scalar_storage)) + + arg_ptr_array = None + if len(arg_ptr_values) > 0: + arg_ptr_array = (ctypes.c_void_p * len(arg_ptr_values))( + *[ctypes.c_void_p(ptr) for ptr in arg_ptr_values] + ) + + kernel_param_variants = [None, 0, ctypes.c_void_p(0)] + if arg_ptr_array is not None: + array_ptr = ctypes.cast(arg_ptr_array, ctypes.POINTER(ctypes.c_void_p)) + kernel_param_variants = _dedupe( + [ + arg_ptr_array, + array_ptr, + ctypes.cast(array_ptr, ctypes.c_void_p), + ctypes.cast(array_ptr, ctypes.c_void_p).value, + tuple(arg_ptr_values), + list(arg_ptr_values), + ] + ) + + stream_handle = 0 if stream is None else int(stream) + stream_variants = _dedupe( + [ + stream_handle, + as_driver_handle("CUstream", stream_handle), + ] + ) + + function_candidates = [ + self.function_raw, + as_driver_handle("CUfunction", self.function_raw), + ] + try: + function_candidates.append(to_int(self.function_raw)) + except Exception: + pass + function_variants = _dedupe(function_candidates) + + extra_variants = [None, 0, ctypes.c_void_p(0)] + last_error = None + + for function_handle in function_variants: + for stream_value in stream_variants: + for kernel_params in kernel_param_variants: + for extra in extra_variants: + try: + drv_check( + drv_call( + "cuLaunchKernel", + function_handle, + int(grid[0]), + int(grid[1]), + int(grid[2]), + int(block[0]), + int(block[1]), + int(block[2]), + 0, + stream_value, + kernel_params, + extra, + ), + "cuLaunchKernel", + ) + return + except Exception as exc: + last_error = exc + + try: + drv_check( + drv_call( + "cuLaunchKernel", + function_handle, + int(grid[0]), + int(grid[1]), + int(grid[2]), + int(block[0]), + int(block[1]), + int(block[2]), + 0, + stream_value, + kernel_params, + ), + "cuLaunchKernel", + ) + return + except Exception as exc: + last_error = exc + continue + + if last_error is None: + raise RuntimeError("cuLaunchKernel failed with no diagnostic.") + raise RuntimeError(f"cuLaunchKernel failed: {last_error}") from last_error + + +class SourceModule: + def __init__(self, source: str, no_extern_c: bool = True, options: Optional[List[str]] = None): + _ = no_extern_c + if options is None: + options = [] + + program_name = b"vkdispatch.cu" + source_bytes = source.encode("utf-8") + program = nvrtc_check( + nvrtc_call( + "nvrtcCreateProgram", + source_bytes, + program_name, + 0, + [], + [], + ), + "nvrtcCreateProgram", + ) + + cubin = b"" + ptx = b"" + build_log = b"" + + try: + encoded_options = [opt.encode("utf-8") if isinstance(opt, str) else bytes(opt) for opt in options] + encoded_options = prepare_nvrtc_options(encoded_options) + compile_result = nvrtc_call("nvrtcCompileProgram", program, len(encoded_options), encoded_options) + compile_status = compile_result[0] if isinstance(compile_result, tuple) else compile_result + + build_log = nvrtc_read_bytes(program, "nvrtcGetProgramLogSize", "nvrtcGetProgramLog") + if not status_success(compile_status): + clean_build_log = build_log.rstrip(b"\x00").decode("utf-8", errors="replace") + if 'could not open source file "cuda_runtime.h"' in clean_build_log: + discovered = discover_cuda_include_dirs() + hint = ( + " NVRTC could not find CUDA headers. " + f"Discovered include dirs: {discovered if len(discovered) > 0 else 'none'}. " + "Set CUDA_HOME/CUDA_PATH to your toolkit root or ensure nvcc is on PATH." + ) + else: + hint = "" + raise RuntimeError( + f"NVRTC compilation failed: {clean_build_log}{hint}" + ) + + try: + cubin = nvrtc_read_bytes(program, "nvrtcGetCUBINSize", "nvrtcGetCUBIN") + except Exception: + cubin = b"" + + if len(cubin) == 0: + try: + ptx = nvrtc_read_bytes(program, "nvrtcGetPTXSize", "nvrtcGetPTX") + except Exception: + ptx = b"" + finally: + try: + nvrtc_check(nvrtc_call("nvrtcDestroyProgram", program), "nvrtcDestroyProgram") + except Exception: + pass + + image_data = cubin + if len(image_data) == 0: + image_data = ptx + + if len(image_data) == 0: + raise RuntimeError("NVRTC compilation succeeded but produced neither a CUBIN nor a PTX payload.") + + if len(cubin) == 0 and not image_data.endswith(b"\x00"): + image_data += b"\x00" + + self.module_raw = drv_check( + drv_call(["cuModuleLoadDataEx", "cuModuleLoadData"], image_data), + "cuModuleLoadData", + ) + + def get_function(self, name: str): + func_raw = drv_check( + drv_call("cuModuleGetFunction", self.module_raw, name.encode("utf-8")), + "cuModuleGetFunction", + ) + return _KernelFunction(func_raw) + + +class _CudaDevice: + class device_attribute: + MAX_BLOCK_DIM_X = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X", + 0, + ) + MAX_BLOCK_DIM_Y = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y", + 0, + ) + MAX_BLOCK_DIM_Z = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z", + 0, + ) + MAX_THREADS_PER_BLOCK = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK", + 0, + ) + MAX_GRID_DIM_X = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X", + 0, + ) + MAX_GRID_DIM_Y = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y", + 0, + ) + MAX_GRID_DIM_Z = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z", + 0, + ) + WARP_SIZE = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_WARP_SIZE", + 0, + ) + MAX_SHARED_MEMORY_PER_BLOCK = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK", + 0, + ) + + class Device: + def __init__(self, index: int): + self.index = int(index) + self.device_raw = drv_check(drv_call("cuDeviceGet", self.index), "cuDeviceGet") + + @staticmethod + def count(): + return int(drv_check(drv_call("cuDeviceGetCount"), "cuDeviceGetCount")) + + def get_attributes(self): + attrs = {} + for attr_name in ( + "MAX_BLOCK_DIM_X", + "MAX_BLOCK_DIM_Y", + "MAX_BLOCK_DIM_Z", + "MAX_THREADS_PER_BLOCK", + "MAX_GRID_DIM_X", + "MAX_GRID_DIM_Y", + "MAX_GRID_DIM_Z", + "WARP_SIZE", + "MAX_SHARED_MEMORY_PER_BLOCK", + ): + attr_enum = getattr(_CudaDevice.device_attribute, attr_name) + try: + val = drv_check( + drv_call("cuDeviceGetAttribute", attr_enum, self.device_raw), + "cuDeviceGetAttribute", + ) + attrs[attr_enum] = int(val) + except Exception: + attrs[attr_enum] = 0 + return attrs + + def compute_capability(self): + major_enum = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR", + 0, + ) + minor_enum = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR", + 0, + ) + major = drv_check(drv_call("cuDeviceGetAttribute", major_enum, self.device_raw), "cuDeviceGetAttribute") + minor = drv_check(drv_call("cuDeviceGetAttribute", minor_enum, self.device_raw), "cuDeviceGetAttribute") + return int(major), int(minor) + + def total_memory(self): + return int(drv_check(drv_call(["cuDeviceTotalMem", "cuDeviceTotalMem_v2"], self.device_raw), "cuDeviceTotalMem")) + + def pci_bus_id(self): + try: + bus_id = drv_check(drv_call("cuDeviceGetPCIBusId", 64, self.device_raw), "cuDeviceGetPCIBusId") + if isinstance(bus_id, (bytes, bytearray)): + return bus_id.decode("utf-8", errors="replace").rstrip("\x00") + return str(bus_id) + except Exception: + return f"cuda-device-{self.index}" + + def name(self): + try: + name = drv_check(drv_call("cuDeviceGetName", 128, self.device_raw), "cuDeviceGetName") + if isinstance(name, (bytes, bytearray)): + return name.decode("utf-8", errors="replace").rstrip("\x00") + return str(name) + except Exception: + return f"CUDA Device {self.index}" + + def retain_primary_context(self): + ctx_raw = drv_check(drv_call("cuDevicePrimaryCtxRetain", self.device_raw), "cuDevicePrimaryCtxRetain") + return _ContextHandle(ctx_raw, self.index, True) + + def make_context(self): + ctx_raw = drv_check( + drv_call(["cuCtxCreate", "cuCtxCreate_v2"], 0, self.device_raw), + "cuCtxCreate", + ) + return _ContextHandle(ctx_raw, self.index, False) + + class Context: + @staticmethod + def pop(): + try: + drv_check(drv_call("cuCtxPopCurrent"), "cuCtxPopCurrent") + return + except Exception: + pass + + popped = ctypes.c_void_p() + drv_check(drv_call("cuCtxPopCurrent", popped), "cuCtxPopCurrent") + + Stream = _StreamHandle + ExternalStream = _StreamHandle + Event = _EventHandle + DeviceAllocation = _DeviceAllocation + device_attribute = device_attribute + + @staticmethod + def init(): + drv_check(drv_call("cuInit", 0), "cuInit") + + @staticmethod + def get_driver_version(): + return int(drv_check(drv_call("cuDriverGetVersion"), "cuDriverGetVersion")) + + @staticmethod + def mem_alloc(size: int): + ptr = drv_check( + drv_call(["cuMemAlloc", "cuMemAlloc_v2"], int(size)), + "cuMemAlloc", + ) + return _DeviceAllocation(int(to_int(ptr))) + + @staticmethod + def memcpy_htod_async(dst_ptr, src_obj, stream_obj): + src_view = memoryview(src_obj).cast("B") + host_ptr, _keepalive = readonly_host_ptr(src_view) + stream_handle = 0 if stream_obj is None else int(stream_obj) + drv_check( + drv_call( + ["cuMemcpyHtoDAsync", "cuMemcpyHtoDAsync_v2"], + as_driver_handle("CUdeviceptr", int(dst_ptr)), + host_ptr, + len(src_view), + as_driver_handle("CUstream", stream_handle), + ), + "cuMemcpyHtoDAsync", + ) + + @staticmethod + def memcpy_dtoh_async(dst_obj, src_ptr, stream_obj): + dst_view = memoryview(dst_obj).cast("B") + host_ptr, _keepalive = writable_host_ptr(dst_view) + stream_handle = 0 if stream_obj is None else int(stream_obj) + drv_check( + drv_call( + ["cuMemcpyDtoHAsync", "cuMemcpyDtoHAsync_v2"], + host_ptr, + as_driver_handle("CUdeviceptr", int(src_ptr)), + len(dst_view), + as_driver_handle("CUstream", stream_handle), + ), + "cuMemcpyDtoHAsync", + ) + + @staticmethod + def pagelocked_empty(size: int, dtype): + return np.empty(int(size), dtype=dtype) + + +cuda = _CudaDevice diff --git a/vkdispatch/backends/cuda_backend/descriptor_sets.py b/vkdispatch/backends/cuda_backend/descriptor_sets.py new file mode 100644 index 00000000..10670708 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/descriptor_sets.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from . import state as state +from .helpers import set_error, to_bytes, buffer_device_ptr + +from .handle import CUDAHandle, HandleRegistry +from typing import Dict, Tuple, Optional + +_descriptor_sets: HandleRegistry = HandleRegistry() + +class CUDADescriptorSet(CUDAHandle): + plan_handle: int + buffer_bindings: Dict[int, Tuple[int, int, int, int, int, int]] + image_bindings: Dict[int, Tuple[int, int, int, int]] + inline_uniform_payload: bytes + + def __init__(self, plan_handle: int): + super().__init__(_descriptor_sets) + + self.plan_handle = plan_handle + self.buffer_bindings = {} + self.image_bindings = {} + self.inline_uniform_payload = b"" + + @staticmethod + def from_handle(handle: int) -> Optional["CUDADescriptorSet"]: + return _descriptor_sets.get(int(handle)) + + def resolve_buffer_pointer(self, binding: int) -> int: + binding_info = self.buffer_bindings.get(binding) + if binding_info is None: + raise RuntimeError(f"Missing descriptor buffer binding {binding}") + + buffer_handle, offset, _, _, _, _ = binding_info + + buffer_obj = state.buffers.get(int(buffer_handle)) + if buffer_obj is None: + raise RuntimeError(f"Invalid buffer handle {buffer_handle} for binding {binding}") + + return buffer_device_ptr(buffer_obj) + int(offset) + +def descriptor_set_create(plan): + if int(plan) not in state.compute_plans: + set_error("Invalid compute plan handle for descriptor_set_create") + return 0 + + return CUDADescriptorSet(plan_handle=int(plan)).handle + + +def descriptor_set_destroy(descriptor_set): + _descriptor_sets.pop(descriptor_set) + + +def descriptor_set_write_buffer( + descriptor_set, + binding, + object, + offset, + range, + uniform, + read_access, + write_access, +): + ds = CUDADescriptorSet.from_handle(descriptor_set) + if ds is None: + set_error("Invalid descriptor set handle for descriptor_set_write_buffer") + return + + ds.buffer_bindings[int(binding)] = ( + int(object), + int(offset), + int(range), + int(uniform), + int(read_access), + int(write_access), + ) + + +def descriptor_set_write_image( + descriptor_set, + binding, + object, + sampler_obj, + read_access, + write_access, +): + _ = descriptor_set + _ = binding + _ = object + _ = sampler_obj + _ = read_access + _ = write_access + set_error("CUDA Python backend does not support image objects yet") + + +def descriptor_set_write_inline_uniform(descriptor_set, payload): + ds = CUDADescriptorSet.from_handle(descriptor_set) + if ds is None: + set_error("Invalid descriptor set handle for descriptor_set_write_inline_uniform") + return + + try: + ds.inline_uniform_payload = to_bytes(payload) + except Exception as exc: + set_error(f"Failed to store inline uniform payload: {exc}") diff --git a/vkdispatch/backends/cuda_backend/handle.py b/vkdispatch/backends/cuda_backend/handle.py new file mode 100644 index 00000000..5f5e5082 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/handle.py @@ -0,0 +1,26 @@ +from typing import Dict, Optional + +from . import state as state + +class HandleRegistry: + def __init__(self): + self.registry: Dict[int, object] = {} + + def new_handle(self, obj: object) -> int: + handle = state.next_handle + self.registry[handle] = obj + state.next_handle += 1 + return handle + + def get(self, handle: int) -> Optional[object]: + return self.registry.get(int(handle)) + + def pop(self, handle: int) -> Optional[object]: + return self.registry.pop(int(handle), None) + + +class CUDAHandle: + handle: int + + def __init__(self, registry: HandleRegistry): + self.handle = registry.new_handle(self) \ No newline at end of file diff --git a/vkdispatch/backends/cuda_backend/helpers.py b/vkdispatch/backends/cuda_backend/helpers.py new file mode 100644 index 00000000..5dad2743 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/helpers.py @@ -0,0 +1,380 @@ +from __future__ import annotations + +from contextlib import contextmanager +import re +import sys +from typing import Dict, List, Optional, Tuple, Any + +from . import state as state +from .bindings import driver, np, drv_call, drv_check, to_int +from .constants import ( + BINDING_PARAM_RE, + KERNEL_SIGNATURE_RE, + LOCAL_X_RE, + LOCAL_Y_RE, + LOCAL_Z_RE, + SAMPLER_PARAM_RE, +) +from .cuda_primitives import _ByValueKernelArg, cuda +from .state import CUDABuffer, CUDAComputePlan, CUDAContext, CUDAKernelParam + +#from .api_descriptor import CUDADescriptorSet + +def new_handle(registry: Dict[int, object], obj: object) -> int: + handle = state.next_handle + state.next_handle += 1 + registry[handle] = obj + return handle + + +def to_bytes(value) -> bytes: + if value is None: + return b"" + if isinstance(value, bytes): + return value + if isinstance(value, bytearray): + return bytes(value) + if isinstance(value, memoryview): + return value.tobytes() + return bytes(value) + + +def set_error(message: str) -> None: + state.error_string = str(message) + + +def clear_error() -> None: + state.error_string = None + + +def coerce_stream_handle(stream_obj) -> Optional[int]: + if stream_obj is None: + return None + + if isinstance(stream_obj, int): + return int(stream_obj) + + cuda_stream_protocol = getattr(stream_obj, "__cuda_stream__", None) + if cuda_stream_protocol is not None: + try: + proto_value = cuda_stream_protocol() if callable(cuda_stream_protocol) else cuda_stream_protocol + if isinstance(proto_value, tuple) and len(proto_value) > 0: + proto_value = proto_value[0] + return int(proto_value) + except Exception: + pass + + for attr_name in ("cuda_stream", "ptr", "handle"): + if hasattr(stream_obj, attr_name): + try: + return int(getattr(stream_obj, attr_name)) + except Exception: + pass + + nested = getattr(stream_obj, "stream", None) + if nested is not None and nested is not stream_obj: + try: + return coerce_stream_handle(nested) + except Exception: + pass + + try: + return int(stream_obj) + except Exception as exc: + raise TypeError( + "Unable to extract a CUDA stream handle from the provided object. " + "Pass an int handle or an object with __cuda_stream__/.cuda_stream/.ptr/.handle." + ) from exc + + +def stream_override_stack() -> List[Optional[int]]: + stack = getattr(state.stream_override, "stack", None) + if stack is None: + stack = [] + state.stream_override.stack = stack + return stack + + +def get_stream_override_handle() -> Optional[int]: + stack = getattr(state.stream_override, "stack", None) + if not stack: + return None + return stack[-1] + + +def wrap_external_stream(handle: int): + handle = int(handle) + + if handle in state.external_stream_cache: + return state.external_stream_cache[handle] + + if handle == 0: + return None + + ctor_attempts = [ + lambda: cuda.Stream(handle=handle), + lambda: cuda.Stream(ptr=handle), + lambda: cuda.Stream(int(handle)), + ] + + external_cls = getattr(cuda, "ExternalStream", None) + if external_cls is not None: + ctor_attempts.insert(0, lambda: external_cls(handle)) + + last_error = None + for ctor in ctor_attempts: + try: + stream_obj = ctor() + state.external_stream_cache[handle] = stream_obj + return stream_obj + except Exception as exc: # pragma: no cover - depends on cuda-python version + last_error = exc + + raise RuntimeError( + f"Failed to wrap external CUDA stream handle {handle} with CUDA Python. " + "This CUDA Python version may not support external stream wrappers." + ) from last_error + + +def stream_for_queue(ctx: CUDAContext, queue_index: int): + override_handle = get_stream_override_handle() + if override_handle is None: + return ctx.streams[queue_index] + return wrap_external_stream(int(override_handle)) + + +def buffer_device_ptr(buffer_obj: CUDABuffer) -> int: + return int(buffer_obj.device_ptr) + + +def queue_indices(ctx: CUDAContext, queue_index: int, *, all_on_negative: bool = False) -> List[int]: + if ctx.queue_count <= 0: + return [] + + if queue_index is None: + return [0] + + queue_index = int(queue_index) + + if all_on_negative and queue_index < 0: + return list(range(ctx.queue_count)) + + if queue_index == -1: + return [0] + + if 0 <= queue_index < ctx.queue_count: + return [queue_index] + + return [] + + +def context_from_handle(context_handle: int) -> Optional[CUDAContext]: + ctx = state.contexts.get(int(context_handle)) + if ctx is None: + set_error(f"Invalid context handle {context_handle}") + return ctx + + +@contextmanager +def activate_context(ctx: CUDAContext): + ctx.cuda_context.push() + try: + yield + finally: + cuda.Context.pop() + +def allocate_staging_storage(size: int): + try: + # Pagelocked host memory improves async HtoD/DtoH throughput and overlap. + return cuda.pagelocked_empty(int(size), np.uint8) + except Exception: + return bytearray(int(size)) + + +def fallback_max_kernel_param_size(compute_capability_major: int) -> int: + # CUDA kernels support at least 4 KiB of launch parameters on legacy devices. + # Volta+ devices commonly expose a larger 32 KiB-ish argument space. + return 32764 if int(compute_capability_major) >= 7 else 4096 + + +def query_max_kernel_param_size(device_raw, compute_capability_major: int) -> int: + attr_names = ( + "CU_DEVICE_ATTRIBUTE_MAX_PARAMETER_SIZE", + "CU_DEVICE_ATTRIBUTE_MAX_PARAMETER_SIZE_SUPPORTED", + "CU_DEVICE_ATTRIBUTE_MAX_KERNEL_PARAMETER_SIZE", + ) + + attr_enum_container = getattr(driver, "CUdevice_attribute", None) + if attr_enum_container is not None: + for attr_name in attr_names: + attr_enum = getattr(attr_enum_container, attr_name, None) + if attr_enum is None: + continue + + try: + queried_value = drv_check( + drv_call("cuDeviceGetAttribute", attr_enum, device_raw), + "cuDeviceGetAttribute", + ) + queried_size = int(to_int(queried_value)) + if queried_size > 0: + return queried_size + except Exception: + continue + + print( + "Warning: Unable to query max kernel parameter size from CUDA driver. Falling back to a conservative default.", + file=sys.stderr, + ) + + return fallback_max_kernel_param_size(compute_capability_major) + + +def parse_local_size(source: str) -> Tuple[int, int, int]: + x_match = LOCAL_X_RE.search(source) + y_match = LOCAL_Y_RE.search(source) + z_match = LOCAL_Z_RE.search(source) + + x = int(x_match.group(1)) if x_match else 1 + y = int(y_match.group(1)) if y_match else 1 + z = int(z_match.group(1)) if z_match else 1 + + return (x, y, z) + + +def parse_kernel_params(source: str) -> List[CUDAKernelParam]: + signature_match = KERNEL_SIGNATURE_RE.search(source) + if signature_match is None: + raise RuntimeError("Could not find vkdispatch_main kernel signature in CUDA source") + + signature_blob = signature_match.group(1).strip() + if len(signature_blob) == 0: + return [] + + params: List[CUDAKernelParam] = [] + + for raw_decl in [part.strip() for part in signature_blob.split(",") if len(part.strip()) > 0]: + name_match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)\s*$", raw_decl) + if name_match is None: + raise RuntimeError(f"Unable to parse kernel parameter declaration '{raw_decl}'") + + param_name = name_match.group(1) + + if param_name == "vkdispatch_uniform_ptr": + params.append(CUDAKernelParam("uniform", 0, param_name)) + continue + + if param_name == "vkdispatch_uniform_value": + params.append(CUDAKernelParam("uniform_value", None, param_name)) + continue + + if param_name == "vkdispatch_pc_value": + params.append(CUDAKernelParam("push_constant_value", None, param_name)) + continue + + binding_match = BINDING_PARAM_RE.match(param_name) + if binding_match is not None: + params.append(CUDAKernelParam("storage", int(binding_match.group(1)), param_name)) + continue + + sampler_match = SAMPLER_PARAM_RE.match(param_name) + if sampler_match is not None: + params.append(CUDAKernelParam("sampler", int(sampler_match.group(1)), param_name)) + continue + + params.append(CUDAKernelParam("unknown", None, param_name)) + + return params + +def build_kernel_args_template( + plan: CUDAComputePlan, + descriptor_set: Optional[Any], # CUDADescriptorSet + push_constant_payload: bytes = b"", +) -> Tuple[object, ...]: + args: List[object] = [] + + for param in plan.params: + if param.kind == "uniform": + if descriptor_set is None: + raise RuntimeError("Kernel requires a descriptor set but none was provided") + + args.append(np.uintp(descriptor_set.resolve_buffer_pointer(0))) + continue + + if param.kind == "uniform_value": + if descriptor_set is None: + raise RuntimeError("Kernel requires a descriptor set but none was provided") + + if len(descriptor_set.inline_uniform_payload) == 0: + raise RuntimeError( + "Missing inline uniform payload for CUDA by-value uniform parameter " + f"'{param.raw_name}'." + ) + + args.append(_ByValueKernelArg(descriptor_set.inline_uniform_payload, param.raw_name)) + continue + + if param.kind == "push_constant_value": + if plan.pc_size <= 0: + raise RuntimeError( + f"Kernel parameter '{param.raw_name}' expects push-constant data, but this compute plan has pc_size={plan.pc_size}." + ) + + if len(push_constant_payload) == 0: + raise RuntimeError( + "Missing push-constant payload for CUDA by-value push-constant parameter " + f"'{param.raw_name}'." + ) + + if len(push_constant_payload) != int(plan.pc_size): + raise RuntimeError( + f"Push-constant payload size mismatch for parameter '{param.raw_name}'. " + f"Expected {plan.pc_size} bytes but got {len(push_constant_payload)} bytes." + ) + + args.append(_ByValueKernelArg(push_constant_payload, param.raw_name)) + continue + + if param.kind == "storage": + if descriptor_set is None: + raise RuntimeError("Kernel requires a descriptor set but none was provided") + + if param.binding is None: + raise RuntimeError("Storage parameter has no binding index") + + args.append(np.uintp(descriptor_set.resolve_buffer_pointer(param.binding))) + continue + + if param.kind == "sampler": + raise RuntimeError("CUDA Python backend does not support sampled image bindings yet") + + raise RuntimeError( + f"Unsupported kernel parameter '{param.raw_name}'. " + "Expected vkdispatch_uniform_ptr / vkdispatch_uniform_value / vkdispatch_pc_value / vkdispatch_binding__ptr." + ) + + return tuple(args) + + +def align_up(value: int, alignment: int) -> int: + if alignment <= 1: + return value + return ((value + alignment - 1) // alignment) * alignment + + +def estimate_kernel_param_size_bytes(args: Tuple[object, ...]) -> int: + total_bytes = 0 + + for arg in args: + if isinstance(arg, _ByValueKernelArg): + payload_size = len(arg.payload) + # Kernel params are aligned by argument type. Use a conservative + # 16-byte alignment for by-value structs. + total_bytes = align_up(total_bytes, 16) + total_bytes += payload_size + continue + + total_bytes = align_up(total_bytes, 8) + total_bytes += 8 + + return total_bytes diff --git a/vkdispatch/backends/cuda_backend/image_fft_stubs.py b/vkdispatch/backends/cuda_backend/image_fft_stubs.py new file mode 100644 index 00000000..7b21e627 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/image_fft_stubs.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from . import state as state +from .helpers import set_error + + +def image_create(context, extent, layers, format, type, view_type, generate_mips): + _ = context + _ = extent + _ = layers + _ = format + _ = type + _ = view_type + _ = generate_mips + set_error("CUDA Python backend does not support image objects yet") + return 0 + + +def image_destroy(image): + _ = image + set_error("CUDA Python backend does not support image objects yet") + + +def image_create_sampler( + context, + mag_filter, + min_filter, + mip_mode, + address_mode, + mip_lod_bias, + min_lod, + max_lod, + border_color, +): + _ = context + _ = mag_filter + _ = min_filter + _ = mip_mode + _ = address_mode + _ = mip_lod_bias + _ = min_lod + _ = max_lod + _ = border_color + set_error("CUDA Python backend does not support image samplers yet") + return 0 + + +def image_destroy_sampler(sampler): + _ = sampler + set_error("CUDA Python backend does not support image samplers yet") + + +def image_write(image, data, offset, extent, baseLayer, layerCount, device_index): + _ = image + _ = data + _ = offset + _ = extent + _ = baseLayer + _ = layerCount + _ = device_index + set_error("CUDA Python backend does not support image writes yet") + + +def image_format_block_size(format): + _ = format + set_error("CUDA Python backend does not support image format block size queries yet") + + +def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index): + _ = image + _ = offset + _ = extent + _ = baseLayer + _ = layerCount + _ = device_index + set_error("CUDA Python backend does not support image reads yet") + return bytes(max(0, int(out_size))) + + +def stage_fft_plan_create( + context, + dims, + axes, + buffer_size, + do_r2c, + normalize, + pad_left, + pad_right, + frequency_zeropadding, + kernel_num, + kernel_convolution, + conjugate_convolution, + convolution_features, + input_buffer_size, + num_batches, + single_kernel_multiple_batches, + keep_shader_code, +): + _ = context + _ = dims + _ = axes + _ = buffer_size + _ = do_r2c + _ = normalize + _ = pad_left + _ = pad_right + _ = frequency_zeropadding + _ = kernel_num + _ = kernel_convolution + _ = conjugate_convolution + _ = convolution_features + _ = input_buffer_size + _ = num_batches + _ = single_kernel_multiple_batches + _ = keep_shader_code + set_error("CUDA Python backend does not support FFT plans yet") + return 0 + + +def stage_fft_plan_destroy(plan): + _ = plan + set_error("CUDA Python backend does not support FFT plans yet") + + +def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer): + _ = command_list + _ = plan + _ = buffer + _ = inverse + _ = kernel + _ = input_buffer + set_error("CUDA Python backend does not support FFT stages yet") diff --git a/vkdispatch/backends/cuda_backend/signal.py b/vkdispatch/backends/cuda_backend/signal.py new file mode 100644 index 00000000..6dfbca35 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/signal.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from . import state as state +from .helpers import ( + activate_context, + context_from_handle, + queue_indices, + set_error, + stream_for_queue, +) + +from typing import Optional + +from .cuda_primitives import cuda +from .handle import CUDAHandle, HandleRegistry + +_signals: HandleRegistry = HandleRegistry() + +class CUDASignal(CUDAHandle): + context_handle: int + queue_index: int + event: Optional["cuda.Event"] = None + submitted: bool = True + done: bool = True + + def __init__(self, + context_handle: int, + queue_index: int, + event: Optional["cuda.Event"] = None, + submitted: bool = True, + done: bool = True): + super().__init__(_signals) + + self.context_handle = context_handle + self.queue_index = queue_index + self.event = event + self.submitted = submitted + self.done = done + + @staticmethod + def from_handle(handle: int) -> Optional["CUDASignal"]: + return _signals.get(handle) + + def record(self, stream: "cuda.Stream"): + self.submitted = True + self.done = False + if self.event is None: + self.event = cuda.Event() + self.event.record(stream) + + def query(self) -> bool: + if self.event is None: + return bool(self.done) + + try: + done = self.event.query() + except Exception: + return False + + self.done = bool(done) + return self.done + +def signal_wait(signal_ptr, wait_for_timestamp, queue_index): + signal_obj = CUDASignal.from_handle(signal_ptr) + if signal_obj is None: + return True + + if not bool(wait_for_timestamp): + # CUDA Python records signals synchronously on submission; host-side "recorded" waits + # should therefore complete immediately once an event exists. + if signal_obj.event is None: + return bool(signal_obj.done) + return bool(signal_obj.submitted) + + if signal_obj.done: + return True + + if signal_obj.event is None: + return bool(signal_obj.done) + + ctx = state.contexts.get(signal_obj.context_handle) + if ctx is None: + return signal_obj.query() + + try: + with activate_context(ctx): + signal_obj.event.synchronize() + signal_obj.done = True + return True + except Exception: + return signal_obj.query() + + +def signal_insert(context, queue_index): + ctx = context_from_handle(int(context)) + if ctx is None: + return 0 + + selected = queue_indices(ctx, int(queue_index)) + if len(selected) == 0: + selected = [0] + + signal = CUDASignal(context_handle=int(context), queue_index=selected[0], submitted=False, done=False) + + try: + with activate_context(ctx): + signal.record(stream_for_queue(ctx, selected[0])) + except Exception as exc: + set_error(f"Failed to insert signal: {exc}") + return 0 + + return signal.handle + + +def signal_destroy(signal_ptr): + _signals.pop(signal_ptr) diff --git a/vkdispatch/backends/cuda_backend/state.py b/vkdispatch/backends/cuda_backend/state.py new file mode 100644 index 00000000..21e7af25 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/state.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +import threading +from typing import Dict, List, Optional, Tuple + +from .constants import LOG_LEVEL_WARNING +from .cuda_primitives import SourceModule, cuda + +#from .api_descriptor import CUDADescriptorSet + +# --- Runtime state --- + +initialized = False +debug_mode = False +log_level = LOG_LEVEL_WARNING +error_string: Optional[str] = None +next_handle = 1 + +contexts: Dict[int, "CUDAContext"] = {} +buffers: Dict[int, "CUDABuffer"] = {} +command_lists: Dict[int, "CUDACommandList"] = {} +compute_plans: Dict[int, "CUDAComputePlan"] = {} +external_stream_cache: Dict[int, object] = {} +stream_override = threading.local() + + +# --- Internal objects --- + +@dataclass +class CUDAContext: + device_index: int + cuda_context: "cuda.Context" + streams: List["cuda.Stream"] + queue_count: int + queue_to_device: List[int] + max_kernel_param_size: int + uses_primary_context: bool = False + stopped: bool = False + + +@dataclass +class CUDABuffer: + context_handle: int + size: int + device_ptr: int + device_allocation: Optional["cuda.DeviceAllocation"] + owns_allocation: bool + staging_data: List[object] + signal_handles: List[int] + + +@dataclass +class CUDACommandRecord: + plan_handle: int + descriptor_set_handle: int + blocks: Tuple[int, int, int] + pc_size: int + + +@dataclass +class CUDACommandList: + context_handle: int + commands: List[CUDACommandRecord] = field(default_factory=list) + + +@dataclass +class CUDAKernelParam: + kind: str + binding: Optional[int] + raw_name: str + + +@dataclass +class CUDAComputePlan: + context_handle: int + shader_source: bytes + bindings: List[int] + shader_name: bytes + module: SourceModule + function: object + local_size: Tuple[int, int, int] + params: List[CUDAKernelParam] + pc_size: int + + + diff --git a/vkdispatch/backends/dummy_backend.py b/vkdispatch/backends/dummy_backend.py new file mode 100644 index 00000000..420a59f8 --- /dev/null +++ b/vkdispatch/backends/dummy_backend.py @@ -0,0 +1,545 @@ +"""Brython-friendly pure-Python shim for ``vkdispatch_native``. + +This module mirrors the Cython-exposed API used by ``vkdispatch`` and provides +dummy metadata helpers for docs/codegen flows. + +Runtime GPU operations are intentionally denied so the dummy backend fails fast +when used outside codegen-only scripts. +""" + +# --- Runtime state --- + +_initialized = False +_debug_mode = False +_log_level = 2 +_error_string = None +_next_handle = 1 + +_contexts = {} +_signals = {} + +# Device limits exposed through get_devices(); mutable so docs UI can tune them. +_DEFAULT_SUBGROUP_SIZE = 32 +_DEFAULT_MAX_WORKGROUP_SIZE = (1024, 1024, 64) +_DEFAULT_MAX_WORKGROUP_INVOCATIONS = 1024 +_DEFAULT_MAX_WORKGROUP_COUNT = (65535, 65535, 65535) +_DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE = 64 * 1024 + +_device_subgroup_size = _DEFAULT_SUBGROUP_SIZE +_device_subgroup_enabled = True +_device_max_workgroup_size = _DEFAULT_MAX_WORKGROUP_SIZE +_device_max_workgroup_invocations = _DEFAULT_MAX_WORKGROUP_INVOCATIONS +_device_max_workgroup_count = _DEFAULT_MAX_WORKGROUP_COUNT +_device_max_compute_shared_memory_size = _DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE + + +# --- Internal objects --- + +class _Signal: + __slots__ = ("done",) + + def __init__(self, done=True): + self.done = bool(done) + + +class _Context: + __slots__ = ( + "device_indices", + "queue_families", + "queue_count", + "queue_to_device", + "stopped", + ) + + def __init__(self, device_indices, queue_families): + self.device_indices = list(device_indices) + self.queue_families = [list(fam) for fam in queue_families] + + normalized = [] + for fam in self.queue_families: + normalized.append(fam if len(fam) > 0 else [0]) + self.queue_families = normalized + + self.queue_count = sum(len(fam) for fam in self.queue_families) + if self.queue_count <= 0: + self.queue_families = [[0]] + self.queue_count = 1 + + queue_to_device = [] + for dev_idx, fam in enumerate(self.queue_families): + for _ in fam: + queue_to_device.append(dev_idx) + + if len(queue_to_device) == 0: + queue_to_device = [0] + + self.queue_to_device = queue_to_device + self.stopped = False + +# --- Internal helpers --- + +def _new_handle(registry, obj): + global _next_handle + handle = _next_handle + _next_handle += 1 + registry[handle] = obj + return handle + +def _set_error(message): + global _error_string + _error_string = str(message) + + +def _clear_error(): + global _error_string + _error_string = None + + +_DUMMY_CODEGEN_ONLY_ERROR = ( + "The 'dummy' backend is codegen-only and does not support runtime GPU " + "operations. Use backend='vulkan', backend='pycuda', or backend='cuda-python' for execution." +) + + +def _deny_runtime_native_call(function_name): + raise RuntimeError(f"{_DUMMY_CODEGEN_ONLY_ERROR} (native call: {function_name})") + + +def _as_positive_int(name, value): + try: + parsed = int(value) + except Exception as exc: + raise ValueError("%s must be an integer" % name) from exc + + if parsed <= 0: + raise ValueError("%s must be greater than zero" % name) + + return parsed + + +def _as_positive_triplet(name, value): + try: + parts = list(value) + except Exception as exc: + raise ValueError("%s must contain exactly 3 integers" % name) from exc + + if len(parts) != 3: + raise ValueError("%s must contain exactly 3 integers" % name) + + return ( + _as_positive_int("%s[0]" % name, parts[0]), + _as_positive_int("%s[1]" % name, parts[1]), + _as_positive_int("%s[2]" % name, parts[2]), + ) + + +# --- API: context/init/errors/logging --- + + +def reset_device_options(): + global _device_subgroup_size + global _device_subgroup_enabled + global _device_max_workgroup_size + global _device_max_workgroup_invocations + global _device_max_workgroup_count + global _device_max_compute_shared_memory_size + + _device_subgroup_size = _DEFAULT_SUBGROUP_SIZE + _device_subgroup_enabled = True + _device_max_workgroup_size = _DEFAULT_MAX_WORKGROUP_SIZE + _device_max_workgroup_invocations = _DEFAULT_MAX_WORKGROUP_INVOCATIONS + _device_max_workgroup_count = _DEFAULT_MAX_WORKGROUP_COUNT + _device_max_compute_shared_memory_size = _DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE + + +def set_device_options( + subgroup_size=None, + subgroup_enabled=None, + max_workgroup_size=None, + max_workgroup_invocations=None, + max_workgroup_count=None, + max_compute_shared_memory_size=None, +): + global _device_subgroup_size + global _device_subgroup_enabled + global _device_max_workgroup_size + global _device_max_workgroup_invocations + global _device_max_workgroup_count + global _device_max_compute_shared_memory_size + + if subgroup_size is not None: + _device_subgroup_size = _as_positive_int("subgroup_size", subgroup_size) + + if subgroup_enabled is not None: + if not isinstance(subgroup_enabled, bool): + raise ValueError("subgroup_enabled must be a boolean value") + _device_subgroup_enabled = subgroup_enabled + + if max_workgroup_size is not None: + _device_max_workgroup_size = _as_positive_triplet( + "max_workgroup_size", + max_workgroup_size, + ) + + if max_workgroup_invocations is not None: + _device_max_workgroup_invocations = _as_positive_int( + "max_workgroup_invocations", + max_workgroup_invocations, + ) + + if max_workgroup_count is not None: + _device_max_workgroup_count = _as_positive_triplet( + "max_workgroup_count", + max_workgroup_count, + ) + + if max_compute_shared_memory_size is not None: + _device_max_compute_shared_memory_size = _as_positive_int( + "max_compute_shared_memory_size", + max_compute_shared_memory_size, + ) + + +def init(debug, log_level): + global _initialized, _debug_mode, _log_level + _initialized = True + _debug_mode = bool(debug) + _log_level = int(log_level) + _clear_error() + + +def log(log_level, text, file_str, line_str): + # Keep logging quiet in docs/brython by default. + # Function kept for API compatibility. + _ = log_level + _ = text + _ = file_str + _ = line_str + + +def set_log_level(log_level): + global _log_level + _log_level = int(log_level) + + +def get_devices(): + if not _initialized: + init(False, _log_level) + + # One plausible fake discrete GPU with compute+graphics queue families. + device_tuple = ( + 0, # version_variant + 1, # version_major + 3, # version_minor + 0, # version_patch + 1001000, # driver_version + 0x1BAD, # vendor_id + 0x0001, # device_id + 2, # device_type (Discrete GPU) + "VKDispatch Web Dummy GPU", + 1, # shader_buffer_float32_atomics + 1, # shader_buffer_float32_atomic_add + 1, # float_64_support + 1, # float_16_support + 1, # int_64_support + 1, # int_16_support + 1, # storage_buffer_16_bit_access + 1, # uniform_and_storage_buffer_16_bit_access + 1, # storage_push_constant_16 + 1, # storage_input_output_16 + _device_max_workgroup_size, # max_workgroup_size + _device_max_workgroup_invocations, # max_workgroup_invocations + _device_max_workgroup_count, # max_workgroup_count + 8, # max_descriptor_set_count + 256, # max_push_constant_size + 1 << 30, # max_storage_buffer_range + 65536, # max_uniform_buffer_range + 0, # uniform_buffer_alignment + _device_subgroup_size, # subgroup_size + 0x7FFFFFFF if _device_subgroup_enabled else 0, # supported_stages + 0x7FFFFFFF if _device_subgroup_enabled else 0, # supported_operations + 1, # quad_operations_in_all_stages + _device_max_compute_shared_memory_size, # max_compute_shared_memory_size + [ + (8, 0x006), # compute + transfer + (4, 0x007), # graphics + compute + transfer + ], + 1, # scalar_block_layout + 1, # timeline_semaphores + bytes((0x56, 0x4B, 0x44, 0x30, 0x57, 0x45, 0x42, 0x31, 0x44, 0x55, 0x4D, 0x4D, 0x59, 0x00, 0x00, 0x01)), + ) + + return [device_tuple] + + +def context_create(device_indicies, queue_families): + try: + ctx = _Context(device_indicies, queue_families) + return _new_handle(_contexts, ctx) + except Exception as exc: + _set_error("Failed to create context: %s" % exc) + return 0 + + +def signal_wait(signal_ptr, wait_for_timestamp, queue_index): + _ = wait_for_timestamp + _ = queue_index + signal_obj = _signals.get(int(signal_ptr)) + if signal_obj is None: + return True + return bool(signal_obj.done) + + +def signal_insert(context, queue_index): + _ = context + _ = queue_index + return _new_handle(_signals, _Signal(done=True)) + + +def signal_destroy(signal_ptr): + _signals.pop(int(signal_ptr), None) + + +def context_destroy(context): + _contexts.pop(int(context), None) + + +def get_error_string(): + if _error_string is None: + return 0 + return _error_string + + +def context_stop_threads(context): + ctx = _contexts.get(int(context)) + if ctx is not None: + ctx.stopped = True + + +# --- API: buffers --- + + +def buffer_create(context, size, per_device): + _deny_runtime_native_call("buffer_create") + + +def buffer_destroy(buffer): + _deny_runtime_native_call("buffer_destroy") + + +def buffer_get_queue_signal(buffer, queue_index): + _deny_runtime_native_call("buffer_get_queue_signal") + + +def buffer_wait_staging_idle(buffer, queue_index): + _deny_runtime_native_call("buffer_wait_staging_idle") + + +def buffer_write_staging(buffer, queue_index, data, size): + _deny_runtime_native_call("buffer_write_staging") + + +def buffer_read_staging(buffer, queue_index, size): + _deny_runtime_native_call("buffer_read_staging") + + +def buffer_write(buffer, offset, size, index): + _deny_runtime_native_call("buffer_write") + + +def buffer_read(buffer, offset, size, index): + _deny_runtime_native_call("buffer_read") + + +# --- API: command lists --- + + +def command_list_create(context): + _deny_runtime_native_call("command_list_create") + + +def command_list_destroy(command_list): + _deny_runtime_native_call("command_list_destroy") + + +def command_list_get_instance_size(command_list): + _deny_runtime_native_call("command_list_get_instance_size") + + +def command_list_reset(command_list): + _deny_runtime_native_call("command_list_reset") + + +def command_list_submit(command_list, data, instance_count, index): + _deny_runtime_native_call("command_list_submit") + + +# --- API: descriptor sets --- + + +def descriptor_set_create(plan): + _deny_runtime_native_call("descriptor_set_create") + + +def descriptor_set_destroy(descriptor_set): + _deny_runtime_native_call("descriptor_set_destroy") + + +def descriptor_set_write_buffer( + descriptor_set, + binding, + object, + offset, + range, + uniform, + read_access, + write_access, +): + _deny_runtime_native_call("descriptor_set_write_buffer") + + +def descriptor_set_write_image( + descriptor_set, + binding, + object, + sampler_obj, + read_access, + write_access, +): + _deny_runtime_native_call("descriptor_set_write_image") + + +# --- API: images/samplers --- + + +def image_create(context, extent, layers, format, type, view_type, generate_mips): + _deny_runtime_native_call("image_create") + + +def image_destroy(image): + _deny_runtime_native_call("image_destroy") + + +def image_create_sampler( + context, + mag_filter, + min_filter, + mip_mode, + address_mode, + mip_lod_bias, + min_lod, + max_lod, + border_color, +): + _deny_runtime_native_call("image_create_sampler") + + +def image_destroy_sampler(sampler): + _deny_runtime_native_call("image_destroy_sampler") + + +def image_write(image, data, offset, extent, baseLayer, layerCount, device_index): + _deny_runtime_native_call("image_write") + + +def image_format_block_size(format): + _deny_runtime_native_call("image_format_block_size") + + +def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index): + _deny_runtime_native_call("image_read") + + +# --- API: compute stage --- + + +def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name): + _deny_runtime_native_call("stage_compute_plan_create") + + +def stage_compute_plan_destroy(plan): + _deny_runtime_native_call("stage_compute_plan_destroy") + + +def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z): + _deny_runtime_native_call("stage_compute_record") + + +# --- API: FFT stage --- + + +def stage_fft_plan_create( + context, + dims, + axes, + buffer_size, + do_r2c, + normalize, + pad_left, + pad_right, + frequency_zeropadding, + kernel_num, + kernel_convolution, + conjugate_convolution, + convolution_features, + input_buffer_size, + num_batches, + single_kernel_multiple_batches, + keep_shader_code, +): + _deny_runtime_native_call("stage_fft_plan_create") + + +def stage_fft_plan_destroy(plan): + _deny_runtime_native_call("stage_fft_plan_destroy") + + +def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer): + _deny_runtime_native_call("stage_fft_record") + + +__all__ = [ + "reset_device_options", + "set_device_options", + "init", + "log", + "set_log_level", + "get_devices", + "context_create", + "signal_wait", + "signal_insert", + "signal_destroy", + "context_destroy", + "get_error_string", + "context_stop_threads", + "buffer_create", + "buffer_destroy", + "buffer_get_queue_signal", + "buffer_wait_staging_idle", + "buffer_write_staging", + "buffer_read_staging", + "buffer_write", + "buffer_read", + "command_list_create", + "command_list_destroy", + "command_list_get_instance_size", + "command_list_reset", + "command_list_submit", + "descriptor_set_create", + "descriptor_set_destroy", + "descriptor_set_write_buffer", + "descriptor_set_write_image", + "image_create", + "image_destroy", + "image_create_sampler", + "image_destroy_sampler", + "image_write", + "image_format_block_size", + "image_read", + "stage_compute_plan_create", + "stage_compute_plan_destroy", + "stage_compute_record", + "stage_fft_plan_create", + "stage_fft_plan_destroy", + "stage_fft_record" +] diff --git a/vkdispatch/backends/opencl_backend.py b/vkdispatch/backends/opencl_backend.py new file mode 100644 index 00000000..eed638a3 --- /dev/null +++ b/vkdispatch/backends/opencl_backend.py @@ -0,0 +1,2080 @@ +"""pyopencl-backed runtime shim mirroring the vkdispatch_native API surface. + +This module intentionally matches the function names exposed by the Cython +extension so existing Python runtime objects can call into either backend. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import hashlib +import re +import threading +from typing import Dict, List, Optional, Tuple + +import os +import sys + +try: + import numpy as np +except Exception as exc: # pragma: no cover - import failure path + raise ImportError( + "The OpenCL Python backend requires both 'pyopencl' and 'numpy' to be installed." + ) from exc + +try: + import pyopencl as cl +except Exception as exc: # pragma: no cover - import failure path + raise ImportError( + "The OpenCL runtime backend requires the 'pyopencl' package " + "(`pip install pyopencl`)." + ) from exc + + +# Log level constants mirrored from native bindings. +LOG_LEVEL_VERBOSE = 0 +LOG_LEVEL_INFO = 1 +LOG_LEVEL_WARNING = 2 +LOG_LEVEL_ERROR = 3 + +# Descriptor type enum values mirrored from vkdispatch_native/stages_extern.pxd. +DESCRIPTOR_TYPE_STORAGE_BUFFER = 1 +DESCRIPTOR_TYPE_STORAGE_IMAGE = 2 +DESCRIPTOR_TYPE_UNIFORM_BUFFER = 3 +DESCRIPTOR_TYPE_UNIFORM_IMAGE = 4 +DESCRIPTOR_TYPE_SAMPLER = 5 + +# Image format block sizes for formats exposed in vkdispatch.base.image.image_format. +_IMAGE_BLOCK_SIZES = { + 13: 1, + 14: 1, + 20: 2, + 21: 2, + 27: 3, + 28: 3, + 41: 4, + 42: 4, + 74: 2, + 75: 2, + 76: 2, + 81: 4, + 82: 4, + 83: 4, + 88: 6, + 89: 6, + 90: 6, + 95: 8, + 96: 8, + 97: 8, + 98: 4, + 99: 4, + 100: 4, + 101: 8, + 102: 8, + 103: 8, + 104: 12, + 105: 12, + 106: 12, + 107: 16, + 108: 16, + 109: 16, + 110: 8, + 111: 8, + 112: 8, + 113: 16, + 114: 16, + 115: 16, + 116: 24, + 117: 24, + 118: 24, + 119: 32, + 120: 32, + 121: 32, +} + +_LOCAL_X_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_X\s+(\d+)") +_LOCAL_Y_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Y\s+(\d+)") +_LOCAL_Z_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Z\s+(\d+)") +_REQD_LOCAL_RE = re.compile(r"reqd_work_group_size\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)") +_KERNEL_SIGNATURE_RE = re.compile(r"vkdispatch_main\s*\(([^)]*)\)", re.S) +_BINDING_PARAM_RE = re.compile(r"vkdispatch_binding_(\d+)_ptr$") +_SAMPLER_PARAM_RE = re.compile(r"vkdispatch_sampler_(\d+)$") +_PUSH_CONSTANT_STRUCT_RE = re.compile( + r"typedef\s+struct\s+PushConstant\s*\{(?P.*?)\}\s*PushConstant\s*;", + re.S, +) +_PUSH_CONSTANT_FIELD_RE = re.compile( + r"(?P[A-Za-z_][A-Za-z0-9_]*)\s+" + r"(?P[A-Za-z_][A-Za-z0-9_]*)" + r"(?:\s*\[\s*(?P\d+)\s*\])?$" +) +_VECTOR_TYPE_RE = re.compile(r"([A-Za-z_][A-Za-z0-9_]*?)([2-4])$") +_OPENCL_VERSION_RE = re.compile(r"OpenCL\s+(\d+)\.(\d+)") +_DIGIT_RE = re.compile(r"(\d+)") +_OPENCL_MAX_INFLIGHT_SUBMISSIONS = 4 +_OPENCL_SUBGROUP_PROBE_SOURCE = """ +__kernel void vkdispatch_subgroup_probe(__global uint *out) { + size_t gid = get_global_id(0); + if (gid == 0) { + out[0] = 0u; + } +} +""" + + +# --- Runtime state --- + +_initialized = False +_debug_mode = False +_log_level = LOG_LEVEL_WARNING +_error_string: Optional[str] = None +_next_handle = 1 + +_contexts: Dict[int, "_Context"] = {} +_signals: Dict[int, "_Signal"] = {} +_buffers: Dict[int, "_Buffer"] = {} +_command_lists: Dict[int, "_CommandList"] = {} +_compute_plans: Dict[int, "_ComputePlan"] = {} +_descriptor_sets: Dict[int, "_DescriptorSet"] = {} +_images: Dict[int, object] = {} +_samplers: Dict[int, object] = {} +_fft_plans: Dict[int, object] = {} +_subgroup_size_cache: Dict[Tuple[int, int, str, str], int] = {} + +_marker_helpers = threading.local() + + +# --- Internal objects --- + + +@dataclass(frozen=True) +class _DeviceEntry: + logical_index: int + platform_index: int + device_index: int + platform: object + device: object + + +@dataclass +class _Signal: + context_handle: int + queue_index: int + event: Optional[object] = None + submitted: bool = True + done: bool = True + + +@dataclass +class _Context: + device_index: int + cl_context: object + queues: List[object] + queue_count: int + queue_to_device: List[int] + sub_buffer_alignment: int + submission_events: List[List[object]] = field(default_factory=list) + stopped: bool = False + + +@dataclass +class _Buffer: + context_handle: int + size: int + cl_buffer: object + staging_data: List[bytearray] + signal_handles: List[int] + + +@dataclass +class _CommandRecord: + plan_handle: int + descriptor_set_handle: int + blocks: Tuple[int, int, int] + pc_size: int + + +@dataclass +class _CommandList: + context_handle: int + commands: List[_CommandRecord] = field(default_factory=list) + + +@dataclass +class _KernelParam: + kind: str + binding: Optional[int] + raw_name: str + + +@dataclass(frozen=True) +class _PushConstantTypeLayout: + host_elem_size: int + opencl_elem_size: int + opencl_align: int + + +@dataclass(frozen=True) +class _PushConstantFieldDecl: + type_name: str + field_name: str + count: int + + +@dataclass(frozen=True) +class _PushConstantFieldLayout: + type_name: str + field_name: str + count: int + host_offset: int + opencl_offset: int + host_elem_size: int + opencl_elem_size: int + + +@dataclass(frozen=True) +class _PushConstantLayout: + fields: Tuple[_PushConstantFieldLayout, ...] + host_size: int + opencl_size: int + opencl_alignment: int + needs_repack: bool + + +@dataclass +class _ComputePlan: + context_handle: int + shader_source: bytes + bindings: List[int] + shader_name: bytes + program: object + kernel: object + local_size: Tuple[int, int, int] + params: List[_KernelParam] + pc_size: int + pc_layout: Optional[_PushConstantLayout] = None + + +@dataclass +class _DescriptorSet: + plan_handle: int + buffer_bindings: Dict[int, Tuple[int, int, int, int, int, int]] = field(default_factory=dict) + image_bindings: Dict[int, Tuple[int, int, int, int]] = field(default_factory=dict) + + +# --- Helper utilities --- + + +def _new_handle(registry: Dict[int, object], obj: object) -> int: + global _next_handle + handle = _next_handle + _next_handle += 1 + registry[handle] = obj + return handle + + +def _to_bytes(value) -> bytes: + if value is None: + return b"" + if isinstance(value, bytes): + return value + if isinstance(value, bytearray): + return bytes(value) + if isinstance(value, memoryview): + return value.tobytes() + return bytes(value) + + +def _set_error(message: str) -> None: + global _error_string + _error_string = str(message) + + +def _clear_error() -> None: + global _error_string + _error_string = None + +def _enumerate_opencl_devices() -> List[_DeviceEntry]: + entries: List[_DeviceEntry] = [] + + if ( + sys.platform.startswith("linux") + and "OCL_ICD_VENDORS" not in os.environ + and "OPENCL_VENDOR_PATH" not in os.environ + and os.path.isdir("/etc/OpenCL/vendors") + ): + os.environ["OCL_ICD_VENDORS"] = "/etc/OpenCL/vendors" + + try: + platforms = cl.get_platforms() + except Exception as exc: + raise RuntimeError( + f"Failed to get OpenCL Platform: {exc}" + ) from exc + + logical_index = 0 + for platform_index, platform in enumerate(platforms): + try: + devices = platform.get_devices() + except Exception: + continue + + for device_index, device in enumerate(devices): + entries.append( + _DeviceEntry( + logical_index=logical_index, + platform_index=platform_index, + device_index=device_index, + platform=platform, + device=device, + ) + ) + logical_index += 1 + + return entries + + +def _coerce_int(value, fallback: int = 0) -> int: + try: + return int(value) + except Exception: + return int(fallback) + + +def _align_up(value: int, alignment: int) -> int: + if alignment <= 1: + return int(value) + return ((int(value) + alignment - 1) // alignment) * alignment + + +def _opencl_version_components(version_text: str) -> Tuple[int, int]: + if not isinstance(version_text, str): + return (0, 0) + + match = _OPENCL_VERSION_RE.search(version_text) + if match is None: + return (0, 0) + + return (_coerce_int(match.group(1), 0), _coerce_int(match.group(2), 0)) + + +def _driver_version_number(driver_text: str) -> int: + if not isinstance(driver_text, str): + return 0 + + pieces = _DIGIT_RE.findall(driver_text) + if len(pieces) == 0: + return 0 + + folded = 0 + weight = 1_000_000 + for token in pieces[:3]: + folded += _coerce_int(token, 0) * weight + weight = max(1, weight // 1000) + return folded + + +def _device_type_to_vkdispatch(device_type: int) -> int: + if device_type & getattr(cl.device_type, "GPU", 0): + return 2 + if device_type & getattr(cl.device_type, "ACCELERATOR", 0): + return 3 + if device_type & getattr(cl.device_type, "CPU", 0): + return 4 + return 0 + + +def _device_uuid(entry: _DeviceEntry, device_name: str, driver_version: str) -> bytes: + platform_vendor = "" + platform_name = "" + try: + platform_vendor = str(entry.platform.vendor) + except Exception: + platform_vendor = "" + try: + platform_name = str(entry.platform.name) + except Exception: + platform_name = "" + + seed = ( + f"opencl:{entry.platform_index}:{entry.device_index}:" + f"{platform_vendor}:" + f"{platform_name}:" + f"{device_name}:{driver_version}" + ) + return hashlib.md5(seed.encode("utf-8")).digest() + + +def _device_attr(device, attr_name: str, default): + try: + return getattr(device, attr_name) + except Exception: + return default + + +def _release_opencl_object(obj: object) -> None: + release = getattr(obj, "release", None) + if callable(release): + try: + release() + except Exception: + pass + + +def _device_identity_key(entry: _DeviceEntry, device_name: str, driver_version: str) -> Tuple[int, int, str, str]: + return (int(entry.platform_index), int(entry.device_index), str(device_name), str(driver_version)) + + +def _kernel_preferred_workgroup_multiple(device) -> Optional[int]: + ctx = None + program = None + kernel = None + + try: + ctx = cl.Context(devices=[device]) + program = cl.Program(ctx, _OPENCL_SUBGROUP_PROBE_SOURCE).build() + kernel = cl.Kernel(program, "vkdispatch_subgroup_probe") + multiple = kernel.get_work_group_info( + cl.kernel_work_group_info.PREFERRED_WORK_GROUP_SIZE_MULTIPLE, + device, + ) + multiple_int = _coerce_int(multiple, 0) + if multiple_int > 0: + return multiple_int + except Exception: + return None + finally: + _release_opencl_object(kernel) + _release_opencl_object(program) + _release_opencl_object(ctx) + + return None + + +def _round_down_power_of_two(value: int) -> int: + value = int(value) + if value <= 1: + return 1 + return 1 << (value.bit_length() - 1) + + +def _vendor_subgroup_fallback( + *, + device_type: int, + vendor_text: str, + platform_name: str, + device_name: str, + max_workgroup_invocations: int, +) -> int: + if device_type == 4: + return 1 + + combined = " ".join( + token.lower() + for token in (vendor_text, platform_name, device_name) + if isinstance(token, str) and len(token) > 0 + ) + + if "nvidia" in combined: + return 32 + + if "advanced micro devices" in combined or " amd" in f" {combined}" or "radeon" in combined: + return 64 + + if "apple" in combined or "m1" in combined or "m2" in combined or "m3" in combined or "m4" in combined: + return 32 + + if "intel" in combined: + return 16 if device_type == 2 else 1 + + if device_type == 2: + bounded = min(max(1, int(max_workgroup_invocations)), 64) + if bounded >= 32: + return 32 + return _round_down_power_of_two(bounded) + + return 1 + + +def _estimate_subgroup_size( + entry: _DeviceEntry, + device, + *, + device_name: str, + driver_version: str, + device_type: int, + max_workgroup_invocations: int, +) -> int: + cache_key = _device_identity_key(entry, device_name, driver_version) + cached = _subgroup_size_cache.get(cache_key) + if cached is not None: + return cached + + platform_name = str(_device_attr(entry.platform, "name", "")) + vendor_text = str(_device_attr(device, "vendor", _device_attr(entry.platform, "vendor", ""))) + + subgroup_size = _kernel_preferred_workgroup_multiple(device) + if subgroup_size is None: + subgroup_size = _vendor_subgroup_fallback( + device_type=device_type, + vendor_text=vendor_text, + platform_name=platform_name, + device_name=device_name, + max_workgroup_invocations=max_workgroup_invocations, + ) + + subgroup_size = max(1, int(subgroup_size)) + _subgroup_size_cache[cache_key] = subgroup_size + return subgroup_size + + +def _context_from_handle(context_handle: int) -> Optional[_Context]: + ctx = _contexts.get(int(context_handle)) + if ctx is None: + _set_error(f"Invalid context handle {context_handle}") + return ctx + + +def _queue_indices(ctx: _Context, queue_index: int, *, all_on_negative: bool = False) -> List[int]: + if ctx.queue_count <= 0: + return [] + + if queue_index is None: + return [0] + + queue_index = int(queue_index) + + if all_on_negative and queue_index < 0: + return list(range(ctx.queue_count)) + + if queue_index == -1: + return [0] + + if 0 <= queue_index < ctx.queue_count: + return [queue_index] + + return [] + + +def _record_signal(signal: _Signal, event_obj: Optional[object]) -> None: + if signal.event is not None and signal.event is not event_obj: + try: + signal.event.release() + except Exception: + pass + signal.submitted = True + signal.done = event_obj is None + signal.event = event_obj + + +def _query_event_done(event_obj: Optional[object]) -> bool: + if event_obj is None: + return True + + try: + complete = int(getattr(getattr(cl, "command_execution_status", object()), "COMPLETE", 0)) + status = _coerce_int(event_obj.command_execution_status, 0) + return status == complete + except Exception: + return False + + +def _query_signal(signal: _Signal) -> bool: + signal.done = _query_event_done(signal.event) if signal.event is not None else bool(signal.done) + return signal.done + + +def _wait_signal(signal: _Signal) -> bool: + if signal.event is None: + return bool(signal.done) + + try: + signal.event.wait() + signal.done = True + return True + except Exception: + return _query_signal(signal) + + +def _parse_local_size(source: str) -> Tuple[int, int, int]: + x_match = _LOCAL_X_RE.search(source) + y_match = _LOCAL_Y_RE.search(source) + z_match = _LOCAL_Z_RE.search(source) + + if x_match is not None and y_match is not None and z_match is not None: + return ( + _coerce_int(x_match.group(1), 1), + _coerce_int(y_match.group(1), 1), + _coerce_int(z_match.group(1), 1), + ) + + reqd_match = _REQD_LOCAL_RE.search(source) + if reqd_match is not None: + return ( + _coerce_int(reqd_match.group(1), 1), + _coerce_int(reqd_match.group(2), 1), + _coerce_int(reqd_match.group(3), 1), + ) + + return (1, 1, 1) + + +def _opencl_device_launch_limits(logical_device_index: int) -> Tuple[Tuple[int, int, int], int]: + entries = _enumerate_opencl_devices() + if logical_device_index < 0 or logical_device_index >= len(entries): + raise RuntimeError( + f"OpenCL device index {logical_device_index} is out of range for launch validation" + ) + + device = entries[logical_device_index].device + max_work_item_sizes = tuple( + _coerce_int(x, 1) + for x in _device_attr(device, "max_work_item_sizes", (1, 1, 1)) + ) + + if len(max_work_item_sizes) < 3: + max_work_item_sizes = (max_work_item_sizes + (1, 1, 1))[:3] + else: + max_work_item_sizes = max_work_item_sizes[:3] + + max_workgroup_size = ( + max(1, int(max_work_item_sizes[0])), + max(1, int(max_work_item_sizes[1])), + max(1, int(max_work_item_sizes[2])), + ) + max_workgroup_invocations = max( + 1, + _coerce_int(_device_attr(device, "max_work_group_size", 1), 1), + ) + + return max_workgroup_size, max_workgroup_invocations + + +def _validate_local_size_for_enqueue(ctx: _Context, local_size: Tuple[int, int, int]) -> None: + max_workgroup_size, max_workgroup_invocations = _opencl_device_launch_limits(ctx.device_index) + local_x, local_y, local_z = (max(1, int(dim)) for dim in local_size) + local_invocations = local_x * local_y * local_z + + violations = [] + if local_x > max_workgroup_size[0]: + violations.append(f"x={local_x} exceeds {max_workgroup_size[0]}") + if local_y > max_workgroup_size[1]: + violations.append(f"y={local_y} exceeds {max_workgroup_size[1]}") + if local_z > max_workgroup_size[2]: + violations.append(f"z={local_z} exceeds {max_workgroup_size[2]}") + if local_invocations > max_workgroup_invocations: + violations.append( + f"total invocations={local_invocations} exceeds {max_workgroup_invocations}" + ) + + if violations: + raise RuntimeError( + "OpenCL local size is invalid for the active device: " + f"requested ({local_x}, {local_y}, {local_z}), " + f"device limits {max_workgroup_size} with max_work_group_size=" + f"{max_workgroup_invocations} ({'; '.join(violations)})" + ) + + +_PUSH_CONSTANT_SCALAR_LAYOUTS: Dict[str, Tuple[int, int]] = { + "char": (1, 1), + "uchar": (1, 1), + "short": (2, 2), + "ushort": (2, 2), + "int": (4, 4), + "uint": (4, 4), + "long": (8, 8), + "ulong": (8, 8), + "half": (2, 2), + "float": (4, 4), + "double": (8, 8), +} + +_PUSH_CONSTANT_MATRIX_LAYOUTS: Dict[str, _PushConstantTypeLayout] = { + "vkdispatch_mat2": _PushConstantTypeLayout(host_elem_size=16, opencl_elem_size=16, opencl_align=8), + "vkdispatch_mat3": _PushConstantTypeLayout(host_elem_size=36, opencl_elem_size=36, opencl_align=1), + "vkdispatch_mat4": _PushConstantTypeLayout(host_elem_size=64, opencl_elem_size=64, opencl_align=16), + "vkdispatch_packed_float3": _PushConstantTypeLayout(host_elem_size=12, opencl_elem_size=12, opencl_align=1), +} + + +def _extract_push_constant_struct_body(source: str) -> Optional[str]: + struct_match = _PUSH_CONSTANT_STRUCT_RE.search(source) + if struct_match is None: + return None + return struct_match.group("body") + + +def _parse_push_constant_struct_fields(body: str) -> List[_PushConstantFieldDecl]: + fields: List[_PushConstantFieldDecl] = [] + + for raw_decl in body.split(";"): + decl = " ".join(raw_decl.strip().split()) + if len(decl) == 0: + continue + + field_match = _PUSH_CONSTANT_FIELD_RE.fullmatch(decl) + if field_match is None: + raise RuntimeError(f"Unable to parse PushConstant field declaration '{decl}'") + + type_name = field_match.group("type") + field_name = field_match.group("name") + count_token = field_match.group("count") + count = 1 if count_token is None else _coerce_int(count_token, 0) + + if count <= 0: + raise RuntimeError(f"Invalid PushConstant array size for field '{field_name}'") + + fields.append(_PushConstantFieldDecl(type_name=type_name, field_name=field_name, count=count)) + + return fields + + +def _push_constant_type_layout(type_name: str) -> _PushConstantTypeLayout: + matrix_layout = _PUSH_CONSTANT_MATRIX_LAYOUTS.get(type_name) + if matrix_layout is not None: + return matrix_layout + + scalar_layout = _PUSH_CONSTANT_SCALAR_LAYOUTS.get(type_name) + if scalar_layout is not None: + size, align = scalar_layout + return _PushConstantTypeLayout(host_elem_size=size, opencl_elem_size=size, opencl_align=align) + + vector_match = _VECTOR_TYPE_RE.fullmatch(type_name) + if vector_match is not None: + scalar_name = vector_match.group(1) + lane_count = _coerce_int(vector_match.group(2), 0) + scalar_info = _PUSH_CONSTANT_SCALAR_LAYOUTS.get(scalar_name) + if scalar_info is None: + raise RuntimeError(f"Unsupported PushConstant vector scalar type '{scalar_name}'") + + scalar_size, _scalar_align = scalar_info + host_elem_size = scalar_size * lane_count + + if lane_count == 3: + opencl_elem_size = scalar_size * 4 + opencl_align = scalar_size * 4 + else: + opencl_elem_size = host_elem_size + opencl_align = opencl_elem_size + + return _PushConstantTypeLayout( + host_elem_size=host_elem_size, + opencl_elem_size=opencl_elem_size, + opencl_align=opencl_align, + ) + + raise RuntimeError(f"Unsupported PushConstant field type '{type_name}'") + + +def _compute_push_constant_layout(field_decls: List[_PushConstantFieldDecl]) -> _PushConstantLayout: + host_offset = 0 + opencl_offset = 0 + max_opencl_align = 1 + needs_repack = False + field_layouts: List[_PushConstantFieldLayout] = [] + + for field_decl in field_decls: + type_layout = _push_constant_type_layout(field_decl.type_name) + + opencl_offset = _align_up(opencl_offset, type_layout.opencl_align) + + if type_layout.opencl_align > max_opencl_align: + max_opencl_align = type_layout.opencl_align + + if host_offset != opencl_offset: + needs_repack = True + if type_layout.host_elem_size != type_layout.opencl_elem_size: + needs_repack = True + + field_layouts.append( + _PushConstantFieldLayout( + type_name=field_decl.type_name, + field_name=field_decl.field_name, + count=field_decl.count, + host_offset=host_offset, + opencl_offset=opencl_offset, + host_elem_size=type_layout.host_elem_size, + opencl_elem_size=type_layout.opencl_elem_size, + ) + ) + + host_offset += type_layout.host_elem_size * field_decl.count + opencl_offset += type_layout.opencl_elem_size * field_decl.count + + opencl_size = _align_up(opencl_offset, max_opencl_align) + if opencl_size != host_offset: + needs_repack = True + + return _PushConstantLayout( + fields=tuple(field_layouts), + host_size=host_offset, + opencl_size=opencl_size, + opencl_alignment=max_opencl_align, + needs_repack=needs_repack, + ) + + +def _build_push_constant_layout(source: str, expected_host_size: int) -> Optional[_PushConstantLayout]: + expected_host_size = int(expected_host_size) + if expected_host_size <= 0: + return None + + body = _extract_push_constant_struct_body(source) + if body is None: + raise RuntimeError("Could not find PushConstant struct declaration in OpenCL source") + + field_decls = _parse_push_constant_struct_fields(body) + if len(field_decls) == 0: + raise RuntimeError("PushConstant struct declaration is empty") + + layout = _compute_push_constant_layout(field_decls) + if layout.host_size != expected_host_size: + raise RuntimeError( + f"PushConstant host layout mismatch. Expected {expected_host_size} bytes " + f"but parsed {layout.host_size} bytes from OpenCL source." + ) + + return layout + + +def _repack_push_constant_payload( + push_constant_payload: bytes, + layout: Optional[_PushConstantLayout], +) -> bytes: + payload = _to_bytes(push_constant_payload) + + if layout is None or not layout.needs_repack: + return payload + + if len(payload) != int(layout.host_size): + raise RuntimeError( + f"PushConstant payload length mismatch for repack. " + f"Expected {layout.host_size} bytes but got {len(payload)} bytes." + ) + + out = bytearray(int(layout.opencl_size)) + + for field in layout.fields: + if field.host_elem_size > field.opencl_elem_size: + raise RuntimeError( + f"PushConstant field '{field.field_name}' host element size ({field.host_elem_size}) " + f"exceeds OpenCL ABI element size ({field.opencl_elem_size})." + ) + + for element_index in range(int(field.count)): + host_start = field.host_offset + (element_index * field.host_elem_size) + host_end = host_start + field.host_elem_size + opencl_start = field.opencl_offset + (element_index * field.opencl_elem_size) + opencl_end = opencl_start + field.host_elem_size + out[opencl_start:opencl_end] = payload[host_start:host_end] + + return bytes(out) + + +def _parse_kernel_params(source: str) -> List[_KernelParam]: + signature_match = _KERNEL_SIGNATURE_RE.search(source) + if signature_match is None: + raise RuntimeError("Could not find vkdispatch_main kernel signature in OpenCL source") + + signature_blob = signature_match.group(1).strip() + if len(signature_blob) == 0: + return [] + + params: List[_KernelParam] = [] + + for raw_decl in [part.strip() for part in signature_blob.split(",") if len(part.strip()) > 0]: + name_match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)\s*$", raw_decl) + if name_match is None: + raise RuntimeError(f"Unable to parse kernel parameter declaration '{raw_decl}'") + + param_name = name_match.group(1) + + if param_name == "vkdispatch_uniform_ptr": + params.append(_KernelParam("uniform", 0, param_name)) + continue + + if param_name == "vkdispatch_pc_value": + params.append(_KernelParam("push_constant_value", None, param_name)) + continue + + binding_match = _BINDING_PARAM_RE.match(param_name) + if binding_match is not None: + params.append(_KernelParam("storage", _coerce_int(binding_match.group(1), 0), param_name)) + continue + + sampler_match = _SAMPLER_PARAM_RE.match(param_name) + if sampler_match is not None: + params.append(_KernelParam("sampler", _coerce_int(sampler_match.group(1), 0), param_name)) + continue + + params.append(_KernelParam("unknown", None, param_name)) + + return params + + +def _buffer_access_flags(read_access: int, write_access: int) -> int: + read_enabled = int(read_access) != 0 + write_enabled = int(write_access) != 0 + + if read_enabled and not write_enabled: + return int(cl.mem_flags.READ_ONLY) + if write_enabled and not read_enabled: + return int(cl.mem_flags.WRITE_ONLY) + return int(cl.mem_flags.READ_WRITE) + + +def _resolve_descriptor_buffer( + descriptor_set: _DescriptorSet, + binding: int, + ctx: _Context, + keepalive: List[object], +): + binding_info = descriptor_set.buffer_bindings.get(int(binding)) + if binding_info is None: + raise RuntimeError(f"Missing descriptor buffer binding {binding}") + + buffer_handle, offset, requested_range, _uniform, read_access, write_access = binding_info + + buffer_obj = _buffers.get(int(buffer_handle)) + if buffer_obj is None: + raise RuntimeError(f"Invalid buffer handle {buffer_handle} for binding {binding}") + + offset = int(offset) + requested_range = int(requested_range) + + if offset < 0: + raise RuntimeError(f"Negative descriptor offset {offset} for binding {binding}") + + max_size = int(buffer_obj.size) + if offset > max_size: + raise RuntimeError(f"Descriptor offset {offset} exceeds buffer size {max_size} for binding {binding}") + + sub_size = max_size - offset if requested_range <= 0 else requested_range + if sub_size < 0: + raise RuntimeError(f"Invalid descriptor range {sub_size} for binding {binding}") + + if offset + sub_size > max_size: + raise RuntimeError( + f"Descriptor range (offset={offset}, size={sub_size}) exceeds buffer size {max_size} for binding {binding}" + ) + + if offset == 0 and sub_size == max_size: + return buffer_obj.cl_buffer + + if (offset % ctx.sub_buffer_alignment) != 0: + raise RuntimeError( + f"Descriptor offset {offset} for binding {binding} is not aligned to " + f"{ctx.sub_buffer_alignment} bytes required by this OpenCL device" + ) + + sub_buffer = buffer_obj.cl_buffer.get_sub_region( + int(offset), + int(sub_size), + _buffer_access_flags(read_access, write_access), + ) + keepalive.append(sub_buffer) + return sub_buffer + + +def _build_kernel_args( + plan: _ComputePlan, + descriptor_set: Optional[_DescriptorSet], + ctx: _Context, + push_constant_payload: bytes = b"", +) -> Tuple[List[object], List[object]]: + args: List[object] = [] + keepalive: List[object] = [] + + for param in plan.params: + if param.kind == "uniform": + if descriptor_set is None: + raise RuntimeError("Kernel requires a descriptor set but none was provided") + args.append(_resolve_descriptor_buffer(descriptor_set, 0, ctx, keepalive)) + continue + + if param.kind == "storage": + if descriptor_set is None: + raise RuntimeError("Kernel requires a descriptor set but none was provided") + if param.binding is None: + raise RuntimeError("Storage parameter has no binding index") + args.append(_resolve_descriptor_buffer(descriptor_set, int(param.binding), ctx, keepalive)) + continue + + if param.kind == "push_constant_value": + if int(plan.pc_size) <= 0: + raise RuntimeError( + f"Kernel parameter '{param.raw_name}' expects push-constant data, but this compute plan has pc_size={plan.pc_size}." + ) + + if len(push_constant_payload) == 0: + raise RuntimeError( + "Missing push-constant payload for OpenCL by-value push-constant parameter " + f"'{param.raw_name}'." + ) + + if len(push_constant_payload) != int(plan.pc_size): + raise RuntimeError( + f"Push-constant payload size mismatch for parameter '{param.raw_name}'. " + f"Expected {plan.pc_size} bytes but got {len(push_constant_payload)} bytes." + ) + + args.append(_repack_push_constant_payload(push_constant_payload, plan.pc_layout)) + continue + + if param.kind == "sampler": + raise RuntimeError("OpenCL backend does not support image/sampler bindings") + + raise RuntimeError( + f"Unsupported kernel parameter '{param.raw_name}'. " + "Expected vkdispatch_uniform_ptr / vkdispatch_pc_value / vkdispatch_binding__ptr." + ) + + return args, keepalive + + +def _marker_wait_functions() -> List[object]: + cached = getattr(_marker_helpers, "funcs", None) + if cached is not None: + return cached + + funcs: List[object] = [] + for fn_name in ( + "enqueue_marker", + "enqueue_marker_with_wait_list", + "enqueue_barrier_with_wait_list", + ): + fn = getattr(cl, fn_name, None) + if fn is not None: + funcs.append(fn) + + _marker_helpers.funcs = funcs + return funcs + + +def _insert_queue_marker_event(queue) -> Optional[object]: + for marker_fn in _marker_wait_functions(): + try: + event_obj = marker_fn(queue) + if event_obj is not None: + return event_obj + except TypeError: + try: + event_obj = marker_fn(queue, wait_for=[]) + if event_obj is not None: + return event_obj + except Exception: + continue + except Exception: + continue + + return None + + +def _release_event(event_obj: Optional[object]) -> None: + if event_obj is None: + return + + try: + event_obj.release() + except Exception: + pass + + +def _prune_submission_events(ctx: _Context, queue_index: int) -> int: + pending_events: List[object] = [] + + for event_obj in ctx.submission_events[queue_index]: + if _query_event_done(event_obj): + _release_event(event_obj) + continue + + pending_events.append(event_obj) + + ctx.submission_events[queue_index] = pending_events + return len(pending_events) + + +def _reserve_submission_slot(ctx: _Context, queue_index: int) -> bool: + return _prune_submission_events(ctx, queue_index) < _OPENCL_MAX_INFLIGHT_SUBMISSIONS + + +def _track_submission_completion(ctx: _Context, queue_index: int) -> None: + queue = ctx.queues[queue_index] + marker_event = _insert_queue_marker_event(queue) + + if marker_event is None: + queue.finish() + _prune_submission_events(ctx, queue_index) + return + + ctx.submission_events[queue_index].append(marker_event) + queue.flush() + + +# --- API: context/init/logging --- + + +def init(debug, log_level): + global _initialized, _debug_mode, _log_level + + _debug_mode = bool(debug) + _log_level = int(log_level) + _clear_error() + + if _initialized: + return + + _initialized = True + + +def log(log_level, text, file_str, line_str): + _ = log_level + _ = text + _ = file_str + _ = line_str + + +def set_log_level(log_level): + global _log_level + _log_level = int(log_level) + + +def get_devices(): + if not _initialized: + init(False, _log_level) + + entries = _enumerate_opencl_devices() + devices = [] + + for entry in entries: + device = entry.device + opencl_version = _device_attr(device, "version", "") + version_major, version_minor = _opencl_version_components(opencl_version) + version_patch = 0 + + driver_version = str(_device_attr(device, "driver_version", "")) + driver_version_num = _driver_version_number(driver_version) + + vendor_id = _coerce_int(_device_attr(device, "vendor_id", 0), 0) + device_id = int(entry.logical_index) + device_type = _device_type_to_vkdispatch(_coerce_int(_device_attr(device, "type", 0), 0)) + device_name = str(_device_attr(device, "name", f"OpenCL Device {entry.logical_index}")) + + extensions = str(_device_attr(device, "extensions", "")) + float32_atomic_support = ( + "cl_ext_float_atomics" in extensions + or "cl_khr_float_atomics" in extensions + ) + float64_support = "cl_khr_fp64" in extensions or _coerce_int(_device_attr(device, "double_fp_config", 0), 0) != 0 + float16_support = "cl_khr_fp16" in extensions or _coerce_int(_device_attr(device, "half_fp_config", 0), 0) != 0 + int64_support = _coerce_int(_device_attr(device, "address_bits", 0), 0) >= 64 + int16_support = _coerce_int(_device_attr(device, "preferred_vector_width_short", 0), 0) > 0 + + max_work_item_sizes = tuple( + _coerce_int(x, 1) + for x in _device_attr(device, "max_work_item_sizes", (1, 1, 1)) + ) + if len(max_work_item_sizes) < 3: + max_work_item_sizes = ( + max_work_item_sizes + (1, 1, 1) + )[:3] + else: + max_work_item_sizes = max_work_item_sizes[:3] + + max_workgroup_size = ( + max(1, int(max_work_item_sizes[0])), + max(1, int(max_work_item_sizes[1])), + max(1, int(max_work_item_sizes[2])), + ) + max_workgroup_invocations = max(1, _coerce_int(_device_attr(device, "max_work_group_size", 1), 1)) + + max_workgroup_count = (2 ** 31 - 1, 2 ** 31 - 1, 2 ** 31 - 1) + + max_storage_buffer_range = max( + 1, + min( + _coerce_int(_device_attr(device, "max_mem_alloc_size", 1), 1), + (1 << 31) - 1, + ), + ) + max_uniform_buffer_range = max(1, _coerce_int(_device_attr(device, "max_constant_buffer_size", 65536), 65536)) + uniform_alignment = max( + 1, + _coerce_int(_device_attr(device, "mem_base_addr_align", 8), 8) // 8, + ) + max_push_constant_size = max(0, _coerce_int(_device_attr(device, "max_parameter_size", 0), 0)) + + subgroup_size = _estimate_subgroup_size( + entry, + device, + device_name=device_name, + driver_version=driver_version, + device_type=device_type, + max_workgroup_invocations=max_workgroup_invocations, + ) + + max_compute_shared_memory_size = max( + 1, + _coerce_int(_device_attr(device, "local_mem_size", 1), 1), + ) + + uuid_bytes = _device_uuid(entry, device_name, driver_version) + + devices.append( + ( + 0, # Vulkan variant + int(version_major), + int(version_minor), + int(version_patch), + int(driver_version_num), + int(vendor_id), + int(device_id), + int(device_type), + str(device_name), + 1 if float32_atomic_support else 0, + 1 if float32_atomic_support else 0, + 1 if float64_support else 0, + 1 if float16_support else 0, + 1 if int64_support else 0, + 1 if int16_support else 0, + 1 if int16_support else 0, # storage_buffer_16_bit_access + 1 if int16_support else 0, # uniform_and_storage_buffer_16_bit_access + 0, # storage_push_constant_16 + 1 if int16_support else 0, # storage_input_output_16 + max_workgroup_size, + int(max_workgroup_invocations), + max_workgroup_count, + 8, # max descriptor sets (virtualized for parity) + int(max_push_constant_size), + int(max_storage_buffer_range), + int(max_uniform_buffer_range), + int(uniform_alignment), + subgroup_size, # subgroup size + 0, # subgroup stages + 0, # subgroup operations + 0, # quad operations in all stages + int(max_compute_shared_memory_size), + [(1, 0x006)], # compute + transfer queue + 1, # scalar block layout equivalent + 0, # timeline semaphores equivalent + uuid_bytes, + ) + ) + + return devices + + +def context_create(device_indicies, queue_families): + if not _initialized: + init(False, _log_level) + + try: + device_ids = [int(x) for x in device_indicies] + except Exception: + _set_error("context_create expected a list of integer device indices") + return 0 + + if len(device_ids) != 1: + _set_error("OpenCL backend currently supports exactly one device") + return 0 + + try: + normalized_families = [[int(x) for x in family] for family in queue_families] + except Exception: + _set_error("context_create expected queue_families to be a nested integer list") + return 0 + + if len(normalized_families) != 1 or len(normalized_families[0]) != 1: + _set_error("OpenCL backend currently supports exactly one queue") + return 0 + + entries = _enumerate_opencl_devices() + if len(entries) == 0: + if _error_string is None: + _set_error("No OpenCL devices were found") + return 0 + + logical_device_index = int(device_ids[0]) + if logical_device_index < 0 or logical_device_index >= len(entries): + _set_error( + f"Invalid OpenCL device index {logical_device_index}. " + f"Expected range [0, {len(entries) - 1}]" + ) + return 0 + + entry = entries[logical_device_index] + + try: + cl_context = cl.Context(devices=[entry.device]) + queue = cl.CommandQueue(cl_context, device=entry.device) + sub_buffer_alignment = max( + 1, + _coerce_int(_device_attr(entry.device, "mem_base_addr_align", 8), 8) // 8, + ) + ctx = _Context( + device_index=logical_device_index, + cl_context=cl_context, + queues=[queue], + queue_count=1, + queue_to_device=[0], + sub_buffer_alignment=sub_buffer_alignment, + submission_events=[[]], + stopped=False, + ) + return _new_handle(_contexts, ctx) + except Exception as exc: + _set_error(f"Failed to create OpenCL context: {exc}") + return 0 + + +def context_destroy(context): + ctx = _contexts.pop(int(context), None) + if ctx is None: + return + + for queue_events in ctx.submission_events: + for event_obj in queue_events: + _release_event(event_obj) + queue_events.clear() + + for queue in ctx.queues: + try: + queue.finish() + except Exception: + pass + try: + queue.release() + except Exception: + pass + + try: + ctx.cl_context.release() + except Exception: + pass + + +def context_stop_threads(context): + ctx = _contexts.get(int(context)) + if ctx is not None: + ctx.stopped = True + + +def get_error_string(): + if _error_string is None: + return 0 + return _error_string + + +# --- API: signals --- + + +def signal_wait(signal_ptr, wait_for_timestamp, queue_index): + _ = queue_index + + signal_obj = _signals.get(int(signal_ptr)) + if signal_obj is None: + return True + + if not bool(wait_for_timestamp): + if signal_obj.event is None: + return bool(signal_obj.done) + return bool(signal_obj.submitted) + + return _wait_signal(signal_obj) + + +def signal_insert(context, queue_index): + ctx = _context_from_handle(int(context)) + if ctx is None: + return 0 + + selected = _queue_indices(ctx, int(queue_index)) + if len(selected) == 0: + selected = [0] + + signal = _Signal(context_handle=int(context), queue_index=selected[0], submitted=False, done=False) + handle = _new_handle(_signals, signal) + + try: + event_obj = _insert_queue_marker_event(ctx.queues[selected[0]]) + if event_obj is None: + ctx.queues[selected[0]].finish() + signal.done = True + signal.submitted = True + else: + _record_signal(signal, event_obj) + except Exception as exc: + _set_error(f"Failed to insert signal: {exc}") + return 0 + + return handle + + +def signal_destroy(signal_ptr): + signal_obj = _signals.pop(int(signal_ptr), None) + if signal_obj is None: + return + + try: + if signal_obj.event is not None: + signal_obj.event.release() + except Exception: + pass + + +# --- API: buffers --- + + +def buffer_create(context, size, per_device): + _ = per_device + + ctx = _context_from_handle(int(context)) + if ctx is None: + return 0 + + size = int(size) + if size <= 0: + _set_error("Buffer size must be greater than zero") + return 0 + + try: + cl_buffer = cl.Buffer(ctx.cl_context, cl.mem_flags.READ_WRITE, size=size) + signal_handles = [ + _new_handle(_signals, _Signal(context_handle=int(context), queue_index=i, done=True)) + for i in range(ctx.queue_count) + ] + obj = _Buffer( + context_handle=int(context), + size=size, + cl_buffer=cl_buffer, + staging_data=[bytearray(size) for _ in range(ctx.queue_count)], + signal_handles=signal_handles, + ) + return _new_handle(_buffers, obj) + except Exception as exc: + _set_error(f"Failed to create OpenCL buffer: {exc}") + return 0 + + +def buffer_create_external(context, size, device_ptr): + _ = context + _ = size + _ = device_ptr + _set_error("OpenCL backend does not support external buffer aliases in MVP") + return 0 + + +def buffer_destroy(buffer): + obj = _buffers.pop(int(buffer), None) + if obj is None: + return + + for signal_handle in obj.signal_handles: + signal_destroy(signal_handle) + + try: + obj.cl_buffer.release() + except Exception: + pass + + +def buffer_get_queue_signal(buffer, queue_index): + obj = _buffers.get(int(buffer)) + if obj is None: + return _new_handle(_signals, _Signal(context_handle=0, queue_index=0, done=True)) + + queue_index = int(queue_index) + if queue_index < 0 or queue_index >= len(obj.signal_handles): + queue_index = 0 + + return obj.signal_handles[queue_index] + + +def buffer_wait_staging_idle(buffer, queue_index): + signal_handle = buffer_get_queue_signal(buffer, queue_index) + signal_obj = _signals.get(int(signal_handle)) + if signal_obj is None: + return True + return _query_signal(signal_obj) + + +def buffer_write_staging(buffer, queue_index, data, size): + obj = _buffers.get(int(buffer)) + if obj is None: + return + + queue_index = int(queue_index) + if queue_index < 0 or queue_index >= len(obj.staging_data): + return + + payload = _to_bytes(data) + size = min(int(size), len(payload), obj.size) + if size <= 0: + return + + obj.staging_data[queue_index][:size] = payload[:size] + + +def buffer_read_staging(buffer, queue_index, size): + obj = _buffers.get(int(buffer)) + if obj is None: + return bytes(int(size)) + + queue_index = int(queue_index) + if queue_index < 0 or queue_index >= len(obj.staging_data): + return bytes(int(size)) + + size = max(0, int(size)) + staging = obj.staging_data[queue_index] + + if size <= len(staging): + return bytes(staging[:size]) + + return bytes(staging) + bytes(size - len(staging)) + + +def buffer_write(buffer, offset, size, index): + obj = _buffers.get(int(buffer)) + if obj is None: + return + + ctx = _contexts.get(obj.context_handle) + if ctx is None: + _set_error(f"Missing context for buffer handle {buffer}") + return + + offset = int(offset) + size = int(size) + if size <= 0 or offset < 0: + return + + try: + for queue_index in _queue_indices(ctx, int(index), all_on_negative=True): + queue = ctx.queues[queue_index] + end = min(offset + size, obj.size) + copy_size = end - offset + if copy_size <= 0: + continue + + host_src = np.frombuffer(obj.staging_data[queue_index], dtype=np.uint8, count=copy_size) + event_obj = cl.enqueue_copy( + queue, + obj.cl_buffer, + host_src, + dst_offset=offset, + is_blocking=False, + ) + + signal_obj = _signals.get(obj.signal_handles[queue_index]) + if signal_obj is not None: + _record_signal(signal_obj, event_obj) + except Exception as exc: + _set_error(f"Failed to write OpenCL buffer: {exc}") + + +def buffer_read(buffer, offset, size, index): + obj = _buffers.get(int(buffer)) + if obj is None: + return + + ctx = _contexts.get(obj.context_handle) + if ctx is None: + _set_error(f"Missing context for buffer handle {buffer}") + return + + queue_index = int(index) + if queue_index < 0 or queue_index >= ctx.queue_count: + _set_error(f"Invalid queue index {queue_index} for buffer read") + return + + offset = int(offset) + size = int(size) + if size <= 0 or offset < 0: + return + + try: + queue = ctx.queues[queue_index] + end = min(offset + size, obj.size) + copy_size = end - offset + if copy_size <= 0: + return + + host_dst = np.frombuffer(obj.staging_data[queue_index], dtype=np.uint8, count=copy_size) + event_obj = cl.enqueue_copy( + queue, + host_dst, + obj.cl_buffer, + src_offset=offset, + is_blocking=False, + ) + + signal_obj = _signals.get(obj.signal_handles[queue_index]) + if signal_obj is not None: + _record_signal(signal_obj, event_obj) + except Exception as exc: + _set_error(f"Failed to read OpenCL buffer: {exc}") + + +# --- API: command lists --- + + +def command_list_create(context): + if int(context) not in _contexts: + _set_error("Invalid context handle for command_list_create") + return 0 + + return _new_handle(_command_lists, _CommandList(context_handle=int(context))) + + +def command_list_destroy(command_list): + _command_lists.pop(int(command_list), None) + + +def command_list_get_instance_size(command_list): + obj = _command_lists.get(int(command_list)) + if obj is None: + return 0 + + return int(sum(int(command.pc_size) for command in obj.commands)) + + +def command_list_reset(command_list): + obj = _command_lists.get(int(command_list)) + if obj is None: + return + + obj.commands = [] + + +def command_list_submit(command_list, data, instance_count, index): + obj = _command_lists.get(int(command_list)) + if obj is None: + return True + + ctx = _contexts.get(obj.context_handle) + if ctx is None: + _set_error(f"Missing context for command list {command_list}") + return True + + instance_count = int(instance_count) + if instance_count <= 0: + return True + + instance_size = command_list_get_instance_size(command_list) + payload = _to_bytes(data) + expected_payload_size = int(instance_size) * int(instance_count) + + if expected_payload_size == 0: + if len(payload) != 0: + _set_error( + f"Unexpected push-constant data for command list with instance_size=0 " + f"(got {len(payload)} bytes)." + ) + return True + elif len(payload) != expected_payload_size: + _set_error( + f"Push-constant data size mismatch. Expected {expected_payload_size} bytes " + f"(instance_size={instance_size}, instance_count={instance_count}) but got {len(payload)} bytes." + ) + return True + + queue_targets = _queue_indices(ctx, int(index), all_on_negative=True) + if len(queue_targets) == 0: + queue_targets = [0] + + try: + for queue_index in queue_targets: + if not _reserve_submission_slot(ctx, queue_index): + return False + + for queue_index in queue_targets: + queue = ctx.queues[queue_index] + for instance_index in range(instance_count): + instance_base_offset = instance_index * instance_size + per_instance_offset = 0 + for command in obj.commands: + plan = _compute_plans.get(command.plan_handle) + if plan is None: + raise RuntimeError(f"Invalid compute plan handle {command.plan_handle}") + + descriptor_set = None + if command.descriptor_set_handle != 0: + descriptor_set = _descriptor_sets.get(command.descriptor_set_handle) + if descriptor_set is None: + raise RuntimeError( + f"Invalid descriptor set handle {command.descriptor_set_handle}" + ) + + command_pc_size = int(command.pc_size) + pc_payload = b"" + if command_pc_size > 0 and len(payload) > 0: + pc_start = instance_base_offset + per_instance_offset + pc_end = pc_start + command_pc_size + pc_payload = payload[pc_start:pc_end] + + args, _keepalive = _build_kernel_args( + plan, + descriptor_set, + ctx, + pc_payload, + ) + + for arg_index, arg_value in enumerate(args): + plan.kernel.set_arg(arg_index, arg_value) + + local_x = max(1, int(plan.local_size[0])) + local_y = max(1, int(plan.local_size[1])) + local_z = max(1, int(plan.local_size[2])) + _validate_local_size_for_enqueue(ctx, (local_x, local_y, local_z)) + + blocks_x = max(1, int(command.blocks[0])) + blocks_y = max(1, int(command.blocks[1])) + blocks_z = max(1, int(command.blocks[2])) + + global_size = ( + blocks_x * local_x, + blocks_y * local_y, + blocks_z * local_z, + ) + + cl.enqueue_nd_range_kernel( + queue, + plan.kernel, + global_size, + (local_x, local_y, local_z), + ) + + per_instance_offset += command_pc_size + + if per_instance_offset != instance_size: + raise RuntimeError( + f"Internal command list size mismatch: computed {per_instance_offset} bytes, " + f"expected {instance_size} bytes." + ) + + _track_submission_completion(ctx, queue_index) + except Exception as exc: + _set_error(f"Failed to submit OpenCL command list: {exc}") + + return True + + +# --- API: descriptor sets --- + + +def descriptor_set_create(plan): + if int(plan) not in _compute_plans: + _set_error("Invalid compute plan handle for descriptor_set_create") + return 0 + + return _new_handle(_descriptor_sets, _DescriptorSet(plan_handle=int(plan))) + + +def descriptor_set_destroy(descriptor_set): + _descriptor_sets.pop(int(descriptor_set), None) + + +def descriptor_set_write_buffer( + descriptor_set, + binding, + object, + offset, + range, + uniform, + read_access, + write_access, +): + ds = _descriptor_sets.get(int(descriptor_set)) + if ds is None: + _set_error("Invalid descriptor set handle for descriptor_set_write_buffer") + return + + ds.buffer_bindings[int(binding)] = ( + int(object), + int(offset), + int(range), + int(uniform), + int(read_access), + int(write_access), + ) + + +def descriptor_set_write_image( + descriptor_set, + binding, + object, + sampler_obj, + read_access, + write_access, +): + _ = descriptor_set + _ = binding + _ = object + _ = sampler_obj + _ = read_access + _ = write_access + _set_error("OpenCL backend does not support image objects in MVP") + + +# --- API: compute stage --- + + +def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name): + ctx = _context_from_handle(int(context)) + if ctx is None: + return 0 + + source_bytes = _to_bytes(shader_source) + shader_name_bytes = _to_bytes(shader_name) + source_text = source_bytes.decode("utf-8", errors="replace") + pc_size = int(pc_size) + + try: + program = cl.Program(ctx.cl_context, source_text).build() + kernel = cl.Kernel(program, "vkdispatch_main") + except Exception as exc: + kernel_name = shader_name_bytes.decode("utf-8", errors="replace") + _set_error(f"Failed to compile OpenCL kernel '{kernel_name}': {exc}") + return 0 + + try: + params = _parse_kernel_params(source_text) + local_size = _parse_local_size(source_text) + pc_layout = _build_push_constant_layout(source_text, pc_size) + except Exception as exc: + _set_error(f"Failed to parse OpenCL kernel metadata: {exc}") + return 0 + + plan = _ComputePlan( + context_handle=int(context), + shader_source=source_bytes, + bindings=[int(x) for x in bindings], + shader_name=shader_name_bytes, + program=program, + kernel=kernel, + local_size=local_size, + params=params, + pc_size=pc_size, + pc_layout=pc_layout, + ) + + return _new_handle(_compute_plans, plan) + + +def stage_compute_plan_destroy(plan): + plan_obj = _compute_plans.pop(int(plan), None) + if plan_obj is None: + return + + try: + plan_obj.kernel.release() + except Exception: + pass + + try: + plan_obj.program.release() + except Exception: + pass + + +def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z): + cl_obj = _command_lists.get(int(command_list)) + cp_obj = _compute_plans.get(int(plan)) + if cl_obj is None or cp_obj is None: + _set_error("Invalid command list or compute plan handle for stage_compute_record") + return + + cl_obj.commands.append( + _CommandRecord( + plan_handle=int(plan), + descriptor_set_handle=int(descriptor_set), + blocks=(int(blocks_x), int(blocks_y), int(blocks_z)), + pc_size=int(cp_obj.pc_size), + ) + ) + + +# --- API: images/samplers (MVP unsupported) --- + + +def image_create(context, extent, layers, format, type, view_type, generate_mips): + _ = context + _ = extent + _ = layers + _ = format + _ = type + _ = view_type + _ = generate_mips + _set_error("OpenCL backend does not support image objects in MVP") + return 0 + + +def image_destroy(image): + _images.pop(int(image), None) + + +def image_create_sampler( + context, + mag_filter, + min_filter, + mip_mode, + address_mode, + mip_lod_bias, + min_lod, + max_lod, + border_color, +): + _ = context + _ = mag_filter + _ = min_filter + _ = mip_mode + _ = address_mode + _ = mip_lod_bias + _ = min_lod + _ = max_lod + _ = border_color + _set_error("OpenCL backend does not support image samplers in MVP") + return 0 + + +def image_destroy_sampler(sampler): + _samplers.pop(int(sampler), None) + + +def image_write(image, data, offset, extent, baseLayer, layerCount, device_index): + _ = image + _ = data + _ = offset + _ = extent + _ = baseLayer + _ = layerCount + _ = device_index + _set_error("OpenCL backend does not support image writes in MVP") + + +def image_format_block_size(format): + return int(_IMAGE_BLOCK_SIZES.get(int(format), 4)) + + +def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index): + _ = image + _ = offset + _ = extent + _ = baseLayer + _ = layerCount + _ = device_index + _set_error("OpenCL backend does not support image reads in MVP") + return bytes(max(0, int(out_size))) + + +# --- API: FFT stage (MVP unsupported) --- + + +def stage_fft_plan_create( + context, + dims, + axes, + buffer_size, + do_r2c, + normalize, + pad_left, + pad_right, + frequency_zeropadding, + kernel_num, + kernel_convolution, + conjugate_convolution, + convolution_features, + input_buffer_size, + num_batches, + single_kernel_multiple_batches, + keep_shader_code, +): + _ = context + _ = dims + _ = axes + _ = buffer_size + _ = do_r2c + _ = normalize + _ = pad_left + _ = pad_right + _ = frequency_zeropadding + _ = kernel_num + _ = kernel_convolution + _ = conjugate_convolution + _ = convolution_features + _ = input_buffer_size + _ = num_batches + _ = single_kernel_multiple_batches + _ = keep_shader_code + _set_error("OpenCL backend does not support FFT plans in MVP") + return 0 + + +def stage_fft_plan_destroy(plan): + _fft_plans.pop(int(plan), None) + + +def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer): + _ = command_list + _ = plan + _ = buffer + _ = inverse + _ = kernel + _ = input_buffer + _set_error("OpenCL backend does not support FFT stages in MVP") + + +__all__ = [ + "LOG_LEVEL_VERBOSE", + "LOG_LEVEL_INFO", + "LOG_LEVEL_WARNING", + "LOG_LEVEL_ERROR", + "DESCRIPTOR_TYPE_STORAGE_BUFFER", + "DESCRIPTOR_TYPE_STORAGE_IMAGE", + "DESCRIPTOR_TYPE_UNIFORM_BUFFER", + "DESCRIPTOR_TYPE_UNIFORM_IMAGE", + "DESCRIPTOR_TYPE_SAMPLER", + "init", + "log", + "set_log_level", + "get_devices", + "context_create", + "signal_wait", + "signal_insert", + "signal_destroy", + "context_destroy", + "get_error_string", + "context_stop_threads", + "buffer_create", + "buffer_create_external", + "buffer_destroy", + "buffer_get_queue_signal", + "buffer_wait_staging_idle", + "buffer_write_staging", + "buffer_read_staging", + "buffer_write", + "buffer_read", + "command_list_create", + "command_list_destroy", + "command_list_get_instance_size", + "command_list_reset", + "command_list_submit", + "descriptor_set_create", + "descriptor_set_destroy", + "descriptor_set_write_buffer", + "descriptor_set_write_image", + "image_create", + "image_destroy", + "image_create_sampler", + "image_destroy_sampler", + "image_write", + "image_format_block_size", + "image_read", + "stage_compute_plan_create", + "stage_compute_plan_destroy", + "stage_compute_record", + "stage_fft_plan_create", + "stage_fft_plan_destroy", + "stage_fft_record", +] diff --git a/vkdispatch/backends/pycuda_native.py b/vkdispatch/backends/pycuda_native.py deleted file mode 100644 index 5bf4068d..00000000 --- a/vkdispatch/backends/pycuda_native.py +++ /dev/null @@ -1,1371 +0,0 @@ -"""PyCUDA-backed runtime shim mirroring the vkdispatch_native API surface. - -This module intentionally matches the function names exposed by the Cython -extension so existing Python runtime objects can call into either backend. -""" - -from __future__ import annotations - -from contextlib import contextmanager -from dataclasses import dataclass, field -import hashlib -import re -from typing import Dict, List, Optional, Tuple - -try: - import numpy as np - import pycuda.driver as cuda - from pycuda.compiler import SourceModule -except Exception as exc: # pragma: no cover - import failure path - raise ImportError( - "The PyCUDA backend requires both 'pycuda' and 'numpy' to be installed." - ) from exc - - -# Log level constants mirrored from native bindings. -LOG_LEVEL_VERBOSE = 0 -LOG_LEVEL_INFO = 1 -LOG_LEVEL_WARNING = 2 -LOG_LEVEL_ERROR = 3 - -# Descriptor type enum values mirrored from vkdispatch_native/stages_extern.pxd. -DESCRIPTOR_TYPE_STORAGE_BUFFER = 1 -DESCRIPTOR_TYPE_STORAGE_IMAGE = 2 -DESCRIPTOR_TYPE_UNIFORM_BUFFER = 3 -DESCRIPTOR_TYPE_UNIFORM_IMAGE = 4 -DESCRIPTOR_TYPE_SAMPLER = 5 - -# Image format block sizes for formats exposed in vkdispatch.base.image.image_format. -_IMAGE_BLOCK_SIZES = { - 13: 1, - 14: 1, - 20: 2, - 21: 2, - 27: 3, - 28: 3, - 41: 4, - 42: 4, - 74: 2, - 75: 2, - 76: 2, - 81: 4, - 82: 4, - 83: 4, - 88: 6, - 89: 6, - 90: 6, - 95: 8, - 96: 8, - 97: 8, - 98: 4, - 99: 4, - 100: 4, - 101: 8, - 102: 8, - 103: 8, - 104: 12, - 105: 12, - 106: 12, - 107: 16, - 108: 16, - 109: 16, - 110: 8, - 111: 8, - 112: 8, - 113: 16, - 114: 16, - 115: 16, - 116: 24, - 117: 24, - 118: 24, - 119: 32, - 120: 32, - 121: 32, -} - -_LOCAL_X_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_X\s+(\d+)") -_LOCAL_Y_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Y\s+(\d+)") -_LOCAL_Z_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Z\s+(\d+)") -_KERNEL_SIGNATURE_RE = re.compile(r"vkdispatch_main\s*\(([^)]*)\)", re.S) -_BINDING_PARAM_RE = re.compile(r"vkdispatch_binding_(\d+)_ptr$") -_SAMPLER_PARAM_RE = re.compile(r"vkdispatch_sampler_(\d+)$") - - -# --- Runtime state --- - -_initialized = False -_debug_mode = False -_log_level = LOG_LEVEL_WARNING -_error_string: Optional[str] = None -_next_handle = 1 - -_contexts: Dict[int, "_Context"] = {} -_signals: Dict[int, "_Signal"] = {} -_buffers: Dict[int, "_Buffer"] = {} -_command_lists: Dict[int, "_CommandList"] = {} -_compute_plans: Dict[int, "_ComputePlan"] = {} -_descriptor_sets: Dict[int, "_DescriptorSet"] = {} -_images: Dict[int, object] = {} -_samplers: Dict[int, object] = {} -_fft_plans: Dict[int, object] = {} - - -# --- Internal objects --- - - -@dataclass -class _Signal: - context_handle: int - queue_index: int - event: Optional["cuda.Event"] = None - submitted: bool = True - done: bool = True - - -@dataclass -class _Context: - device_index: int - pycuda_context: "cuda.Context" - streams: List["cuda.Stream"] - queue_count: int - queue_to_device: List[int] - stopped: bool = False - - -@dataclass -class _Buffer: - context_handle: int - size: int - device_allocation: "cuda.DeviceAllocation" - staging_data: List[object] - signal_handles: List[int] - - -@dataclass -class _CommandRecord: - plan_handle: int - descriptor_set_handle: int - blocks: Tuple[int, int, int] - pc_size: int - - -@dataclass -class _CommandList: - context_handle: int - commands: List[_CommandRecord] = field(default_factory=list) - compute_instance_size: int = 0 - pc_scratch: Optional["cuda.DeviceAllocation"] = None - pc_scratch_size: int = 0 - - -@dataclass -class _KernelParam: - kind: str - binding: Optional[int] - raw_name: str - - -@dataclass -class _ComputePlan: - context_handle: int - shader_source: bytes - bindings: List[int] - pc_size: int - shader_name: bytes - module: SourceModule - function: object - local_size: Tuple[int, int, int] - params: List[_KernelParam] - - -@dataclass -class _DescriptorSet: - plan_handle: int - buffer_bindings: Dict[int, Tuple[int, int, int, int, int, int]] = field(default_factory=dict) - image_bindings: Dict[int, Tuple[int, int, int, int]] = field(default_factory=dict) - - -@dataclass -class _ResolvedLaunch: - plan: _ComputePlan - blocks: Tuple[int, int, int] - pc_offset: int - pc_size: int - args: Tuple[object, ...] - pc_scratch: Optional["cuda.DeviceAllocation"] = None - - -# --- Helper utilities --- - - -def _new_handle(registry: Dict[int, object], obj: object) -> int: - global _next_handle - handle = _next_handle - _next_handle += 1 - registry[handle] = obj - return handle - - -def _to_bytes(value) -> bytes: - if value is None: - return b"" - if isinstance(value, bytes): - return value - if isinstance(value, bytearray): - return bytes(value) - if isinstance(value, memoryview): - return value.tobytes() - return bytes(value) - - -def _set_error(message: str) -> None: - global _error_string - _error_string = str(message) - - -def _clear_error() -> None: - global _error_string - _error_string = None - - -def _queue_indices(ctx: _Context, queue_index: int, *, all_on_negative: bool = False) -> List[int]: - if ctx.queue_count <= 0: - return [] - - if queue_index is None: - return [0] - - queue_index = int(queue_index) - - if all_on_negative and queue_index < 0: - return list(range(ctx.queue_count)) - - if queue_index == -1: - return [0] - - if 0 <= queue_index < ctx.queue_count: - return [queue_index] - - return [] - - -def _context_from_handle(context_handle: int) -> Optional[_Context]: - ctx = _contexts.get(int(context_handle)) - if ctx is None: - _set_error(f"Invalid context handle {context_handle}") - return ctx - - -@contextmanager -def _activate_context(ctx: _Context): - ctx.pycuda_context.push() - try: - yield - finally: - cuda.Context.pop() - - -def _record_signal(signal: _Signal, stream: "cuda.Stream") -> None: - signal.submitted = True - signal.done = False - if signal.event is None: - signal.event = cuda.Event() - signal.event.record(stream) - - -def _query_signal(signal: _Signal) -> bool: - if signal.event is None: - return bool(signal.done) - - try: - done = signal.event.query() - except Exception: - return False - - signal.done = bool(done) - return signal.done - - -def _allocate_staging_storage(size: int): - try: - # Pagelocked host memory improves async HtoD/DtoH throughput and overlap. - return cuda.pagelocked_empty(int(size), np.uint8) - except Exception: - return bytearray(int(size)) - - -def _parse_local_size(source: str) -> Tuple[int, int, int]: - x_match = _LOCAL_X_RE.search(source) - y_match = _LOCAL_Y_RE.search(source) - z_match = _LOCAL_Z_RE.search(source) - - x = int(x_match.group(1)) if x_match else 1 - y = int(y_match.group(1)) if y_match else 1 - z = int(z_match.group(1)) if z_match else 1 - - return (x, y, z) - - -def _parse_kernel_params(source: str) -> List[_KernelParam]: - signature_match = _KERNEL_SIGNATURE_RE.search(source) - if signature_match is None: - raise RuntimeError("Could not find vkdispatch_main kernel signature in CUDA source") - - signature_blob = signature_match.group(1).strip() - if len(signature_blob) == 0: - return [] - - params: List[_KernelParam] = [] - - for raw_decl in [part.strip() for part in signature_blob.split(",") if len(part.strip()) > 0]: - name_match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)\s*$", raw_decl) - if name_match is None: - raise RuntimeError(f"Unable to parse kernel parameter declaration '{raw_decl}'") - - param_name = name_match.group(1) - - if param_name == "vkdispatch_uniform_ptr": - params.append(_KernelParam("uniform", 0, param_name)) - continue - - if param_name == "vkdispatch_pc_ptr": - params.append(_KernelParam("push_constant", None, param_name)) - continue - - binding_match = _BINDING_PARAM_RE.match(param_name) - if binding_match is not None: - params.append(_KernelParam("storage", int(binding_match.group(1)), param_name)) - continue - - sampler_match = _SAMPLER_PARAM_RE.match(param_name) - if sampler_match is not None: - params.append(_KernelParam("sampler", int(sampler_match.group(1)), param_name)) - continue - - params.append(_KernelParam("unknown", None, param_name)) - - return params - - -def _resolve_buffer_pointer(descriptor_set: _DescriptorSet, binding: int) -> int: - binding_info = descriptor_set.buffer_bindings.get(binding) - if binding_info is None: - raise RuntimeError(f"Missing descriptor buffer binding {binding}") - - buffer_handle, offset, _range, _uniform, _read_access, _write_access = binding_info - - buffer_obj = _buffers.get(int(buffer_handle)) - if buffer_obj is None: - raise RuntimeError(f"Invalid buffer handle {buffer_handle} for binding {binding}") - - return int(buffer_obj.device_allocation) + int(offset) - - -def _ensure_pc_scratch(command_list: _CommandList, required_size: int) -> "cuda.DeviceAllocation": - if required_size <= 0: - required_size = 1 - - if command_list.pc_scratch is not None and command_list.pc_scratch_size >= required_size: - return command_list.pc_scratch - - command_list.pc_scratch = cuda.mem_alloc(required_size) - command_list.pc_scratch_size = required_size - return command_list.pc_scratch - - -def _build_kernel_args( - plan: _ComputePlan, - descriptor_set: Optional[_DescriptorSet], - command_list: _CommandList, - pc_data: bytes, - stream: "cuda.Stream", -) -> List[object]: - args: List[object] = [] - - for param in plan.params: - if param.kind == "uniform": - if descriptor_set is None: - raise RuntimeError("Kernel requires a descriptor set but none was provided") - - args.append(np.uintp(_resolve_buffer_pointer(descriptor_set, 0))) - continue - - if param.kind == "storage": - if descriptor_set is None: - raise RuntimeError("Kernel requires a descriptor set but none was provided") - - if param.binding is None: - raise RuntimeError("Storage parameter has no binding index") - - args.append(np.uintp(_resolve_buffer_pointer(descriptor_set, param.binding))) - continue - - if param.kind == "push_constant": - pc_scratch = _ensure_pc_scratch(command_list, len(pc_data)) - - if len(pc_data) > 0: - cuda.memcpy_htod_async(pc_scratch, pc_data, stream) - - args.append(np.uintp(int(pc_scratch))) - continue - - if param.kind == "sampler": - raise RuntimeError("PyCUDA backend does not support sampled image bindings yet") - - raise RuntimeError( - f"Unsupported kernel parameter '{param.raw_name}'. " - "Expected vkdispatch_uniform_ptr / vkdispatch_binding__ptr / vkdispatch_pc_ptr." - ) - - return args - - -def _build_kernel_args_template( - plan: _ComputePlan, - descriptor_set: Optional[_DescriptorSet], - command_list: _CommandList, - pc_size: int, -) -> Tuple[Tuple[object, ...], Optional["cuda.DeviceAllocation"]]: - args: List[object] = [] - pc_scratch: Optional["cuda.DeviceAllocation"] = None - - for param in plan.params: - if param.kind == "uniform": - if descriptor_set is None: - raise RuntimeError("Kernel requires a descriptor set but none was provided") - - args.append(np.uintp(_resolve_buffer_pointer(descriptor_set, 0))) - continue - - if param.kind == "storage": - if descriptor_set is None: - raise RuntimeError("Kernel requires a descriptor set but none was provided") - - if param.binding is None: - raise RuntimeError("Storage parameter has no binding index") - - args.append(np.uintp(_resolve_buffer_pointer(descriptor_set, param.binding))) - continue - - if param.kind == "push_constant": - if pc_scratch is None: - pc_scratch = _ensure_pc_scratch(command_list, int(pc_size)) - args.append(np.uintp(int(pc_scratch))) - continue - - if param.kind == "sampler": - raise RuntimeError("PyCUDA backend does not support sampled image bindings yet") - - raise RuntimeError( - f"Unsupported kernel parameter '{param.raw_name}'. " - "Expected vkdispatch_uniform_ptr / vkdispatch_binding__ptr / vkdispatch_pc_ptr." - ) - - return tuple(args), pc_scratch - - -# --- API: context/init/logging --- - - -def init(debug, log_level): - global _initialized, _debug_mode, _log_level - - _debug_mode = bool(debug) - _log_level = int(log_level) - _clear_error() - - if _initialized: - return - - cuda.init() - _initialized = True - - -def log(log_level, text, file_str, line_str): - _ = log_level - _ = text - _ = file_str - _ = line_str - - -def set_log_level(log_level): - global _log_level - _log_level = int(log_level) - - -def get_devices(): - if not _initialized: - init(False, _log_level) - - try: - device_count = cuda.Device.count() - except Exception as exc: - _set_error(f"Failed to enumerate CUDA devices: {exc}") - return [] - - driver_version = 0 - try: - driver_version = int(cuda.get_driver_version()) - except Exception: - driver_version = 0 - - devices = [] - - for index in range(device_count): - dev = cuda.Device(index) - attrs = dev.get_attributes() - cc_major, cc_minor = dev.compute_capability() - total_memory = int(dev.total_memory()) - - max_workgroup_size = ( - int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_X, 1024)), - int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_Y, 1024)), - int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_Z, 64)), - ) - - max_workgroup_count = ( - int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_X, 65535)), - int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_Y, 65535)), - int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_Z, 65535)), - ) - - subgroup_size = int(attrs.get(cuda.device_attribute.WARP_SIZE, 32)) - max_shared_memory = int( - attrs.get(cuda.device_attribute.MAX_SHARED_MEMORY_PER_BLOCK, 48 * 1024) - ) - - try: - bus_id = str(dev.pci_bus_id()) - except Exception: - bus_id = f"cuda-device-{index}" - - uuid_bytes = hashlib.md5(bus_id.encode("utf-8")).digest() - - devices.append( - ( - 0, # Vulkan variant - int(cc_major), # major - int(cc_minor), # minor - 0, # patch - driver_version, - 0, # vendor id unknown in this API layer - index, # device id - 2, # discrete gpu - str(dev.name()), - 1, # shader_buffer_float32_atomics - 1, # shader_buffer_float32_atomic_add - 1, # float64 support - 1 if (cc_major > 5 or (cc_major == 5 and cc_minor >= 3)) else 0, # float16 support - 1, # int64 - 1, # int16 - 1, # storage_buffer_16_bit_access - 1, # uniform_and_storage_buffer_16_bit_access - 1, # storage_push_constant_16 - 1, # storage_input_output_16 - max_workgroup_size, - int(attrs.get(cuda.device_attribute.MAX_THREADS_PER_BLOCK, 1024)), - max_workgroup_count, - 8, # max descriptor sets (virtualized for parity) - 4096, # max push constant size - min(total_memory, (1 << 31) - 1), - 65536, - 16, - subgroup_size, - 0x7FFFFFFF, # supported stages (virtualized for parity) - 0x7FFFFFFF, # supported operations (virtualized for parity) - 1, - max_shared_memory, - [(1, 0x002)], # compute queue - 1, # scalar block layout - 1, # timeline semaphores equivalent - uuid_bytes, - ) - ) - - return devices - - -def context_create(device_indicies, queue_families): - if not _initialized: - init(False, _log_level) - - try: - device_ids = [int(x) for x in device_indicies] - except Exception: - _set_error("context_create expected a list of integer device indices") - return 0 - - if len(device_ids) != 1: - _set_error("PyCUDA backend currently supports exactly one device") - return 0 - - if len(queue_families) != 1 or len(queue_families[0]) != 1: - _set_error("PyCUDA backend currently supports exactly one queue") - return 0 - - device_index = device_ids[0] - - pycuda_context = None - context_pushed = False - - try: - if device_index < 0 or device_index >= cuda.Device.count(): - _set_error(f"Invalid CUDA device index {device_index}") - return 0 - - dev = cuda.Device(device_index) - pycuda_context = dev.make_context() - context_pushed = True - stream = cuda.Stream() - - ctx = _Context( - device_index=device_index, - pycuda_context=pycuda_context, - streams=[stream], - queue_count=1, - queue_to_device=[0], - stopped=False, - ) - handle = _new_handle(_contexts, ctx) - - # Leave no context current after creation. - cuda.Context.pop() - context_pushed = False - return handle - except Exception as exc: - if context_pushed: - try: - cuda.Context.pop() - except Exception: - pass - - if pycuda_context is not None: - try: - pycuda_context.detach() - except Exception: - pass - - _set_error(f"Failed to create PyCUDA context: {exc}") - return 0 - - -def context_destroy(context): - ctx = _contexts.pop(int(context), None) - if ctx is None: - return - - try: - with _activate_context(ctx): - for stream in ctx.streams: - stream.synchronize() - except Exception: - pass - - try: - ctx.pycuda_context.detach() - except Exception: - pass - - -def context_stop_threads(context): - ctx = _contexts.get(int(context)) - if ctx is not None: - ctx.stopped = True - - -def get_error_string(): - if _error_string is None: - return 0 - return _error_string - - -# --- API: signals --- - - -def signal_wait(signal_ptr, wait_for_timestamp, queue_index): - signal_obj = _signals.get(int(signal_ptr)) - if signal_obj is None: - return True - - if not bool(wait_for_timestamp): - # PyCUDA records signals synchronously on submission; host-side "recorded" waits - # should therefore complete immediately once an event exists. - if signal_obj.event is None: - return bool(signal_obj.done) - return bool(signal_obj.submitted) - - if signal_obj.done: - return True - - if signal_obj.event is None: - return bool(signal_obj.done) - - ctx = _contexts.get(signal_obj.context_handle) - if ctx is None: - return _query_signal(signal_obj) - - try: - with _activate_context(ctx): - signal_obj.event.synchronize() - signal_obj.done = True - return True - except Exception: - return _query_signal(signal_obj) - - -def signal_insert(context, queue_index): - ctx = _context_from_handle(int(context)) - if ctx is None: - return 0 - - selected = _queue_indices(ctx, int(queue_index)) - if len(selected) == 0: - selected = [0] - - signal = _Signal(context_handle=int(context), queue_index=selected[0], submitted=False, done=False) - handle = _new_handle(_signals, signal) - - try: - with _activate_context(ctx): - _record_signal(signal, ctx.streams[selected[0]]) - except Exception as exc: - _set_error(f"Failed to insert signal: {exc}") - return 0 - - return handle - - -def signal_destroy(signal_ptr): - _signals.pop(int(signal_ptr), None) - - -# --- API: buffers --- - - -def buffer_create(context, size, per_device): - _ = per_device - - ctx = _context_from_handle(int(context)) - if ctx is None: - return 0 - - size = int(size) - if size <= 0: - _set_error("Buffer size must be greater than zero") - return 0 - - try: - with _activate_context(ctx): - allocation = cuda.mem_alloc(size) - - signal_handles = [ - _new_handle(_signals, _Signal(context_handle=int(context), queue_index=i, done=True)) - for i in range(ctx.queue_count) - ] - - obj = _Buffer( - context_handle=int(context), - size=size, - device_allocation=allocation, - staging_data=[_allocate_staging_storage(size) for _ in range(ctx.queue_count)], - signal_handles=signal_handles, - ) - return _new_handle(_buffers, obj) - except Exception as exc: - _set_error(f"Failed to create CUDA buffer: {exc}") - return 0 - - -def buffer_destroy(buffer): - obj = _buffers.pop(int(buffer), None) - if obj is None: - return - - for signal_handle in obj.signal_handles: - _signals.pop(signal_handle, None) - - ctx = _contexts.get(obj.context_handle) - if ctx is None: - return - - try: - with _activate_context(ctx): - obj.device_allocation.free() - except Exception: - pass - - -def buffer_get_queue_signal(buffer, queue_index): - obj = _buffers.get(int(buffer)) - if obj is None: - return _new_handle(_signals, _Signal(context_handle=0, queue_index=0, done=True)) - - queue_index = int(queue_index) - if queue_index < 0 or queue_index >= len(obj.signal_handles): - queue_index = 0 - - return obj.signal_handles[queue_index] - - -def buffer_wait_staging_idle(buffer, queue_index): - signal_handle = buffer_get_queue_signal(buffer, queue_index) - signal_obj = _signals.get(int(signal_handle)) - if signal_obj is None: - return True - return _query_signal(signal_obj) - - -def buffer_write_staging(buffer, queue_index, data, size): - obj = _buffers.get(int(buffer)) - if obj is None: - return - - queue_index = int(queue_index) - if queue_index < 0 or queue_index >= len(obj.staging_data): - return - - payload = _to_bytes(data) - size = min(int(size), len(payload), obj.size) - if size <= 0: - return - - payload_view = memoryview(payload)[:size] - staging_view = memoryview(obj.staging_data[queue_index]) - staging_view[:size] = payload_view - - -def buffer_read_staging(buffer, queue_index, size): - obj = _buffers.get(int(buffer)) - if obj is None: - return bytes(int(size)) - - queue_index = int(queue_index) - if queue_index < 0 or queue_index >= len(obj.staging_data): - return bytes(int(size)) - - size = max(0, int(size)) - staging = obj.staging_data[queue_index] - - if size <= len(staging): - return bytes(staging[:size]) - - return bytes(staging) + bytes(size - len(staging)) - - -def buffer_write(buffer, offset, size, index): - obj = _buffers.get(int(buffer)) - if obj is None: - return - - ctx = _contexts.get(obj.context_handle) - if ctx is None: - _set_error(f"Missing context for buffer handle {buffer}") - return - - offset = int(offset) - size = int(size) - if size <= 0 or offset < 0: - return - - try: - with _activate_context(ctx): - for queue_index in _queue_indices(ctx, int(index), all_on_negative=True): - stream = ctx.streams[queue_index] - end = min(offset + size, obj.size) - copy_size = end - offset - if copy_size <= 0: - continue - - src_view = memoryview(obj.staging_data[queue_index])[:copy_size] - cuda.memcpy_htod_async(int(obj.device_allocation) + offset, src_view, stream) - - signal = _signals.get(obj.signal_handles[queue_index]) - if signal is not None: - _record_signal(signal, stream) - except Exception as exc: - _set_error(f"Failed to write CUDA buffer: {exc}") - - -def buffer_read(buffer, offset, size, index): - obj = _buffers.get(int(buffer)) - if obj is None: - return - - ctx = _contexts.get(obj.context_handle) - if ctx is None: - _set_error(f"Missing context for buffer handle {buffer}") - return - - queue_index = int(index) - if queue_index < 0 or queue_index >= ctx.queue_count: - _set_error(f"Invalid queue index {queue_index} for buffer read") - return - - offset = int(offset) - size = int(size) - if size <= 0 or offset < 0: - return - - try: - with _activate_context(ctx): - stream = ctx.streams[queue_index] - end = min(offset + size, obj.size) - copy_size = end - offset - if copy_size <= 0: - return - - dst_view = memoryview(obj.staging_data[queue_index])[:copy_size] - cuda.memcpy_dtoh_async(dst_view, int(obj.device_allocation) + offset, stream) - - signal = _signals.get(obj.signal_handles[queue_index]) - if signal is not None: - _record_signal(signal, stream) - except Exception as exc: - _set_error(f"Failed to read CUDA buffer: {exc}") - - -# --- API: command lists --- - - -def command_list_create(context): - if int(context) not in _contexts: - _set_error("Invalid context handle for command_list_create") - return 0 - - return _new_handle(_command_lists, _CommandList(context_handle=int(context))) - - -def command_list_destroy(command_list): - obj = _command_lists.pop(int(command_list), None) - if obj is None: - return - - ctx = _contexts.get(obj.context_handle) - if ctx is None or obj.pc_scratch is None: - return - - try: - with _activate_context(ctx): - obj.pc_scratch.free() - except Exception: - pass - - -def command_list_get_instance_size(command_list): - obj = _command_lists.get(int(command_list)) - if obj is None: - return 0 - return int(obj.compute_instance_size) - - -def command_list_reset(command_list): - obj = _command_lists.get(int(command_list)) - if obj is None: - return - - obj.commands = [] - obj.compute_instance_size = 0 - - -def command_list_submit(command_list, data, instance_count, index): - obj = _command_lists.get(int(command_list)) - if obj is None: - return True - - ctx = _contexts.get(obj.context_handle) - if ctx is None: - _set_error(f"Missing context for command list {command_list}") - return True - - payload = _to_bytes(data) if data is not None else b"" - instance_count = int(instance_count) - if instance_count <= 0: - return True - - instance_size = int(obj.compute_instance_size) - - if instance_size > 0 and len(payload) < instance_size * instance_count: - _set_error( - f"Instance payload is too small ({len(payload)} bytes) for " - f"{instance_count} instances of size {instance_size}" - ) - return True - - queue_targets = _queue_indices(ctx, int(index), all_on_negative=True) - if len(queue_targets) == 0: - queue_targets = [0] - - try: - with _activate_context(ctx): - payload_view = memoryview(payload) if payload else None - - for queue_index in queue_targets: - stream = ctx.streams[queue_index] - resolved_launches: List[_ResolvedLaunch] = [] - pc_offset = 0 - - for command in obj.commands: - plan = _compute_plans.get(command.plan_handle) - if plan is None: - raise RuntimeError(f"Invalid compute plan handle {command.plan_handle}") - - descriptor_set = None - if command.descriptor_set_handle != 0: - descriptor_set = _descriptor_sets.get(command.descriptor_set_handle) - if descriptor_set is None: - raise RuntimeError( - f"Invalid descriptor set handle {command.descriptor_set_handle}" - ) - - pc_size = int(command.pc_size) - args, pc_scratch = _build_kernel_args_template(plan, descriptor_set, obj, pc_size) - resolved_launches.append( - _ResolvedLaunch( - plan=plan, - blocks=command.blocks, - pc_offset=pc_offset, - pc_size=pc_size, - args=args, - pc_scratch=pc_scratch, - ) - ) - pc_offset += pc_size - - for instance in range(instance_count): - instance_base = instance * instance_size - - for launch in resolved_launches: - if launch.pc_scratch is not None and launch.pc_size > 0: - start = instance_base + launch.pc_offset - end = start + launch.pc_size - cuda.memcpy_htod_async( - launch.pc_scratch, - payload_view[start:end], - stream, - ) - - launch.plan.function( - *launch.args, - block=launch.plan.local_size, - grid=launch.blocks, - stream=stream, - ) - except Exception as exc: - _set_error(f"Failed to submit CUDA command list: {exc}") - - return True - - -# --- API: descriptor sets --- - - -def descriptor_set_create(plan): - if int(plan) not in _compute_plans: - _set_error("Invalid compute plan handle for descriptor_set_create") - return 0 - - return _new_handle(_descriptor_sets, _DescriptorSet(plan_handle=int(plan))) - - -def descriptor_set_destroy(descriptor_set): - _descriptor_sets.pop(int(descriptor_set), None) - - -def descriptor_set_write_buffer( - descriptor_set, - binding, - object, - offset, - range, - uniform, - read_access, - write_access, -): - ds = _descriptor_sets.get(int(descriptor_set)) - if ds is None: - _set_error("Invalid descriptor set handle for descriptor_set_write_buffer") - return - - ds.buffer_bindings[int(binding)] = ( - int(object), - int(offset), - int(range), - int(uniform), - int(read_access), - int(write_access), - ) - - -def descriptor_set_write_image( - descriptor_set, - binding, - object, - sampler_obj, - read_access, - write_access, -): - ds = _descriptor_sets.get(int(descriptor_set)) - if ds is None: - _set_error("Invalid descriptor set handle for descriptor_set_write_image") - return - - ds.image_bindings[int(binding)] = ( - int(object), - int(sampler_obj), - int(read_access), - int(write_access), - ) - - -# --- API: compute stage --- - - -def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name): - ctx = _context_from_handle(int(context)) - if ctx is None: - return 0 - - source_bytes = _to_bytes(shader_source) - shader_name_bytes = _to_bytes(shader_name) - source_text = source_bytes.decode("utf-8", errors="replace") - - try: - with _activate_context(ctx): - module = SourceModule( - source_text, - no_extern_c=True, - options=["-w"] - ) - function = module.get_function("vkdispatch_main") - except Exception as exc: - _set_error(f"Failed to compile CUDA kernel '{shader_name_bytes.decode(errors='ignore')}': {exc}") - return 0 - - try: - params = _parse_kernel_params(source_text) - local_size = _parse_local_size(source_text) - except Exception as exc: - _set_error(f"Failed to parse CUDA kernel metadata: {exc}") - return 0 - - plan = _ComputePlan( - context_handle=int(context), - shader_source=source_bytes, - bindings=[int(x) for x in bindings], - pc_size=int(pc_size), - shader_name=shader_name_bytes, - module=module, - function=function, - local_size=local_size, - params=params, - ) - - return _new_handle(_compute_plans, plan) - - -def stage_compute_plan_destroy(plan): - if plan is None: - return - _compute_plans.pop(int(plan), None) - - -def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z): - cl = _command_lists.get(int(command_list)) - cp = _compute_plans.get(int(plan)) - if cl is None or cp is None: - _set_error("Invalid command list or compute plan handle for stage_compute_record") - return - - cl.commands.append( - _CommandRecord( - plan_handle=int(plan), - descriptor_set_handle=int(descriptor_set), - blocks=(int(blocks_x), int(blocks_y), int(blocks_z)), - pc_size=int(cp.pc_size), - ) - ) - cl.compute_instance_size += int(cp.pc_size) - - -# --- API: images/samplers (not yet implemented on PyCUDA backend) --- - - -def image_create(context, extent, layers, format, type, view_type, generate_mips): - _ = context - _ = extent - _ = layers - _ = format - _ = type - _ = view_type - _ = generate_mips - _set_error("PyCUDA backend does not support image objects yet") - return 0 - - -def image_destroy(image): - _images.pop(int(image), None) - - -def image_create_sampler( - context, - mag_filter, - min_filter, - mip_mode, - address_mode, - mip_lod_bias, - min_lod, - max_lod, - border_color, -): - _ = context - _ = mag_filter - _ = min_filter - _ = mip_mode - _ = address_mode - _ = mip_lod_bias - _ = min_lod - _ = max_lod - _ = border_color - _set_error("PyCUDA backend does not support image samplers yet") - return 0 - - -def image_destroy_sampler(sampler): - _samplers.pop(int(sampler), None) - - -def image_write(image, data, offset, extent, baseLayer, layerCount, device_index): - _ = image - _ = data - _ = offset - _ = extent - _ = baseLayer - _ = layerCount - _ = device_index - _set_error("PyCUDA backend does not support image writes yet") - - -def image_format_block_size(format): - return int(_IMAGE_BLOCK_SIZES.get(int(format), 4)) - - -def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index): - _ = image - _ = offset - _ = extent - _ = baseLayer - _ = layerCount - _ = device_index - _set_error("PyCUDA backend does not support image reads yet") - return bytes(max(0, int(out_size))) - - -# --- API: FFT stage (not yet implemented on PyCUDA backend) --- - - -def stage_fft_plan_create( - context, - dims, - axes, - buffer_size, - do_r2c, - normalize, - pad_left, - pad_right, - frequency_zeropadding, - kernel_num, - kernel_convolution, - conjugate_convolution, - convolution_features, - input_buffer_size, - num_batches, - single_kernel_multiple_batches, - keep_shader_code, -): - _ = context - _ = dims - _ = axes - _ = buffer_size - _ = do_r2c - _ = normalize - _ = pad_left - _ = pad_right - _ = frequency_zeropadding - _ = kernel_num - _ = kernel_convolution - _ = conjugate_convolution - _ = convolution_features - _ = input_buffer_size - _ = num_batches - _ = single_kernel_multiple_batches - _ = keep_shader_code - _set_error("PyCUDA backend does not support FFT plans yet") - return 0 - - -def stage_fft_plan_destroy(plan): - _fft_plans.pop(int(plan), None) - - -def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer): - _ = command_list - _ = plan - _ = buffer - _ = inverse - _ = kernel - _ = input_buffer - _set_error("PyCUDA backend does not support FFT stages yet") - - -__all__ = [ - "LOG_LEVEL_VERBOSE", - "LOG_LEVEL_INFO", - "LOG_LEVEL_WARNING", - "LOG_LEVEL_ERROR", - "DESCRIPTOR_TYPE_STORAGE_BUFFER", - "DESCRIPTOR_TYPE_STORAGE_IMAGE", - "DESCRIPTOR_TYPE_UNIFORM_BUFFER", - "DESCRIPTOR_TYPE_UNIFORM_IMAGE", - "DESCRIPTOR_TYPE_SAMPLER", - "init", - "log", - "set_log_level", - "get_devices", - "context_create", - "signal_wait", - "signal_insert", - "signal_destroy", - "context_destroy", - "get_error_string", - "context_stop_threads", - "buffer_create", - "buffer_destroy", - "buffer_get_queue_signal", - "buffer_wait_staging_idle", - "buffer_write_staging", - "buffer_read_staging", - "buffer_write", - "buffer_read", - "command_list_create", - "command_list_destroy", - "command_list_get_instance_size", - "command_list_reset", - "command_list_submit", - "descriptor_set_create", - "descriptor_set_destroy", - "descriptor_set_write_buffer", - "descriptor_set_write_image", - "image_create", - "image_destroy", - "image_create_sampler", - "image_destroy_sampler", - "image_write", - "image_format_block_size", - "image_read", - "stage_compute_plan_create", - "stage_compute_plan_destroy", - "stage_compute_record", - "stage_fft_plan_create", - "stage_fft_plan_destroy", - "stage_fft_record", -] diff --git a/vkdispatch/base/backend.py b/vkdispatch/base/backend.py deleted file mode 100644 index cf652eb1..00000000 --- a/vkdispatch/base/backend.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -import importlib -from types import ModuleType -from typing import Dict, Optional - -BACKEND_VULKAN = "vulkan" -BACKEND_PYCUDA = "pycuda" - -_VALID_BACKENDS = {BACKEND_VULKAN, BACKEND_PYCUDA} -_active_backend_name: Optional[str] = None -_backend_modules: Dict[str, ModuleType] = {} - - -def normalize_backend_name(backend: Optional[str]) -> str: - if backend is None: - return BACKEND_VULKAN - - backend_name = backend.strip().lower() - if backend_name not in _VALID_BACKENDS: - valid = ", ".join(sorted(_VALID_BACKENDS)) - raise ValueError(f"Unknown backend '{backend}'. Expected one of: {valid}") - - return backend_name - - -def set_active_backend(backend: str) -> str: - global _active_backend_name - - backend_name = normalize_backend_name(backend) - - if _active_backend_name is not None and _active_backend_name != backend_name: - raise RuntimeError( - f"Backend is already set to '{_active_backend_name}' and cannot be changed to '{backend_name}' in this process." - ) - - _active_backend_name = backend_name - return _active_backend_name - - -def clear_active_backend() -> None: - global _active_backend_name - _active_backend_name = None - - -def get_active_backend_name(default: Optional[str] = BACKEND_VULKAN) -> str: - if _active_backend_name is not None: - return _active_backend_name - - return normalize_backend_name(default) - - -def _load_backend_module(backend_name: str) -> ModuleType: - if backend_name in _backend_modules: - return _backend_modules[backend_name] - - if backend_name == BACKEND_VULKAN: - module = importlib.import_module("vkdispatch_native") - elif backend_name == BACKEND_PYCUDA: - module = importlib.import_module("vkdispatch.backends.pycuda_native") - else: - # Defensive guard for future refactors. - raise ValueError(f"Unsupported backend '{backend_name}'") - - _backend_modules[backend_name] = module - return module - - -def get_backend_module(backend: Optional[str] = None) -> ModuleType: - backend_name = normalize_backend_name(backend) if backend is not None else get_active_backend_name() - return _load_backend_module(backend_name) - - -class _BackendProxy: - def __getattr__(self, name: str): - return getattr(get_backend_module(), name) - - -native = _BackendProxy() diff --git a/vkdispatch/base/brython_utils.py b/vkdispatch/base/brython_utils.py deleted file mode 100644 index fa4e7b6b..00000000 --- a/vkdispatch/base/brython_utils.py +++ /dev/null @@ -1,4 +0,0 @@ -import sys - -def is_brython() -> bool: - return sys.implementation.name == "Brython" \ No newline at end of file diff --git a/vkdispatch/base/buffer.py b/vkdispatch/base/buffer.py index 8c2ff2a8..6f49b622 100644 --- a/vkdispatch/base/buffer.py +++ b/vkdispatch/base/buffer.py @@ -1,22 +1,42 @@ from typing import Tuple from typing import List from typing import Union +from typing import Optional +from contextlib import nullcontext +from .init import is_cuda from .dtype import dtype from .context import Handle, Signal from .errors import check_for_errors -from .dtype import complex64, uint32, int32, float32 +from .dtype import complex64 +from . import dtype as dtypes -from .._compat import numpy_compat as npc +from ..compat import numpy_compat as npc from .dtype import to_numpy_dtype, from_numpy_dtype -from .backend import native +from ..backends.backend_selection import native import typing _ArgType = typing.TypeVar('_ArgType', bound=dtype) +import dataclasses + +def _suspend_cuda_capture_if_needed(): + if not is_cuda(): + return nullcontext() + + from ..execution_pipeline.cuda_graph_capture import suspend_cuda_capture + return suspend_cuda_capture() + +@dataclasses.dataclass +class ExternalBufferInfo: + writable: bool + iface: dict + keepalive: bool + cuda_ptr: int + class Buffer(Handle, typing.Generic[_ArgType]): """ Represents a contiguous block of memory on the GPU (or shared across multiple devices). @@ -37,8 +57,14 @@ class Buffer(Handle, typing.Generic[_ArgType]): size: int mem_size: int signals: List[Signal] - - def __init__(self, shape: Tuple[int, ...], var_type: dtype) -> None: + is_external: bool + owns_memory: bool + is_writable: bool + cuda_ptr: typing.Optional[int] + cuda_source: typing.Any + cuda_array_stream: typing.Optional[typing.Any] + + def __init__(self, shape: Tuple[int, ...], var_type: dtype, external_buffer: ExternalBufferInfo = None) -> None: super().__init__() if isinstance(shape, int): @@ -49,7 +75,6 @@ def __init__(self, shape: Tuple[int, ...], var_type: dtype) -> None: self.var_type: dtype = var_type self.shape: Tuple[int] = shape - #self.size: int = int(np.prod(shape)) size = 1 for dim in shape: @@ -71,11 +96,25 @@ def __init__(self, shape: Tuple[int, ...], var_type: dtype) -> None: self.shader_shape = tuple(shader_shape_internal) self.signals = [] - - handle = native.buffer_create( - self.context._handle, self.mem_size, 0 - ) - check_for_errors() + self.is_external = external_buffer is not None + self.owns_memory = external_buffer is None + self.is_writable = True if external_buffer is None else external_buffer.writable + self.cuda_ptr = None if external_buffer is None else external_buffer.cuda_ptr + self.cuda_source = None if external_buffer is None else (external_buffer.iface if external_buffer.keepalive else None) + self.cuda_array_stream = None if external_buffer is None else external_buffer.iface.get("stream") + + with _suspend_cuda_capture_if_needed(): + if external_buffer is not None: + handle = native.buffer_create_external( + self.context._handle, + self.mem_size, + self.cuda_ptr, + ) + else: + handle = native.buffer_create( + self.context._handle, self.mem_size, 0 + ) + check_for_errors() self.signals = [ Signal( @@ -88,6 +127,17 @@ def __init__(self, shape: Tuple[int, ...], var_type: dtype) -> None: self.register_handle(handle) + def __repr__(self): + return f"""Buffer {self._handle}: + shape={self.shape} + var_type={self.var_type.name} + mem_size={self.mem_size} bytes + is_external={self.is_external} + writable={self.is_writable} + cuda_ptr={self.cuda_ptr} + cuda_iface={self.cuda_source} +""" + def _destroy(self) -> None: """Destroy the buffer and all child handles.""" @@ -100,31 +150,33 @@ def __del__(self) -> None: self.destroy() def _wait_staging_idle(self, index: int): - is_idle = native.buffer_wait_staging_idle(self._handle, index) - check_for_errors() + with _suspend_cuda_capture_if_needed(): + is_idle = native.buffer_wait_staging_idle(self._handle, index) + check_for_errors() return is_idle def _do_writes(self, data: bytes, index: int = None): indicies = [index] if index is not None else range(self.context.queue_count) completed_stages = [0] * len(indicies) - while not all(stage == 1 for stage in completed_stages): - for i in range(len(indicies)): - if completed_stages[i] == 1: - continue + with _suspend_cuda_capture_if_needed(): + while not all(stage == 1 for stage in completed_stages): + for i in range(len(indicies)): + if completed_stages[i] == 1: + continue - queue_index = indicies[i] + queue_index = indicies[i] - if not self.signals[queue_index].try_wait(True, queue_index): - continue + if not self.signals[queue_index].try_wait(True, queue_index): + continue - completed_stages[i] = 1 + completed_stages[i] = 1 - native.buffer_write_staging(self._handle, queue_index, data, len(data)) - check_for_errors() + native.buffer_write_staging(self._handle, queue_index, data, len(data)) + check_for_errors() - native.buffer_write(self._handle, 0, len(data), queue_index) - check_for_errors() + native.buffer_write(self._handle, 0, len(data), queue_index) + check_for_errors() def write(self, data: Union[bytes, bytearray, memoryview, typing.Any], index: int = None) -> None: """ @@ -143,6 +195,9 @@ def write(self, data: Union[bytes, bytearray, memoryview, typing.Any], index: in assert isinstance(index, int), "Index must be an integer or None!" assert index >= 0 and index < self.context.queue_count, "Index must be valid!" + if not getattr(self, "is_writable", True): + raise ValueError("Cannot write to a read-only buffer alias.") + true_data_object = None if npc.is_array_like(data): @@ -167,29 +222,30 @@ def _do_reads(self, var_type: dtype, shape: List[int], index: int = None) -> byt mem_size = int(npc.prod(shape)) * var_type.item_size - while not all(stage == 2 for stage in completed_stages): - for i in range(len(indicies)): - if completed_stages[i] == 2: - continue + with _suspend_cuda_capture_if_needed(): + while not all(stage == 2 for stage in completed_stages): + for i in range(len(indicies)): + if completed_stages[i] == 2: + continue - queue_index = indicies[i] + queue_index = indicies[i] - if completed_stages[i] == 0: - if self.signals[queue_index].try_wait(False, queue_index): - completed_stages[i] = 1 - native.buffer_read(self._handle, 0, mem_size, queue_index) - check_for_errors() - else: - continue + if completed_stages[i] == 0: + if self.signals[queue_index].try_wait(False, queue_index): + completed_stages[i] = 1 + native.buffer_read(self._handle, 0, mem_size, queue_index) + check_for_errors() + else: + continue - if completed_stages[i] == 1: - if self.signals[queue_index].try_wait(True, queue_index): - completed_stages[i] = 2 - else: - continue + if completed_stages[i] == 1: + if self.signals[queue_index].try_wait(True, queue_index): + completed_stages[i] = 2 + else: + continue - bytes_list[i] = native.buffer_read_staging(self._handle, queue_index, mem_size) - check_for_errors() + bytes_list[i] = native.buffer_read_staging(self._handle, queue_index, mem_size) + check_for_errors() host_arrays = [] @@ -231,6 +287,9 @@ def read(self, index: Union[int, None] = None): def asbuffer(array: typing.Any) -> Buffer: """Cast an array-like object to a buffer object.""" + if hasattr(array, "__cuda_array_interface__"): + return from_cuda_array(array) + if not npc.is_array_like(array): raise TypeError("Expected an array-like object") @@ -239,62 +298,151 @@ def asbuffer(array: typing.Any) -> Buffer: return buffer -def buffer_u32(shape: Tuple[int, ...]) -> Buffer: - """Create a buffer of unsigned 32-bit integers with the specified shape.""" - return Buffer(shape, uint32) -def buffer_i32(shape: Tuple[int, ...]) -> Buffer: - """Create a buffer of signed 32-bit integers with the specified shape.""" - return Buffer(shape, int32) +def from_cuda_array( + obj: typing.Any, + var_type: typing.Optional[dtype] = None, + require_contiguous: bool = True, + writable: typing.Optional[bool] = None, + keepalive: bool = True, +) -> Buffer: + assert is_cuda(), "__cuda_array_interface__ is only supported with CUDA backends." + + if not hasattr(obj, "__cuda_array_interface__"): + raise TypeError("Expected an object with __cuda_array_interface__") + + npc.require_numpy("from_cuda_array") + np = npc.numpy_module() + + iface = obj.__cuda_array_interface__ + if not isinstance(iface, dict): + raise TypeError("__cuda_array_interface__ must be a dictionary") + + if "shape" not in iface or "typestr" not in iface or "data" not in iface: + raise ValueError("__cuda_array_interface__ is missing required fields (shape/typestr/data)") -def buffer_f32(shape: Tuple[int, ...]) -> Buffer: - """Create a buffer of 32-bit floating-point numbers with the specified shape.""" - return Buffer(shape, float32) + shape = tuple(int(dim) for dim in iface["shape"]) + if len(shape) == 0: + shape = (1,) -def buffer_c64(shape: Tuple[int, ...]) -> Buffer: - """Create a buffer of 64-bit complex numbers with the specified shape.""" - return Buffer(shape, complex64) + data_entry = iface["data"] + if not (isinstance(data_entry, tuple) and len(data_entry) >= 2): + raise ValueError("__cuda_array_interface__['data'] must be a tuple (ptr, read_only)") + + ptr = int(data_entry[0]) + source_read_only = bool(data_entry[1]) + + inferred_np_dtype = np.dtype(iface["typestr"]) + inferred_var_type = from_numpy_dtype(inferred_np_dtype) + if var_type is None: + var_type = inferred_var_type + + if not (var_type == inferred_var_type): + raise ValueError( + f"CAI dtype ({inferred_np_dtype}) does not match requested vd dtype ({var_type.name})." + ) + + if require_contiguous: + strides = iface.get("strides") + if strides is not None: + expected_strides = [] + stride = int(inferred_np_dtype.itemsize) + for dim in reversed(shape): + expected_strides.insert(0, stride) + stride *= int(dim) + if tuple(int(x) for x in strides) != tuple(expected_strides): + raise ValueError("Only contiguous C-order CUDA arrays are supported in from_cuda_array().") + + buffer_writable = (not source_read_only) if writable is None else bool(writable) + if buffer_writable and source_read_only: + raise ValueError("Requested writable=True for a read-only CUDA array.") + + external_buffer_info = ExternalBufferInfo( + writable=buffer_writable, + iface=iface, + keepalive=keepalive, + cuda_ptr=ptr + ) + + return Buffer(shape, var_type, external_buffer=external_buffer_info) class RFFTBuffer(Buffer): - def __init__(self, shape: Tuple[int, ...]): - super().__init__(tuple(shape[:-1]) + (shape[-1] // 2 + 1,), complex64) + real_shape: Tuple[int, ...] + fourier_shape: Tuple[int, ...] + real_type: dtype + + def __init__(self, shape: Tuple[int, ...], fourier_type: dtype = complex64): + if not dtypes.is_complex(fourier_type): + raise ValueError("RFFTBuffer fourier_type must be complex32, complex64, or complex128") + + if not dtypes.is_float_dtype(fourier_type.child_type): + raise ValueError("RFFTBuffer fourier_type must use a floating-point scalar") + + super().__init__(tuple(shape[:-1]) + (shape[-1] // 2 + 1,), fourier_type) self.real_shape = shape self.fourier_shape = self.shape - + self.real_type = fourier_type.child_type + def read_real(self, index: Union[int, None] = None): npc.require_numpy("RFFTBuffer.read_real") np = npc.numpy_module() - return self.read(index).view(np.float32)[..., :self.real_shape[-1]] + + packed_shape = list(self.shape[:-1]) + [self.shape[-1] * 2] + packed_data = self._do_reads(self.real_type, packed_shape, index) + + if index is None: + packed_data = np.array(packed_data) + + return packed_data[..., :self.real_shape[-1]] def read_fourier(self, index: Union[int, None] = None): return self.read(index) - + def write_real(self, data, index: int = None): npc.require_numpy("RFFTBuffer.write_real") np = npc.numpy_module() assert data.shape == self.real_shape, "Data shape must match real shape!" - assert not np.issubdtype(data.dtype, np.complexfloating) , "Data dtype must be scalar!" + assert not np.issubdtype(data.dtype, np.complexfloating), "Data dtype must be scalar!" - true_data = np.zeros(self.shape[:-1] + (self.shape[-1] * 2,), dtype=np.float32) + real_dtype = to_numpy_dtype(self.real_type) + true_data = np.zeros(self.shape[:-1] + (self.shape[-1] * 2,), dtype=real_dtype) true_data[..., :self.real_shape[-1]] = data - self.write(np.ascontiguousarray(true_data).view(np.complex64), index) + self.write(np.ascontiguousarray(true_data), index) def write_fourier(self, data, index: int = None): npc.require_numpy("RFFTBuffer.write_fourier") np = npc.numpy_module() assert data.shape == self.fourier_shape, f"Data shape {data.shape} must match fourier shape {self.fourier_shape}!" - assert np.issubdtype(data.dtype, np.complexfloating) , "Data dtype must be complex!" + assert np.issubdtype(data.dtype, np.complexfloating), "Data dtype must be complex!" + + target_fourier_dtype = to_numpy_dtype(self.var_type) + if npc.is_host_dtype(target_fourier_dtype): + # complex32: pack complex values into float16 real/imag pairs. + complex_data = np.ascontiguousarray(data.astype(np.complex64)) + packed_pairs = np.empty(complex_data.shape + (2,), dtype=np.float16) + packed_pairs[..., 0] = complex_data.real.astype(np.float16) + packed_pairs[..., 1] = complex_data.imag.astype(np.float16) + + packed_real_shape = self.shape[:-1] + (self.shape[-1] * 2,) + self.write(np.ascontiguousarray(packed_pairs).reshape(packed_real_shape), index) + return - self.write(np.ascontiguousarray(data.astype(np.complex64)).view(np.float32), index) + self.write(np.ascontiguousarray(data.astype(target_fourier_dtype)), index) -def asrfftbuffer(data) -> RFFTBuffer: + +def asrfftbuffer(data, fourier_type: Optional[dtype] = None) -> RFFTBuffer: npc.require_numpy("asrfftbuffer") np = npc.numpy_module() assert not np.issubdtype(data.dtype, np.complexfloating), "Data dtype must be scalar!" - buffer = RFFTBuffer(data.shape) + if fourier_type is None: + scalar_dtype = from_numpy_dtype(data.dtype) + scalar_dtype = dtypes.make_floating_dtype(scalar_dtype) + fourier_type = dtypes.complex_from_float(scalar_dtype) + + buffer = RFFTBuffer(data.shape, fourier_type=fourier_type) buffer.write_real(data) return buffer diff --git a/vkdispatch/base/buffer_allocators.py b/vkdispatch/base/buffer_allocators.py new file mode 100644 index 00000000..e14fed86 --- /dev/null +++ b/vkdispatch/base/buffer_allocators.py @@ -0,0 +1,119 @@ +from .buffer import Buffer +from . import dtype as dt +from typing import Tuple + +def buffer_u32(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 32-bit integers with the specified shape.""" + return Buffer(shape, dt.uint32) + +def buffer_uv2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 32-bit integer vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.uvec2) + +def buffer_uv3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 32-bit integer vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.uvec3) + +def buffer_uv4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 32-bit integer vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.uvec4) + +def buffer_i32(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 32-bit integers with the specified shape.""" + return Buffer(shape, dt.int32) + +def buffer_iv2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 32-bit integer vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.ivec2) + +def buffer_iv3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 32-bit integer vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.ivec3) + +def buffer_iv4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 32-bit integer vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.ivec4) + +def buffer_f32(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 32-bit floating-point numbers with the specified shape.""" + return Buffer(shape, dt.float32) + +def buffer_v2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 32-bit floating-point vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.vec2) + +def buffer_v3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 32-bit floating-point vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.vec3) + +def buffer_v4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 32-bit floating-point vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.vec4) + +def buffer_c64(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 64-bit complex numbers with the specified shape.""" + return Buffer(shape, dt.complex64) + +def buffer_u16(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 16-bit integers with the specified shape.""" + return Buffer(shape, dt.uint16) + +def buffer_uhv2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 16-bit integer vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.uhvec2) + +def buffer_uhv3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 16-bit integer vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.uhvec3) + +def buffer_uhv4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 16-bit integer vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.uhvec4) + +def buffer_i16(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 16-bit integers with the specified shape.""" + return Buffer(shape, dt.int16) + +def buffer_ihv2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 16-bit integer vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.ihvec2) + +def buffer_ihv3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 16-bit integer vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.ihvec3) + +def buffer_ihv4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 16-bit integer vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.ihvec4) + +def buffer_f16(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 16-bit floating-point numbers with the specified shape.""" + return Buffer(shape, dt.float16) + +def buffer_hv2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 16-bit floating-point vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.hvec2) + +def buffer_hv3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 16-bit floating-point vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.hvec3) + +def buffer_hv4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 16-bit floating-point vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.hvec4) + +def buffer_f64(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 64-bit floating-point numbers with the specified shape.""" + return Buffer(shape, dt.float64) + +def buffer_dv2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 64-bit floating-point vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.dvec2) + +def buffer_dv3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 64-bit floating-point vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.dvec3) + +def buffer_dv4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 64-bit floating-point vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.dvec4) \ No newline at end of file diff --git a/vkdispatch/base/command_list.py b/vkdispatch/base/command_list.py index 5ebd7194..99fa2799 100644 --- a/vkdispatch/base/command_list.py +++ b/vkdispatch/base/command_list.py @@ -1,11 +1,14 @@ from typing import Tuple from typing import Optional -from .backend import native +from ..backends.backend_selection import native +from .init import is_cuda from .context import Handle from .errors import check_for_errors +from ..execution_pipeline.cuda_graph_capture import get_cuda_capture + from .compute_plan import ComputePlan from .descriptor_set import DescriptorSet @@ -76,7 +79,13 @@ def reset(self) -> None: self.clear_parents() - def submit(self, data: Optional[bytes] = None, queue_index: int = -2, instance_count: Optional[int] = None) -> None: + def submit( + self, + data: Optional[bytes] = None, + queue_index: int = -2, + instance_count: Optional[int] = None, + cuda_stream=None + ) -> None: """ Submits the recorded command list to the GPU queue for execution. @@ -106,9 +115,22 @@ def submit(self, data: Optional[bytes] = None, queue_index: int = -2, instance_c if self.get_instance_size() != 0: assert self.get_instance_size() * instance_count == len(data), "Data length must be the product of the instance size and instance count!" + if cuda_stream is None and get_cuda_capture() is not None: + cuda_stream = get_cuda_capture().cuda_stream + + if cuda_stream is not None: + if not is_cuda(): + raise RuntimeError("cuda_stream=... is currently only supported with CUDA backends.") + + native.cuda_stream_override_begin(cuda_stream) + check_for_errors() + done = False while not done: done = native.command_list_submit( self._handle, data, instance_count, queue_index ) check_for_errors() + + if cuda_stream is not None: + native.cuda_stream_override_end() diff --git a/vkdispatch/base/compute_plan.py b/vkdispatch/base/compute_plan.py index fd997705..995ae177 100644 --- a/vkdispatch/base/compute_plan.py +++ b/vkdispatch/base/compute_plan.py @@ -1,4 +1,4 @@ -from .backend import native +from ..backends.backend_selection import native from .context import Handle from .errors import check_for_compute_stage_errors, check_for_errors @@ -34,7 +34,6 @@ def __init__(self, shader_source: str, binding_type_list: list, pc_size: int, sh self.context._handle, shader_source.encode(), self.binding_list, pc_size, shader_name.encode() ) check_for_compute_stage_errors() - self.register_handle(handle) def _destroy(self) -> None: diff --git a/vkdispatch/base/context.py b/vkdispatch/base/context.py index 11aef807..45351b32 100644 --- a/vkdispatch/base/context.py +++ b/vkdispatch/base/context.py @@ -7,12 +7,16 @@ import atexit import weakref +import sys import os, signal from .errors import check_for_errors, set_running -from .init import DeviceInfo, get_backend, get_devices, initialize, set_log_level, LogLevel, log_info -from .backend import BACKEND_PYCUDA, native +from .init import DeviceInfo, is_cuda, is_opencl, is_dummy, get_devices, initialize, log_info +from ..backends.backend_selection import native +VK_SHADER_STAGE_COMPUTE_BIT = 0x00000020 + +VK_SUBGROUP_FEATURE_ARITHMETIC_BIT = 0x00000004 class Handle: context: "Context" @@ -53,6 +57,8 @@ def clear_parents(self) -> None: """ Clears the parent handles. """ + # children_dict uses weak references, so a child key may disappear + # before teardown reaches this point. for parent in self.parents.values(): parent.remove_child_handle(self) @@ -71,10 +77,8 @@ def remove_child_handle(self, child: "Handle") -> None: """ Removes a child handle from the current handle. """ - if child._handle not in self.children_dict.keys(): - raise ValueError(f"Child handle {child._handle} does not exist in parent handle!") - - self.children_dict.pop(child._handle) + # Be idempotent to tolerate repeated teardown paths and weakref eviction. + self.children_dict.pop(child._handle, None) def _destroy(self) -> None: raise NotImplementedError("destroy is an abstract method and must be implemented by subclasses.") @@ -84,7 +88,10 @@ def destroy(self) -> None: Destroys the context handle and cleans up resources. """ if self.destroyed: - return + return + + self.destroyed = True + self.clear_parents() child_keys = list(self.children_dict.keys()) @@ -101,13 +108,11 @@ def destroy(self) -> None: check_for_errors() self.canary = True - - self.clear_parents() if self._handle in self.context.handles_dict.keys(): self.context.handles_dict.pop(self._handle) - self.destroyed = True + class Signal: ptr_addr: int @@ -159,6 +164,8 @@ class Context: queue_families: List[List[int]] queue_count: int subgroup_size: int + subgroup_enabled: bool + subgroup_arithmetic: bool max_workgroup_size: Tuple[int] max_workgroup_invocations: int max_workgroup_count: Tuple[int, int, int] @@ -179,7 +186,10 @@ def __init__( self.mapped_device_ids = [dev.dev_index for dev in self.device_infos] self._handle = native.context_create(self.mapped_device_ids, queue_families) check_for_errors() - + + self.refresh_limits_from_device_infos() + + def refresh_limits_from_device_infos(self) -> None: subgroup_sizes = [] max_workgroup_sizes_x = [] max_workgroup_sizes_y = [] @@ -191,6 +201,9 @@ def __init__( uniform_buffer_alignments = [] max_shared_memory = [] + subgroup_enabled = True + subgroup_arithmetic = True + for device in self.device_infos: subgroup_sizes.append(device.sub_group_size) @@ -208,6 +221,14 @@ def __init__( max_shared_memory.append(device.max_compute_shared_memory_size) + if not device.supported_stages & VK_SHADER_STAGE_COMPUTE_BIT: + subgroup_enabled = False + + if not device.supported_operations & VK_SUBGROUP_FEATURE_ARITHMETIC_BIT: + subgroup_arithmetic = False + + self.subgroup_enabled = subgroup_enabled + self.subgroup_arithmetic = subgroup_arithmetic self.subgroup_size = min(subgroup_sizes) self.max_workgroup_size = ( min(max_workgroup_sizes_x), @@ -371,15 +392,16 @@ def make_context( select_queue_families(dev_index, queue_family_count) ) - if get_backend() == BACKEND_PYCUDA: + if is_cuda() or is_opencl(): + backend_name = "CUDA" if is_cuda() else "OpenCL" if len(device_ids) != 1: raise NotImplementedError( - "The PyCUDA backend currently supports exactly one device." + f"The {backend_name} backend currently supports exactly one device." ) if len(queue_families) != 1 or len(queue_families[0]) != 1: raise NotImplementedError( - "The PyCUDA backend currently supports exactly one queue." + f"The {backend_name} backend currently supports exactly one queue." ) total_devices = len(get_devices()) @@ -413,6 +435,125 @@ def get_context() -> Context: def get_context_handle() -> int: return get_context()._handle +def _as_positive_int(name: str, value) -> int: + try: + result = int(value) + except Exception as exc: + raise ValueError(f"{name} must be a positive integer") from exc + + if result <= 0: + raise ValueError(f"{name} must be a positive integer") + + return result + +def _as_positive_triplet(name: str, value) -> Tuple[int, int, int]: + try: + parts = list(value) + except Exception as exc: + raise ValueError(f"{name} must contain exactly 3 positive integers") from exc + + if len(parts) != 3: + raise ValueError(f"{name} must contain exactly 3 positive integers") + + return ( + _as_positive_int(f"{name}[0]", parts[0]), + _as_positive_int(f"{name}[1]", parts[1]), + _as_positive_int(f"{name}[2]", parts[2]), + ) + +def set_dummy_context_params( + subgroup_size: int = None, + subgroup_enabled: bool = None, + max_workgroup_size: Tuple[int, int, int] = None, + max_workgroup_invocations: int = None, + max_workgroup_count: Tuple[int, int, int] = None, + max_shared_memory: int = None, +) -> None: + """ + Update cached context/device limit values for the active dummy backend context. + + This only works when a dummy context already exists. + """ + global __context + + if not is_dummy(): + raise RuntimeError( + "set_dummy_context_params() is only supported when running with backend='dummy'." + ) + + if __context is None: + __context = get_context() + + validated_subgroup_size = None + validated_subgroup_enabled = None + validated_max_workgroup_size = None + validated_max_workgroup_invocations = None + validated_max_workgroup_count = None + validated_max_shared_memory = None + + backend_kwargs = {} + + if subgroup_size is not None: + validated_subgroup_size = _as_positive_int("subgroup_size", subgroup_size) + backend_kwargs["subgroup_size"] = validated_subgroup_size + + if subgroup_enabled is not None: + if not isinstance(subgroup_enabled, bool): + raise ValueError("subgroup_enabled must be a boolean value") + validated_subgroup_enabled = bool(subgroup_enabled) + backend_kwargs["subgroup_enabled"] = subgroup_enabled + + if max_workgroup_size is not None: + validated_max_workgroup_size = _as_positive_triplet("max_workgroup_size", max_workgroup_size) + backend_kwargs["max_workgroup_size"] = validated_max_workgroup_size + + if max_workgroup_invocations is not None: + validated_max_workgroup_invocations = _as_positive_int( + "max_workgroup_invocations", + max_workgroup_invocations, + ) + backend_kwargs["max_workgroup_invocations"] = validated_max_workgroup_invocations + + if max_workgroup_count is not None: + validated_max_workgroup_count = _as_positive_triplet("max_workgroup_count", max_workgroup_count) + backend_kwargs["max_workgroup_count"] = validated_max_workgroup_count + + if max_shared_memory is not None: + validated_max_shared_memory = _as_positive_int("max_shared_memory", max_shared_memory) + backend_kwargs["max_compute_shared_memory_size"] = validated_max_shared_memory + + if backend_kwargs: + native.set_device_options(**backend_kwargs) + check_for_errors() + + for device in __context.device_infos: + if validated_subgroup_size is not None: + device.sub_group_size = validated_subgroup_size + + if validated_subgroup_enabled is not None: + if validated_subgroup_enabled: + device.supported_stages |= VK_SHADER_STAGE_COMPUTE_BIT + device.supported_operations |= VK_SUBGROUP_FEATURE_ARITHMETIC_BIT + else: + device.supported_stages &= ~VK_SHADER_STAGE_COMPUTE_BIT + device.supported_operations &= ~VK_SUBGROUP_FEATURE_ARITHMETIC_BIT + + if validated_max_workgroup_size is not None: + device.max_workgroup_size = validated_max_workgroup_size + + if validated_max_workgroup_invocations is not None: + device.max_workgroup_invocations = validated_max_workgroup_invocations + + if validated_max_workgroup_count is not None: + device.max_workgroup_count = validated_max_workgroup_count + + if validated_max_shared_memory is not None: + device.max_compute_shared_memory_size = validated_max_shared_memory + + device.uniform_buffer_alignment = 0 + + __context.refresh_limits_from_device_infos() + def queue_wait_idle(queue_index: int = None, context: Context = None) -> None: """ Wait for the specified queue to finish processing. For all queues, leave queue_index as None. @@ -490,8 +631,7 @@ def _sig_handler(signum, frame): signal.signal(signum, signal.SIG_DFL) os.kill(os.getpid(), signum) - -from .brython_utils import is_brython -if not is_brython(): +# No need to register signal handlers in Brython, since it runs in a browser +if not sys.implementation.name == "Brython": signal.signal(signal.SIGINT, _sig_handler) signal.signal(signal.SIGTERM, _sig_handler) diff --git a/vkdispatch/base/descriptor_set.py b/vkdispatch/base/descriptor_set.py index b4512456..e9d2823a 100644 --- a/vkdispatch/base/descriptor_set.py +++ b/vkdispatch/base/descriptor_set.py @@ -1,4 +1,4 @@ -from .backend import native +from ..backends.backend_selection import native from .errors import check_for_errors @@ -8,6 +8,7 @@ from .image import Sampler from .init import log_info +from .init import is_cuda class DescriptorSet(Handle): """TODO: Docstring""" @@ -28,6 +29,9 @@ def __del__(self) -> None: self.destroy() def bind_buffer(self, buffer: Buffer, binding: int, offset: int = 0, range: int = 0, uniform: bool = False, read_access: bool = True, write_access: bool = True) -> None: + if write_access and not getattr(buffer, "is_writable", True): + raise ValueError("Cannot bind a read-only buffer with write access enabled.") + self.register_parent(buffer) native.descriptor_set_write_buffer( @@ -54,3 +58,10 @@ def bind_sampler(self, sampler: Sampler, binding: int, read_access: bool = True, 1 if write_access else 0 ) check_for_errors() + + def set_inline_uniform_payload(self, payload: bytes) -> None: + if not is_cuda(): + raise RuntimeError("Inline uniform payloads are currently only supported on CUDA backends.") + + native.descriptor_set_write_inline_uniform(self._handle, payload) + check_for_errors() diff --git a/vkdispatch/base/dtype.py b/vkdispatch/base/dtype.py index fa796001..62ea81d3 100644 --- a/vkdispatch/base/dtype.py +++ b/vkdispatch/base/dtype.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from .._compat import numpy_compat as npc +from ..compat import numpy_compat as npc class dtype: name: str @@ -24,6 +24,18 @@ class _Scalar(dtype): child_type = None scalar = None +class _I16(_Scalar): + name = "int16" + item_size = 2 + glsl_type = "int16_t" + format_str = "%d" + +class _U16(_Scalar): + name = "uint16" + item_size = 2 + glsl_type = "uint16_t" + format_str = "%u" + class _I32(_Scalar): name = "int32" item_size = 4 @@ -36,20 +48,61 @@ class _U32(_Scalar): glsl_type = "uint" format_str = "%u" +class _I64(_Scalar): + name = "int64" + item_size = 8 + glsl_type = "int64_t" + format_str = "%lld" + +class _U64(_Scalar): + name = "uint64" + item_size = 8 + glsl_type = "uint64_t" + format_str = "%llu" + +class _F16(_Scalar): + name = "float16" + item_size = 2 + glsl_type = "float16_t" + format_str = "%f" + class _F32(_Scalar): name = "float32" item_size = 4 glsl_type = "float" format_str = "%f" +class _F64(_Scalar): + name = "float64" + item_size = 8 + glsl_type = "double" + format_str = "%lf" + +int16 = _I16 # type: ignore +uint16 = _U16 # type: ignore int32 = _I32 # type: ignore uint32 = _U32 # type: ignore +int64 = _I64 # type: ignore +uint64 = _U64 # type: ignore +float16 = _F16 # type: ignore float32 = _F32 # type: ignore +float64 = _F64 # type: ignore class _Complex(dtype): dimentions = 0 child_count = 2 +class _CF32(_Complex): + name = "complex32" + item_size = 4 + glsl_type = "f16vec2" + format_str = "(%f, %f)" + child_type = float16 + shape = (2,) + numpy_shape = (1,) + true_numpy_shape = () + scalar = None + class _CF64(_Complex): name = "complex64" item_size = 8 @@ -61,11 +114,64 @@ class _CF64(_Complex): true_numpy_shape = () scalar = None +class _CF128(_Complex): + name = "complex128" + item_size = 16 + glsl_type = "dvec2" + format_str = "(%lf, %lf)" + child_type = float64 + shape = (2,) + numpy_shape = (1,) + true_numpy_shape = () + scalar = None + +complex32 = _CF32 # type: ignore complex64 = _CF64 # type: ignore +complex128 = _CF128 # type: ignore class _Vector(dtype): dimentions = 1 +# --- float16 vectors --- + +class _V2F16(_Vector): + name = "hvec2" + item_size = 4 + glsl_type = "f16vec2" + format_str = "(%f, %f)" + child_type = float16 + child_count = 2 + shape = (2,) + numpy_shape = (2,) + true_numpy_shape = (2,) + scalar = float16 + +class _V3F16(_Vector): + name = "hvec3" + item_size = 6 + glsl_type = "f16vec3" + format_str = "(%f, %f, %f)" + child_type = float16 + child_count = 3 + shape = (3,) + numpy_shape = (3,) + true_numpy_shape = (3,) + scalar = float16 + +class _V4F16(_Vector): + name = "hvec4" + item_size = 8 + glsl_type = "f16vec4" + format_str = "(%f, %f, %f, %f)" + child_type = float16 + child_count = 4 + shape = (4,) + numpy_shape = (4,) + true_numpy_shape = (4,) + scalar = float16 + +# --- float32 vectors --- + class _V2F32(_Vector): name = "vec2" item_size = 8 @@ -102,6 +208,84 @@ class _V4F32(_Vector): true_numpy_shape = (4,) scalar = float32 +# --- float64 vectors --- + +class _V2F64(_Vector): + name = "dvec2" + item_size = 16 + glsl_type = "dvec2" + format_str = "(%lf, %lf)" + child_type = float64 + child_count = 2 + shape = (2,) + numpy_shape = (2,) + true_numpy_shape = (2,) + scalar = float64 + +class _V3F64(_Vector): + name = "dvec3" + item_size = 24 + glsl_type = "dvec3" + format_str = "(%lf, %lf, %lf)" + child_type = float64 + child_count = 3 + shape = (3,) + numpy_shape = (3,) + true_numpy_shape = (3,) + scalar = float64 + +class _V4F64(_Vector): + name = "dvec4" + item_size = 32 + glsl_type = "dvec4" + format_str = "(%lf, %lf, %lf, %lf)" + child_type = float64 + child_count = 4 + shape = (4,) + numpy_shape = (4,) + true_numpy_shape = (4,) + scalar = float64 + +# --- int16 vectors --- + +class _V2I16(_Vector): + name = "ihvec2" + item_size = 4 + glsl_type = "i16vec2" + format_str = "(%d, %d)" + child_type = int16 + child_count = 2 + shape = (2,) + numpy_shape = (2,) + true_numpy_shape = (2,) + scalar = int16 + +class _V3I16(_Vector): + name = "ihvec3" + item_size = 6 + glsl_type = "i16vec3" + format_str = "(%d, %d, %d)" + child_type = int16 + child_count = 3 + shape = (3,) + numpy_shape = (3,) + true_numpy_shape = (3,) + scalar = int16 + +class _V4I16(_Vector): + name = "ihvec4" + item_size = 8 + glsl_type = "i16vec4" + format_str = "(%d, %d, %d, %d)" + child_type = int16 + child_count = 4 + shape = (4,) + numpy_shape = (4,) + true_numpy_shape = (4,) + scalar = int16 + +# --- int32 vectors --- + class _V2I32(_Vector): name = "ivec2" item_size = 8 @@ -138,6 +322,46 @@ class _V4I32(_Vector): true_numpy_shape = (4,) scalar = int32 +# --- uint16 vectors --- + +class _V2U16(_Vector): + name = "uhvec2" + item_size = 4 + glsl_type = "u16vec2" + format_str = "(%u, %u)" + child_type = uint16 + child_count = 2 + shape = (2,) + numpy_shape = (2,) + true_numpy_shape = (2,) + scalar = uint16 + +class _V3U16(_Vector): + name = "uhvec3" + item_size = 6 + glsl_type = "u16vec3" + format_str = "(%u, %u, %u)" + child_type = uint16 + child_count = 3 + shape = (3,) + numpy_shape = (3,) + true_numpy_shape = (3,) + scalar = uint16 + +class _V4U16(_Vector): + name = "uhvec4" + item_size = 8 + glsl_type = "u16vec4" + format_str = "(%u, %u, %u, %u)" + child_type = uint16 + child_count = 4 + shape = (4,) + numpy_shape = (4,) + true_numpy_shape = (4,) + scalar = uint16 + +# --- uint32 vectors --- + class _V2U32(_Vector): name = "uvec2" item_size = 8 @@ -174,12 +398,24 @@ class _V4U32(_Vector): true_numpy_shape = (4,) scalar = uint32 +hvec2 = _V2F16 # type: ignore +hvec3 = _V3F16 # type: ignore +hvec4 = _V4F16 # type: ignore vec2 = _V2F32 # type: ignore vec3 = _V3F32 # type: ignore vec4 = _V4F32 # type: ignore +dvec2 = _V2F64 # type: ignore +dvec3 = _V3F64 # type: ignore +dvec4 = _V4F64 # type: ignore +ihvec2 = _V2I16 # type: ignore +ihvec3 = _V3I16 # type: ignore +ihvec4 = _V4I16 # type: ignore ivec2 = _V2I32 # type: ignore ivec3 = _V3I32 # type: ignore ivec4 = _V4I32 # type: ignore +uhvec2 = _V2U16 # type: ignore +uhvec3 = _V3U16 # type: ignore +uhvec4 = _V4U16 # type: ignore uvec2 = _V2U32 # type: ignore uvec3 = _V3U32 # type: ignore uvec4 = _V4U32 # type: ignore @@ -227,39 +463,25 @@ class _M4F32(_Matrix): mat3 = _M3F32 mat4 = _M4F32 +# Maps scalar dtype -> {count: vector_dtype} +_VECTOR_TABLE = { + int16: {1: int16, 2: ihvec2, 3: ihvec3, 4: ihvec4}, + uint16: {1: uint16, 2: uhvec2, 3: uhvec3, 4: uhvec4}, + int32: {1: int32, 2: ivec2, 3: ivec3, 4: ivec4}, + uint32: {1: uint32, 2: uvec2, 3: uvec3, 4: uvec4}, + float16: {1: float16, 2: hvec2, 3: hvec3, 4: hvec4}, + float32: {1: float32, 2: vec2, 3: vec3, 4: vec4}, + float64: {1: float64, 2: dvec2, 3: dvec3, 4: dvec4}, +} + def to_vector(dtype: dtype, count: int) -> dtype: # type: ignore if count < 1 or count > 4: raise ValueError(f"Unsupported count ({count})!") - if dtype == int32: - if count == 1: - return int32 - elif count == 2: - return ivec2 - elif count == 3: - return ivec3 - elif count == 4: - return ivec4 - elif dtype == uint32: - if count == 1: - return uint32 - elif count == 2: - return uvec2 - elif count == 3: - return uvec3 - elif count == 4: - return uvec4 - elif dtype == float32: - if count == 1: - return float32 - elif count == 2: - return vec2 - elif count == 3: - return vec3 - elif count == 4: - return vec4 - else: + table = _VECTOR_TABLE.get(dtype) + if table is None: raise ValueError(f"Unsupported dtype ({dtype})!") + return table[count] def is_dtype(in_type: dtype) -> bool: return issubclass(in_type, dtype) # type: ignore @@ -280,23 +502,65 @@ def is_float_dtype(dtype: dtype) -> bool: if not is_scalar(dtype): dtype = dtype.scalar - return dtype == float32 # or dtype == complex64 + return dtype == float16 or dtype == float32 or dtype == float64 def is_integer_dtype(dtype: dtype) -> bool: if not is_scalar(dtype): dtype = dtype.scalar - return dtype == int32 or dtype == uint32 + return dtype in (int16, uint16, int32, uint32, int64, uint64) + +# Promotion precedence: float64 > float32 > float16 > int64 > int32 > int16 > uint64 > uint32 > uint16 +_SCALAR_RANK = { + uint16: 0, + int16: 1, + uint32: 2, + int32: 3, + uint64: 4, + int64: 5, + float16: 6, + float32: 7, + float64: 8, +} + +_COMPLEX_FROM_FLOAT = { + float16: complex32, + float32: complex64, + float64: complex128, +} + +def complex_from_float(dtype: dtype) -> dtype: + if not is_scalar(dtype): + raise ValueError(f"Unsupported dtype ({dtype})!") + + result = _COMPLEX_FROM_FLOAT.get(dtype) + if result is None: + raise ValueError(f"Unsupported complex base dtype ({dtype})!") + return result + +def _promote_scalar(dtype: dtype) -> dtype: + """Return the floating-point type that matches the width of *dtype*. + + Used by make_floating_dtype to convert integer scalars to their natural + floating counterpart. + """ + if dtype == int16 or dtype == uint16: + return float16 + if dtype == int32 or dtype == uint32: + return float32 + if dtype == int64 or dtype == uint64: + return float64 + return dtype def make_floating_dtype(dtype: dtype) -> dtype: if is_scalar(dtype): - return float32 + return _promote_scalar(dtype) elif is_vector(dtype): - return to_vector(float32, dtype.child_count) + return to_vector(_promote_scalar(dtype.scalar), dtype.child_count) elif is_matrix(dtype): return dtype elif is_complex(dtype): - return complex64 + return dtype else: raise ValueError(f"Unsupported dtype ({dtype})!") @@ -308,14 +572,10 @@ def vector_size(dtype: dtype) -> int: def cross_scalar_scalar(dtype1: dtype, dtype2: dtype) -> dtype: assert is_scalar(dtype1) and is_scalar(dtype2), "Both types must be scalar types!" - - if dtype1 == float32 or dtype2 == float32: - return float32 - - if dtype1 == int32 or dtype2 == int32: - return int32 - - return uint32 + + r1 = _SCALAR_RANK[dtype1] + r2 = _SCALAR_RANK[dtype2] + return dtype1 if r1 >= r2 else dtype2 def cross_vector_scalar(dtype1: dtype, dtype2: dtype) -> dtype: assert is_vector(dtype1) and is_scalar(dtype2), "First type must be vector type and second type must be scalar type!" @@ -354,10 +614,10 @@ def cross_matrix(dtype1: dtype, dtype2: dtype) -> dtype: if is_vector(dtype2) or is_complex(dtype2): raise ValueError("Cannot cross matrix and vector/complex types!") - + if is_scalar(dtype2): return dtype1 - + raise ValueError("Second type must be matrix or scalar type!") def cross_type(dtype1: dtype, dtype2: dtype) -> dtype: @@ -370,38 +630,91 @@ def cross_type(dtype1: dtype, dtype2: dtype) -> dtype: return cross_vector(dtype1, dtype2) elif is_vector(dtype2): return cross_vector(dtype2, dtype1) - + if is_complex(dtype1): - return complex64 + if is_complex(dtype2): + return complex_from_float(cross_scalar_scalar(dtype1.child_type, dtype2.child_type)) + if is_scalar(dtype2): + return complex_from_float(cross_scalar_scalar(dtype1.child_type, _promote_scalar(dtype2))) + raise ValueError("Cannot cross complex and non-scalar types!") elif is_complex(dtype2): - return complex64 - + if is_scalar(dtype1): + return complex_from_float(cross_scalar_scalar(dtype2.child_type, _promote_scalar(dtype1))) + raise ValueError("Cannot cross complex and non-scalar types!") + if is_scalar(dtype1) and is_scalar(dtype2): return cross_scalar_scalar(dtype1, dtype2) +def cross_multiply_type(dtype1: dtype, dtype2: dtype) -> dtype: + """Resolve result type for multiplication. + + Unlike ``cross_type``, multiplication is order-sensitive for matrix/vector + combinations and supports ``matN * vecN`` and ``vecN * matN``. + """ + if is_matrix(dtype1) and is_vector(dtype2): + if dtype1.child_count != dtype2.child_count: + raise ValueError( + f"Cannot multiply matrix '{dtype1.name}' and vector '{dtype2.name}' with incompatible dimensions!" + ) + if dtype1.scalar != float32 or dtype2.scalar != float32: + raise ValueError("Matrix/vector multiplication only supports float32 matrix and vector types.") + return dtype2 + + if is_vector(dtype1) and is_matrix(dtype2): + if dtype1.child_count != dtype2.child_count: + raise ValueError( + f"Cannot multiply vector '{dtype1.name}' and matrix '{dtype2.name}' with incompatible dimensions!" + ) + if dtype1.scalar != float32 or dtype2.scalar != float32: + raise ValueError("Matrix/vector multiplication only supports float32 matrix and vector types.") + return dtype1 + + return cross_type(dtype1, dtype2) + def from_numpy_dtype(dtype: Any) -> dtype: dtype_name = npc.host_dtype_name(dtype) - if dtype_name == "int32": - return int32 - elif dtype_name == "uint32": - return uint32 - elif dtype_name == "float32": - return float32 - elif dtype_name == "complex64": - return complex64 - else: + _NAME_MAP = { + "int16": int16, + "uint16": uint16, + "int32": int32, + "uint32": uint32, + "int64": int64, + "uint64": uint64, + "float16": float16, + "float32": float32, + "float64": float64, + "complex32": complex32, + "complex64": complex64, + "complex128": complex128, + } + + result = _NAME_MAP.get(dtype_name) + if result is None: raise ValueError(f"Unsupported dtype ({dtype})!") + return result def to_numpy_dtype(shader_type: dtype) -> Any: - if shader_type == int32: - return npc.host_dtype("int32") if not npc.HAS_NUMPY else npc.numpy_module().int32 - elif shader_type == uint32: - return npc.host_dtype("uint32") if not npc.HAS_NUMPY else npc.numpy_module().uint32 - elif shader_type == float32: - return npc.host_dtype("float32") if not npc.HAS_NUMPY else npc.numpy_module().float32 - elif shader_type == complex64: - return npc.host_dtype("complex64") if not npc.HAS_NUMPY else npc.numpy_module().complex64 - else: + _TYPE_MAP = { + int16: "int16", + uint16: "uint16", + int32: "int32", + uint32: "uint32", + int64: "int64", + uint64: "uint64", + float16: "float16", + float32: "float32", + float64: "float64", + complex32: "complex32", + complex64: "complex64", + complex128: "complex128", + } + + name = _TYPE_MAP.get(shader_type) + if name is None: raise ValueError(f"Unsupported shader_type ({shader_type})!") + + if npc.HAS_NUMPY and hasattr(npc.numpy_module(), name): + return getattr(npc.numpy_module(), name) + return npc.host_dtype(name) diff --git a/vkdispatch/base/errors.py b/vkdispatch/base/errors.py index 51bd308a..ca6068b1 100644 --- a/vkdispatch/base/errors.py +++ b/vkdispatch/base/errors.py @@ -1,4 +1,4 @@ -from .backend import native +from ..backends.backend_selection import native running = True @@ -26,7 +26,8 @@ def check_for_errors(): raise RuntimeError(error) else: raise RuntimeError("Unknown error occurred") - + + def check_for_compute_stage_errors(): """ Check for errors in the shader compilation stage of the vkdispatch_native library and raise a RuntimeError if found. diff --git a/vkdispatch/base/image.py b/vkdispatch/base/image.py index bb1d1427..f78ec483 100644 --- a/vkdispatch/base/image.py +++ b/vkdispatch/base/image.py @@ -1,9 +1,9 @@ import typing from enum import Enum -from .backend import native +from ..backends.backend_selection import native -from .._compat import numpy_compat as npc +from ..compat import numpy_compat as npc from . import dtype as vdt from .context import Handle diff --git a/vkdispatch/base/init.py b/vkdispatch/base/init.py index 34a084a4..a4aa7c26 100644 --- a/vkdispatch/base/init.py +++ b/vkdispatch/base/init.py @@ -6,12 +6,17 @@ import inspect from .errors import check_for_errors -from .backend import ( +from ..backends.backend_selection import ( + BACKEND_CUDA, + BACKEND_OPENCL, BACKEND_VULKAN, + BACKEND_DUMMY, + BackendUnavailableError, clear_active_backend, get_active_backend_name, + get_backend_module, native, - normalize_backend_name, + get_environment_backend, set_active_backend, ) @@ -262,7 +267,19 @@ def get_info_string(self, verbose: bool = False) -> str: result = f"Device {self.sorted_index}: {self.device_name}\n" - result += f"\tVulkan Version: {self.version_major}.{self.version_minor}.{self.version_patch}\n" + backend_type = "Vulkan" + version_number = f"{self.version_major}.{self.version_minor}.{self.version_patch}" + + if is_cuda(): + backend_type = "CUDA Compute Capability" + version_number = f"{self.version_major}.{self.version_minor}" + elif is_opencl(): + backend_type = "OpenCL" + version_number = f"{self.version_major}.{self.version_minor}" + elif is_dummy(): + backend_type = "Dummy" + + result += f"\t{backend_type} Version: {version_number}\n" result += f"\tDevice Type: {device_type_id_to_str_dict[self.device_type]}\n" if self.version_variant != 0: @@ -396,46 +413,55 @@ def get_cuda_device_map(): return uuid_map -def initialize( - debug_mode: bool = False, - log_level: LogLevel = LogLevel.WARNING, - loader_debug_logs: bool = False, - backend: Optional[str] = None, -): - """ - A function which initializes the Vulkan dispatch library. - - Args: - debug_mode (`bool`): A flag to enable debug mode. - log_level (`LogLevel`): The log level, which is one of the following: - LogLevel.VERBOSE - LogLevel.INFO - LogLevel.WARNING - LogLevel.ERROR - loader_debug_logs (bool): A flag to enable vulkan loader debug logs. - backend (`Optional[str]`): Runtime backend to use. Supported values are - "vulkan" and "pycuda". If omitted, the currently selected backend is - reused. If no backend was selected yet, `VKDISPATCH_BACKEND` is used - when set, otherwise "vulkan" is used. - """ +def _set_initialized_state(backend_name: str, devices: List[DeviceInfo]) -> None: global __initilized_instance - global __device_infos global __backend_name + global __device_infos - backend_name = normalize_backend_name( - backend - if backend is not None - else get_active_backend_name(os.environ.get("VKDISPATCH_BACKEND")) + __initilized_instance = True + __backend_name = backend_name + __device_infos = devices + + for ii, dev in enumerate(__device_infos): + dev.sorted_index = ii + + +def _build_no_gpu_backend_error( + vulkan_error: Exception, + cuda_python_error: Exception, + opencl_error: Exception, +) -> RuntimeError: + return RuntimeError( + "vkdispatch could not find an available GPU backend.\n" + f"Vulkan backend unavailable: {vulkan_error}\n" + f"CUDA Python backend unavailable: {cuda_python_error}\n" + f"OpenCL backend unavailable: {opencl_error}\n" + "Install the Vulkan backend with `pip install vkdispatch`, or install CUDA support " + "(`pip install cuda-python`), or install OpenCL support (`pip install pyopencl`), " + "or explicitly use `vd.initialize(backend='dummy')` " + "for codegen-only workflows." ) - if __initilized_instance: - if __backend_name != backend_name: - raise RuntimeError( - f"vkdispatch is already initialized with backend '{__backend_name}'. " - f"Cannot reinitialize with '{backend_name}' in the same process." - ) - return + +def _build_vulkan_backend_error(vulkan_error: Exception) -> RuntimeError: + return RuntimeError( + "vkdispatch could not load the Vulkan backend.\n" + f"Vulkan backend unavailable: {vulkan_error}\n" + "Install the Vulkan backend with `pip install vkdispatch`, use a CUDA backend " + "(`pip install cuda-python`), use an OpenCL backend (`pip install pyopencl`), " + "or explicitly use `vd.initialize(backend='dummy')` " + "for codegen-only workflows." + ) + + +def _initialize_with_backend( + backend_name: str, + debug_mode: bool, + log_level: LogLevel, + loader_debug_logs: bool, +) -> None: + global __initilized_instance set_active_backend(backend_name) @@ -443,6 +469,9 @@ def initialize( if loader_debug_logs and backend_name == BACKEND_VULKAN: os.environ["VK_LOADER_DEBUG"] = "all" + # Force import now so backend availability errors are distinct from runtime init errors. + get_backend_module(backend_name) + native.init(debug_mode, log_level.value) check_for_errors() @@ -452,59 +481,126 @@ def initialize( ] if backend_name != BACKEND_VULKAN: - __initilized_instance = True - __backend_name = backend_name - __device_infos = devivces - for ii, dev in enumerate(__device_infos): - dev.sorted_index = ii + _set_initialized_state(backend_name, devivces) return is_cuda = any(dev.is_nvidia() for dev in devivces) - cuda_uuids = get_cuda_device_map() if is_cuda else None if cuda_uuids is None: - __initilized_instance = True - __backend_name = backend_name - __device_infos = devivces - for ii, dev in enumerate(__device_infos): - dev.sorted_index = ii + _set_initialized_state(backend_name, devivces) return - + # try to match CUDA devices to Vulkan devices by UUID cuda_uuid_to_index = { uuid_bytes: cuda_index for cuda_index, uuid_bytes in cuda_uuids.items() } - matched_devices: List[Tuple[int, DeviceInfo, int]]= [] + matched_devices: List[Tuple[int, DeviceInfo]] = [] unmatched_devices: List[DeviceInfo] = [] for dev in devivces: if dev.uuid is not None and dev.uuid in cuda_uuid_to_index: - #print(f"Matched Vulkan device {ii} ({dev.device_name}) to CUDA device {cuda_uuid_to_index[dev.uuid]} with UUID {dev.uuid.hex()}") - matched_devices.append( (cuda_uuid_to_index[dev.uuid], dev) ) + matched_devices.append((cuda_uuid_to_index[dev.uuid], dev)) else: - #print(f"Could not match Vulkan device {ii} ({dev.device_name}) with UUID {dev.uuid.hex()} to any CUDA device") unmatched_devices.append(dev) - # sort matched devices by CUDA index matched_devices.sort(key=lambda x: x[0]) - - # return matched devices first (by CUDA index), then unmatched devices (by Vulkan order) result = [dev for _, dev in matched_devices] + unmatched_devices - #result_ids = [ii for _, _, ii in matched_devices] + unmatched_device_ids for dev_id, dev in enumerate(result): - #print(f"Final device order index {dev.sorted_index} -> Vulkan device {dev_id} ({dev.device_name})") dev.sorted_index = dev_id - - __initilized_instance = True - __backend_name = backend_name - __device_infos = result + + _set_initialized_state(backend_name, result) except Exception: if not __initilized_instance: clear_active_backend() raise +def initialize( + debug_mode: bool = False, + log_level: LogLevel = LogLevel.WARNING, + loader_debug_logs: bool = False, + backend: Optional[str] = None, +): + """ + A function which initializes the Vulkan dispatch library. + + Args: + debug_mode (`bool`): A flag to enable debug mode. + log_level (`LogLevel`): The log level, which is one of the following: + LogLevel.VERBOSE + LogLevel.INFO + LogLevel.WARNING + LogLevel.ERROR + loader_debug_logs (bool): A flag to enable vulkan loader debug logs. + backend (`Optional[str]`): Runtime backend to use. Supported values are + "vulkan", "cuda", "opencl", and "dummy". If omitted, the currently selected backend is + reused. If no backend was selected yet, `VKDISPATCH_BACKEND` is used + when set, otherwise "vulkan" is used. + """ + + global __initilized_instance + + backend_name = get_active_backend_name(backend) + backend_explicitly_selected = (backend is not None) or (get_environment_backend() is not None) + + if __initilized_instance: + if __backend_name != backend_name: + raise RuntimeError( + f"vkdispatch is already initialized with backend '{__backend_name}'. " + f"Cannot reinitialize with '{backend_name}' in the same process." + ) + return + + if ( + not backend_explicitly_selected + and backend_name == BACKEND_VULKAN + ): + try: + _initialize_with_backend( + BACKEND_VULKAN, + debug_mode=debug_mode, + log_level=log_level, + loader_debug_logs=loader_debug_logs, + ) + return + except BackendUnavailableError as vulkan_error: + try: + _initialize_with_backend( + BACKEND_CUDA, + debug_mode=debug_mode, + log_level=log_level, + loader_debug_logs=loader_debug_logs, + ) + return + except Exception as cuda_python_error: + try: + _initialize_with_backend( + BACKEND_OPENCL, + debug_mode=debug_mode, + log_level=log_level, + loader_debug_logs=loader_debug_logs, + ) + return + except Exception as opencl_error: + raise _build_no_gpu_backend_error( + vulkan_error, + cuda_python_error, + opencl_error, + ) from opencl_error + + try: + _initialize_with_backend( + backend_name, + debug_mode=debug_mode, + log_level=log_level, + loader_debug_logs=loader_debug_logs, + ) + except BackendUnavailableError as backend_error: + if backend_name == BACKEND_VULKAN: + raise _build_vulkan_backend_error(backend_error) from backend_error + raise + def get_devices() -> List[DeviceInfo]: """ @@ -516,7 +612,7 @@ def get_devices() -> List[DeviceInfo]: global __device_infos - initialize(backend=get_active_backend_name()) + initialize() return __device_infos @@ -527,6 +623,46 @@ def get_backend() -> str: return get_active_backend_name() +def is_vulkan() -> bool: + """ + A function which checks if the active backend is the Vulkan backend. + + Returns: + `bool`: A flag indicating whether the active backend is the Vulkan backend. + """ + + return get_backend() == BACKEND_VULKAN + +def is_cuda() -> bool: + """ + A function which checks if the active backend is a CUDA backend. + + Returns: + `bool`: A flag indicating whether the active backend is a CUDA backend. + """ + + return get_backend() == BACKEND_CUDA + +def is_opencl() -> bool: + """ + A function which checks if the active backend is the OpenCL backend. + + Returns: + `bool`: A flag indicating whether the active backend is the OpenCL backend. + """ + + return get_backend() == BACKEND_OPENCL + +def is_dummy() -> bool: + """ + A function which checks if the active backend is the dummy backend. + + Returns: + `bool`: A flag indicating whether the active backend is the dummy backend. + """ + + return get_backend() == BACKEND_DUMMY + def __log_noinit(text: str, end: str = '\n', level: LogLevel = LogLevel.ERROR, stack_offset: int = 1): """ A function which logs a message at the specified log level. @@ -553,7 +689,7 @@ def log(text: str, end: str = '\n', level: LogLevel = LogLevel.ERROR, stack_offs message (`str`): The message to log. """ - initialize(backend=get_active_backend_name()) + initialize() __log_noinit(text, end, level, stack_offset + 1) @@ -605,6 +741,6 @@ def set_log_level(level: LogLevel): level (`LogLevel`): The log level. """ - initialize(backend=get_active_backend_name()) + initialize() native.set_log_level(level.value) diff --git a/vkdispatch/codegen/__init__.py b/vkdispatch/codegen/__init__.py index 0aa98580..1d07e8eb 100644 --- a/vkdispatch/codegen/__init__.py +++ b/vkdispatch/codegen/__init__.py @@ -30,16 +30,31 @@ from .functions.atomic_memory import atomic_add -from .functions.type_casting import to_dtype, str_to_dtype, to_float, to_int, to_uint -from .functions.type_casting import to_vec2, to_vec3, to_vec4, to_complex -from .functions.type_casting import to_uvec2, to_uvec3, to_uvec4 +from .functions.type_casting import to_dtype, str_to_dtype +from .functions.type_casting import to_float16, to_float, to_float64 +from .functions.type_casting import to_int16, to_int, to_int64, to_uint16, to_uint, to_uint64 +from .functions.type_casting import to_complex, to_complex32, to_complex64, to_complex128 +from .functions.type_casting import to_hvec2, to_hvec3, to_hvec4 +from .functions.type_casting import to_vec2, to_vec3, to_vec4 +from .functions.type_casting import to_dvec2, to_dvec3, to_dvec4 +from .functions.type_casting import to_ihvec2, to_ihvec3, to_ihvec4 from .functions.type_casting import to_ivec2, to_ivec3, to_ivec4 +from .functions.type_casting import to_uhvec2, to_uhvec3, to_uhvec4 +from .functions.type_casting import to_uvec2, to_uvec3, to_uvec4 from .functions.type_casting import to_mat2, to_mat3, to_mat4 -from .functions.registers import new_register, new_float_register, new_int_register, new_uint_register -from .functions.registers import new_vec2_register, new_ivec2_register, new_uvec2_register, new_complex_register -from .functions.registers import new_vec3_register, new_ivec3_register, new_uvec3_register -from .functions.registers import new_vec4_register, new_ivec4_register, new_uvec4_register +from .functions.registers import new_register, new_complex_register +from .functions.registers import new_float16_register, new_float_register, new_float64_register +from .functions.registers import new_int16_register, new_int_register, new_int64_register +from .functions.registers import new_uint16_register, new_uint_register, new_uint64_register +from .functions.registers import new_complex32_register, new_complex64_register, new_complex128_register +from .functions.registers import new_hvec2_register, new_hvec3_register, new_hvec4_register +from .functions.registers import new_vec2_register, new_vec3_register, new_vec4_register +from .functions.registers import new_dvec2_register, new_dvec3_register, new_dvec4_register +from .functions.registers import new_ihvec2_register, new_ihvec3_register, new_ihvec4_register +from .functions.registers import new_ivec2_register, new_ivec3_register, new_ivec4_register +from .functions.registers import new_uhvec2_register, new_uhvec3_register, new_uhvec4_register +from .functions.registers import new_uvec2_register, new_uvec3_register, new_uvec4_register from .functions.registers import new_mat2_register, new_mat3_register, new_mat4_register from .functions.subgroups import subgroup_add, subgroup_mul @@ -56,7 +71,7 @@ from .functions.builtin_constants import global_invocation_id, local_invocation_id, workgroup_id, local_invocation_index from .functions.builtin_constants import workgroup_size, num_workgroups, num_subgroups, subgroup_id -from .functions.builtin_constants import subgroup_size, subgroup_invocation_id, inf_f32, ninf_f32 +from .functions.builtin_constants import subgroup_size, subgroup_invocation_id, inf_f32, ninf_f32, inf_f64, ninf_f64, inf_f16, ninf_f16 from .functions.index_raveling import ravel_index, unravel_index @@ -66,7 +81,7 @@ from .builder import ShaderBinding, ShaderDescription from .builder import ShaderBuilder, ShaderFlags -from .backends import CodeGenBackend, GLSLBackend, CUDABackend +from .backends import CodeGenBackend, GLSLBackend, CUDABackend, OpenCLBackend from .global_builder import set_builder, get_builder, shared_buffer, set_shader_print_line_numbers, get_shader_print_line_numbers from .global_builder import set_codegen_backend, get_codegen_backend diff --git a/vkdispatch/codegen/abreviations.py b/vkdispatch/codegen/abreviations.py index 1fdff076..f9815812 100644 --- a/vkdispatch/codegen/abreviations.py +++ b/vkdispatch/codegen/abreviations.py @@ -7,20 +7,40 @@ from .arguments import Image2D as Img2 from .arguments import Image3D as Img3 +from vkdispatch.base.dtype import float16 as f16 from vkdispatch.base.dtype import float32 as f32 -from vkdispatch.base.dtype import uint32 as u32 +from vkdispatch.base.dtype import float64 as f64 +from vkdispatch.base.dtype import int16 as i16 +from vkdispatch.base.dtype import uint16 as u16 from vkdispatch.base.dtype import int32 as i32 +from vkdispatch.base.dtype import uint32 as u32 +from vkdispatch.base.dtype import int64 as i64 +from vkdispatch.base.dtype import uint64 as u64 +from vkdispatch.base.dtype import complex32 as c32 from vkdispatch.base.dtype import complex64 as c64 +from vkdispatch.base.dtype import complex128 as c128 +from vkdispatch.base.dtype import hvec2 as hv2 +from vkdispatch.base.dtype import hvec3 as hv3 +from vkdispatch.base.dtype import hvec4 as hv4 from vkdispatch.base.dtype import vec2 as v2 from vkdispatch.base.dtype import vec3 as v3 from vkdispatch.base.dtype import vec4 as v4 -from vkdispatch.base.dtype import uvec2 as uv2 -from vkdispatch.base.dtype import uvec3 as uv3 -from vkdispatch.base.dtype import uvec4 as uv4 +from vkdispatch.base.dtype import dvec2 as dv2 +from vkdispatch.base.dtype import dvec3 as dv3 +from vkdispatch.base.dtype import dvec4 as dv4 +from vkdispatch.base.dtype import ihvec2 as ihv2 +from vkdispatch.base.dtype import ihvec3 as ihv3 +from vkdispatch.base.dtype import ihvec4 as ihv4 from vkdispatch.base.dtype import ivec2 as iv2 from vkdispatch.base.dtype import ivec3 as iv3 from vkdispatch.base.dtype import ivec4 as iv4 +from vkdispatch.base.dtype import uhvec2 as uhv2 +from vkdispatch.base.dtype import uhvec3 as uhv3 +from vkdispatch.base.dtype import uhvec4 as uhv4 +from vkdispatch.base.dtype import uvec2 as uv2 +from vkdispatch.base.dtype import uvec3 as uv3 +from vkdispatch.base.dtype import uvec4 as uv4 from vkdispatch.base.dtype import mat2 as m2 from vkdispatch.base.dtype import mat4 as m4 diff --git a/vkdispatch/codegen/backends/__init__.py b/vkdispatch/codegen/backends/__init__.py index 0ddf53ce..773f5bee 100644 --- a/vkdispatch/codegen/backends/__init__.py +++ b/vkdispatch/codegen/backends/__init__.py @@ -1,3 +1,4 @@ from .base import CodeGenBackend from .glsl import GLSLBackend from .cuda import CUDABackend +from .opencl import OpenCLBackend diff --git a/vkdispatch/codegen/backends/base.py b/vkdispatch/codegen/backends/base.py index 5c34ab0b..aafdab6f 100644 --- a/vkdispatch/codegen/backends/base.py +++ b/vkdispatch/codegen/backends/base.py @@ -40,12 +40,74 @@ def mark_composite_binary_op( def type_name(self, var_type: dtypes.dtype) -> str: raise NotImplementedError - def constructor(self, var_type: dtypes.dtype, args: List[str]) -> str: + def constructor( + self, + var_type: dtypes.dtype, + args: List[str], + arg_types: Optional[List[Optional[dtypes.dtype]]] = None, + ) -> str: + _ = arg_types raise NotImplementedError + def component_access_expr(self, expr: str, component: str, base_type: dtypes.dtype) -> str: + return f"{expr}.{component}" + + def buffer_component_expr( + self, + scalar_buffer_expr: str, + base_type: dtypes.dtype, + element_index_expr: str, + component_index_expr: str, + ) -> Optional[str]: + _ = (scalar_buffer_expr, base_type, element_index_expr, component_index_expr) + return None + def fma_function_name(self, var_type: dtypes.dtype) -> str: return "fma" + def math_func_name(self, func_name: str, var_type: dtypes.dtype) -> str: + """Return the backend-specific function name for a math operation. + + Backends can override this to remap function names for specific types + (e.g. CUDA __half intrinsics). + """ + return func_name + + def unary_math_expr(self, func_name: str, arg_type: dtypes.dtype, arg_expr: str) -> str: + return f"{self.math_func_name(func_name, arg_type)}({arg_expr})" + + def binary_math_expr( + self, + func_name: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + ) -> str: + mapped = self.math_func_name(func_name, lhs_type) + if func_name == "atan2": + mapped_atan = self.math_func_name("atan", lhs_type) + return f"{mapped_atan}({lhs_expr}, {rhs_expr})" + + return f"{mapped}({lhs_expr}, {rhs_expr})" + + def arithmetic_unary_expr(self, op: str, var_type: dtypes.dtype, var_expr: str) -> Optional[str]: + """Optional backend override for unary arithmetic expressions.""" + _ = (op, var_type, var_expr) + return None + + def arithmetic_binary_expr( + self, + op: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + ) -> Optional[str]: + """Optional backend override for binary arithmetic expressions.""" + _ = (op, lhs_type, lhs_expr, rhs_type, rhs_expr) + return None + def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str: raise NotImplementedError @@ -85,6 +147,18 @@ def inf_f32_expr(self) -> str: def ninf_f32_expr(self) -> str: raise NotImplementedError + def inf_f64_expr(self) -> str: + raise NotImplementedError + + def ninf_f64_expr(self) -> str: + raise NotImplementedError + + def inf_f16_expr(self) -> str: + raise NotImplementedError + + def ninf_f16_expr(self) -> str: + raise NotImplementedError + def float_bits_to_int_expr(self, var_expr: str) -> str: raise NotImplementedError @@ -145,25 +219,32 @@ def memory_barrier_image_statement(self) -> str: def group_memory_barrier_statement(self) -> str: raise NotImplementedError - def subgroup_add_expr(self, arg_expr: str) -> str: + def subgroup_add_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type raise NotImplementedError - def subgroup_mul_expr(self, arg_expr: str) -> str: + def subgroup_mul_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type raise NotImplementedError - def subgroup_min_expr(self, arg_expr: str) -> str: + def subgroup_min_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type raise NotImplementedError - def subgroup_max_expr(self, arg_expr: str) -> str: + def subgroup_max_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type raise NotImplementedError - def subgroup_and_expr(self, arg_expr: str) -> str: + def subgroup_and_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type raise NotImplementedError - def subgroup_or_expr(self, arg_expr: str) -> str: + def subgroup_or_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type raise NotImplementedError - def subgroup_xor_expr(self, arg_expr: str) -> str: + def subgroup_xor_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type raise NotImplementedError def subgroup_elect_expr(self) -> str: @@ -183,3 +264,8 @@ def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Opti def mark_texture_sample_dimension(self, dimensions: int) -> None: return + + def atomic_add_expr(self, mem_expr: str, value_expr: str, var_type: dtypes.dtype) -> str: + raise NotImplementedError( + f"atomic_add is not supported for backend '{self.name}' and type '{var_type.name}'" + ) diff --git a/vkdispatch/codegen/backends/cuda.py b/vkdispatch/codegen/backends/cuda.py deleted file mode 100644 index 51e575f0..00000000 --- a/vkdispatch/codegen/backends/cuda.py +++ /dev/null @@ -1,1247 +0,0 @@ -from typing import Dict, List, Optional, Set - -import vkdispatch.base.dtype as dtypes - -from .base import CodeGenBackend - - -def _cuda_vec_components(dim: int) -> List[str]: - if dim < 2 or dim > 4: - raise ValueError(f"Unsupported vector dimension '{dim}'") - return list("xyzw"[:dim]) - - -def _cuda_join_statements(statements: List[str]) -> str: - if len(statements) == 0: - return "" - return " ".join(statements) - - -def _cuda_emit_vec_type( - vec_name: str, - scalar_type: str, - dim: int, - *, - allow_unary_neg: bool, - enable_bitwise: bool, - needed_ops: Optional[Set[str]] = None, -) -> str: - comps = _cuda_vec_components(dim) - if needed_ops is None: - needed_ops = set() - if allow_unary_neg: - needed_ops.add("un:-") - if enable_bitwise: - needed_ops.add("un:~") - for op in ["+", "-", "*", "/"]: - needed_ops.add(f"cmpd:{op}=:v") - needed_ops.add(f"cmpd:{op}=:s") - needed_ops.add(f"bin:{op}:vv") - needed_ops.add(f"bin:{op}:vs") - needed_ops.add(f"bin:{op}:sv") - if enable_bitwise: - for op in ["&", "|", "^", "<<", ">>"]: - needed_ops.add(f"cmpd:{op}=:v") - needed_ops.add(f"cmpd:{op}=:s") - needed_ops.add(f"bin:{op}:vv") - needed_ops.add(f"bin:{op}:vs") - needed_ops.add(f"bin:{op}:sv") - - def has(token: str) -> bool: - return token in needed_ops - - lines: List[str] = [f"struct {vec_name} {{"] - lines.extend([f" {scalar_type} {c};" for c in comps]) - lines.append("") - ctor_args = ", ".join([f"{scalar_type} {c}_" for c in comps]) - ctor_init = ", ".join([f"{c}({c}_)" for c in comps]) - splat_init = ", ".join([f"{c}(s)" for c in comps]) - cast_init = ", ".join([f"{c}(({scalar_type})v.{c})" for c in comps]) - lines.append(f" __device__ __forceinline__ {vec_name}() = default;") - lines.append(f" __device__ __forceinline__ {vec_name}({ctor_args}) : {ctor_init} {{}}") - lines.append(f" __device__ __forceinline__ explicit {vec_name}({scalar_type} s) : {splat_init} {{}}") - lines.append(" template ") - lines.append(f" __device__ __forceinline__ explicit {vec_name}(TVec v) : {cast_init} {{}}") - lines.append(f" __device__ __forceinline__ {scalar_type}& operator[](int i) {{ return (&x)[i]; }}") - lines.append(f" __device__ __forceinline__ const {scalar_type}& operator[](int i) const {{ return (&x)[i]; }}") - - if allow_unary_neg and has("un:-"): - neg_expr = ", ".join([f"-{c}" for c in comps]) - lines.append(f" __device__ __forceinline__ {vec_name} operator-() const {{ return {vec_name}({neg_expr}); }}") - - if enable_bitwise and has("un:~"): - not_expr = ", ".join([f"~{c}" for c in comps]) - lines.append(f" __device__ __forceinline__ {vec_name} operator~() const {{ return {vec_name}({not_expr}); }}") - - for op in ["+", "-", "*", "/"]: - op_assign = op + "=" - if has(f"cmpd:{op}=:v"): - vv_ops = _cuda_join_statements([f"{c} {op_assign} b.{c};" for c in comps]) - lines.append( - f" __device__ __forceinline__ {vec_name}& operator{op_assign}(const {vec_name}& b) {{ {vv_ops} return *this; }}" - ) - if has(f"cmpd:{op}=:s"): - sv_ops = _cuda_join_statements([f"{c} {op_assign} b;" for c in comps]) - lines.append( - f" __device__ __forceinline__ {vec_name}& operator{op_assign}({scalar_type} b) {{ {sv_ops} return *this; }}" - ) - - if enable_bitwise: - for op in ["&", "|", "^", "<<", ">>"]: - op_assign = op + "=" - if has(f"cmpd:{op}=:v"): - vv_ops = _cuda_join_statements([f"{c} {op_assign} b.{c};" for c in comps]) - lines.append( - f" __device__ __forceinline__ {vec_name}& operator{op_assign}(const {vec_name}& b) {{ {vv_ops} return *this; }}" - ) - if has(f"cmpd:{op}=:s"): - sv_ops = _cuda_join_statements([f"{c} {op_assign} b;" for c in comps]) - lines.append( - f" __device__ __forceinline__ {vec_name}& operator{op_assign}({scalar_type} b) {{ {sv_ops} return *this; }}" - ) - - lines.append("};") - - # Arithmetic operators (vector/vector, vector/scalar, scalar/vector) - for op in ["+", "-", "*", "/"]: - if has(f"bin:{op}:vv"): - vv_expr = ", ".join([f"(a.{c} {op} b.{c})" for c in comps]) - lines.append( - f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, const {vec_name}& b) {{ return {vec_name}({vv_expr}); }}" - ) - if has(f"bin:{op}:vs"): - vs_expr = ", ".join([f"(a.{c} {op} b)" for c in comps]) - lines.append( - f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, {scalar_type} b) {{ return {vec_name}({vs_expr}); }}" - ) - if has(f"bin:{op}:sv"): - if op in ["+", "*"]: - sv_expr = ", ".join([f"(a {op} b.{c})" for c in comps]) - else: - sv_expr = ", ".join([f"({scalar_type})(a {op} b.{c})" for c in comps]) - lines.append( - f"__device__ __forceinline__ {vec_name} operator{op}({scalar_type} a, const {vec_name}& b) {{ return {vec_name}({sv_expr}); }}" - ) - - if enable_bitwise: - for op in ["&", "|", "^", "<<", ">>"]: - if has(f"bin:{op}:vv"): - vv_expr = ", ".join([f"(a.{c} {op} b.{c})" for c in comps]) - lines.append( - f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, const {vec_name}& b) {{ return {vec_name}({vv_expr}); }}" - ) - if has(f"bin:{op}:vs"): - vs_expr = ", ".join([f"(a.{c} {op} b)" for c in comps]) - lines.append( - f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, {scalar_type} b) {{ return {vec_name}({vs_expr}); }}" - ) - if has(f"bin:{op}:sv"): - sv_expr = ", ".join([f"({scalar_type})(a {op} b.{c})" for c in comps]) - lines.append( - f"__device__ __forceinline__ {vec_name} operator{op}({scalar_type} a, const {vec_name}& b) {{ return {vec_name}({sv_expr}); }}" - ) - - return "\n".join(lines) - - -def _cuda_emit_vec_helper(helper_suffix: str, vec_name: str, scalar_type: str, dim: int) -> str: - comps = _cuda_vec_components(dim) - args = ", ".join([f"{scalar_type} {c}" for c in comps]) - ctor_args = ", ".join(comps) - return "\n".join( - [ - f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}({args}) {{ return {vec_name}({ctor_args}); }}", - f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}({scalar_type} x) {{ return {vec_name}(x); }}", - "template ", - f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}(TVec v) {{ return {vec_name}(v); }}", - ] - ) - - -def _cuda_emit_mat_type(mat_name: str, vec_name: str, dim: int, needed_ops: Optional[Set[str]] = None) -> str: - cols = [f"c{i}" for i in range(dim)] - if needed_ops is None: - needed_ops = { - "un:-", - "cmpd:+=:m", "cmpd:+=:s", - "cmpd:-=:m", "cmpd:-=:s", - "cmpd:*=:s", "cmpd:/=:s", - "bin:+:mm", "bin:+:ms", "bin:+:sm", - "bin:-:mm", "bin:-:ms", "bin:-:sm", - "bin:*:ms", "bin:*:sm", "bin:/:ms", "bin:/:sm", - "bin:*:mv", "bin:*:vm", "bin:*:mm", - } - - def has(token: str) -> bool: - return token in needed_ops - - lines: List[str] = [f"struct {mat_name} {{"] - lines.extend([f" {vec_name} {c};" for c in cols]) - lines.append("") - lines.append(f" __device__ __forceinline__ {mat_name}() = default;") - ctor_args = ", ".join([f"{vec_name} {c}_" for c in cols]) - ctor_init = ", ".join([f"{c}({c}_)" for c in cols]) - lines.append(f" __device__ __forceinline__ {mat_name}({ctor_args}) : {ctor_init} {{}}") - - zero = "0.0f" - diag_init = ", ".join( - [f"c{col_idx}({vec_name}({', '.join(['s' if row_idx == col_idx else zero for row_idx in range(dim)])}))" for col_idx in range(dim)] - ) - lines.append(f" __device__ __forceinline__ explicit {mat_name}(float s) : {diag_init} {{}}") - lines.append(f" __device__ __forceinline__ {vec_name}& operator[](int i) {{ return (&c0)[i]; }}") - lines.append(f" __device__ __forceinline__ const {vec_name}& operator[](int i) const {{ return (&c0)[i]; }}") - if has("un:-"): - lines.append(f" __device__ __forceinline__ {mat_name} operator-() const {{ return {mat_name}({', '.join([f'-c{i}' for i in range(dim)])}); }}") - - for op in ["+", "-"]: - op_assign = op + "=" - if has(f"cmpd:{op}=:m"): - mm_ops = _cuda_join_statements([f"c{i} {op_assign} b.c{i};" for i in range(dim)]) - lines.append( - f" __device__ __forceinline__ {mat_name}& operator{op_assign}(const {mat_name}& b) {{ {mm_ops} return *this; }}" - ) - if has(f"cmpd:{op}=:s"): - ms_ops = _cuda_join_statements([f"c{i} {op_assign} b;" for i in range(dim)]) - lines.append( - f" __device__ __forceinline__ {mat_name}& operator{op_assign}(float b) {{ {ms_ops} return *this; }}" - ) - - for op in ["*", "/"]: - op_assign = op + "=" - if has(f"cmpd:{op}=:s"): - ms_ops = _cuda_join_statements([f"c{i} {op_assign} b;" for i in range(dim)]) - lines.append( - f" __device__ __forceinline__ {mat_name}& operator{op_assign}(float b) {{ {ms_ops} return *this; }}" - ) - - lines.append("};") - - # Basic arithmetic - for op in ["+", "-"]: - if has(f"bin:{op}:mm"): - cols_expr = ", ".join([f"(a.c{i} {op} b.c{i})" for i in range(dim)]) - lines.append( - f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}" - ) - if has(f"bin:{op}:ms"): - cols_expr = ", ".join([f"(a.c{i} {op} b)" for i in range(dim)]) - lines.append( - f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, float b) {{ return {mat_name}({cols_expr}); }}" - ) - if has(f"bin:{op}:sm"): - cols_expr = ", ".join([f"(a {op} b.c{i})" for i in range(dim)]) - lines.append( - f"__device__ __forceinline__ {mat_name} operator{op}(float a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}" - ) - - for op in ["*", "/"]: - if has(f"bin:{op}:ms"): - cols_expr = ", ".join([f"(a.c{i} {op} b)" for i in range(dim)]) - lines.append( - f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, float b) {{ return {mat_name}({cols_expr}); }}" - ) - if has(f"bin:{op}:sm"): - cols_expr = ", ".join([f"(a {op} b.c{i})" for i in range(dim)]) - lines.append( - f"__device__ __forceinline__ {mat_name} operator{op}(float a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}" - ) - - # GLSL-style matrix/vector products (column-major) - vec_comps = _cuda_vec_components(dim) - if has("bin:*:mv"): - mat_vec_terms = [f"(m.c{i} * v.{vec_comps[i]})" for i in range(dim)] - mat_vec_expr = " + ".join(mat_vec_terms) - lines.append( - f"__device__ __forceinline__ {vec_name} operator* (const {mat_name}& m, const {vec_name}& v) {{ return {mat_vec_expr}; }}" - ) - - if has("bin:*:vm"): - row_exprs: List[str] = [] - for col_idx in range(dim): - terms = [f"(v.{vec_comps[row_idx]} * m.c{col_idx}.{vec_comps[row_idx]})" for row_idx in range(dim)] - row_exprs.append(" + ".join(terms)) - lines.append( - f"__device__ __forceinline__ {vec_name} operator* (const {vec_name}& v, const {mat_name}& m) {{ return {vec_name}({', '.join(row_exprs)}); }}" - ) - - if has("bin:*:mm"): - col_products = ", ".join([f"(a * b.c{i})" for i in range(dim)]) - lines.append( - f"__device__ __forceinline__ {mat_name} operator* (const {mat_name}& a, const {mat_name}& b) {{ return {mat_name}({col_products}); }}" - ) - - return "\n".join(lines) - - -def _cuda_emit_mat_helpers(mat_name: str, helper_suffix: str, vec_name: str, vec_helper_suffix: str, dim: int) -> str: - col_type = vec_name - col_args = ", ".join([f"{col_type} c{i}" for i in range(dim)]) - col_ctor = ", ".join([f"c{i}" for i in range(dim)]) - - flat_names = [f"m{col}{row}" for col in range(dim) for row in range(dim)] - flat_args = ", ".join([f"float {name}" for name in flat_names]) - flat_cols: List[str] = [] - for col in range(dim): - values = [f"m{col}{row}" for row in range(dim)] - flat_cols.append(f"vkdispatch_make_{vec_helper_suffix}({', '.join(values)})") - flat_ctor = ", ".join(flat_cols) - - cast_cols = ", ".join([f"vkdispatch_make_{vec_helper_suffix}(m[{i}])" for i in range(dim)]) - - return "\n".join( - [ - f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}({col_args}) {{ return {mat_name}({col_ctor}); }}", - f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}(float s) {{ return {mat_name}(s); }}", - f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}({flat_args}) {{ return {mat_name}({flat_ctor}); }}", - "template ", - f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}(TMat m) {{ return {mat_name}({cast_cols}); }}", - ] - ) - - -def _cuda_composite_helpers() -> str: - parts: List[str] = [] - - vector_specs = [ - ("vkdispatch_int2", "int", 2, True, True, "int2"), - ("vkdispatch_int3", "int", 3, True, True, "int3"), - ("vkdispatch_int4", "int", 4, True, True, "int4"), - ("vkdispatch_uint2", "unsigned int", 2, False, True, "uint2"), - ("vkdispatch_uint3", "unsigned int", 3, False, True, "uint3"), - ("vkdispatch_uint4", "unsigned int", 4, False, True, "uint4"), - ("vkdispatch_float2", "float", 2, True, False, "float2"), - ("vkdispatch_float3", "float", 3, True, False, "float3"), - ("vkdispatch_float4", "float", 4, True, False, "float4"), - ] - - for vec_name, scalar_type, dim, allow_neg, enable_bitwise, helper_suffix in vector_specs: - parts.append( - _cuda_emit_vec_type( - vec_name, - scalar_type, - dim, - allow_unary_neg=allow_neg, - enable_bitwise=enable_bitwise, - ) - ) - parts.append(_cuda_emit_vec_helper(helper_suffix, vec_name, scalar_type, dim)) - - matrix_specs = [ - ("vkdispatch_mat2", "mat2", "vkdispatch_float2", "float2", 2), - ("vkdispatch_mat3", "mat3", "vkdispatch_float3", "float3", 3), - ("vkdispatch_mat4", "mat4", "vkdispatch_float4", "float4", 4), - ] - - for mat_name, helper_suffix, vec_name, vec_helper_suffix, dim in matrix_specs: - parts.append(_cuda_emit_mat_type(mat_name, vec_name, dim)) - parts.append(_cuda_emit_mat_helpers(mat_name, helper_suffix, vec_name, vec_helper_suffix, dim)) - - return "\n\n".join(parts) - - -_CUDA_VEC_TYPE_SPECS = { - "int2": ("vkdispatch_int2", "int", 2, True, True), - "int3": ("vkdispatch_int3", "int", 3, True, True), - "int4": ("vkdispatch_int4", "int", 4, True, True), - "uint2": ("vkdispatch_uint2", "unsigned int", 2, False, True), - "uint3": ("vkdispatch_uint3", "unsigned int", 3, False, True), - "uint4": ("vkdispatch_uint4", "unsigned int", 4, False, True), - "float2": ("vkdispatch_float2", "float", 2, True, False), - "float3": ("vkdispatch_float3", "float", 3, True, False), - "float4": ("vkdispatch_float4", "float", 4, True, False), -} - -_CUDA_MAT_TYPE_SPECS = { - "mat2": ("vkdispatch_mat2", "vkdispatch_float2", "float2", 2), - "mat3": ("vkdispatch_mat3", "vkdispatch_float3", "float3", 3), - "mat4": ("vkdispatch_mat4", "vkdispatch_float4", "float4", 4), -} - - -class CUDABackend(CodeGenBackend): - name = "cuda" - - _HELPER_SNIPPETS: Dict[str, str] = { - "composite_types": "", - "mat2_type": "", - "mat3_type": "", - "mat4_type": "", - "make_mat2": "", - "make_mat3": "", - "make_mat4": "", - "make_int2": "", - "make_int3": "", - "make_int4": "", - "make_uint2": "", - "make_uint3": "", - "make_uint4": "", - "float2_ops": "", - "make_float2": "", - "make_float3": "", - "make_float4": "", - "global_invocation_id": ( - "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_global_invocation_id() {\n" - " return vkdispatch_uint3(\n" - " (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x),\n" - " (unsigned int)(blockIdx.y * blockDim.y + threadIdx.y),\n" - " (unsigned int)(blockIdx.z * blockDim.z + threadIdx.z)\n" - " );\n" - "}" - ), - "local_invocation_id": ( - "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_local_invocation_id() {\n" - " return vkdispatch_uint3((unsigned int)threadIdx.x, (unsigned int)threadIdx.y, (unsigned int)threadIdx.z);\n" - "}" - ), - "workgroup_id": ( - "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_workgroup_id() {\n" - " return vkdispatch_uint3((unsigned int)blockIdx.x, (unsigned int)blockIdx.y, (unsigned int)blockIdx.z);\n" - "}" - ), - "local_invocation_index": ( - "__device__ __forceinline__ unsigned int vkdispatch_local_invocation_index() {\n" - " return (unsigned int)(threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z));\n" - "}" - ), - "subgroup_size": "__device__ __forceinline__ unsigned int vkdispatch_subgroup_size() { return (unsigned int)warpSize; }", - "num_subgroups": ( - "__device__ __forceinline__ unsigned int vkdispatch_num_subgroups() {\n" - " unsigned int local_count = (unsigned int)(blockDim.x * blockDim.y * blockDim.z);\n" - " return (local_count + vkdispatch_subgroup_size() - 1u) / vkdispatch_subgroup_size();\n" - "}" - ), - "subgroup_id": ( - "__device__ __forceinline__ unsigned int vkdispatch_subgroup_id() {\n" - " return vkdispatch_local_invocation_index() / vkdispatch_subgroup_size();\n" - "}" - ), - "subgroup_invocation_id": ( - "__device__ __forceinline__ unsigned int vkdispatch_subgroup_invocation_id() {\n" - " return vkdispatch_local_invocation_index() % vkdispatch_subgroup_size();\n" - "}" - ), - "subgroup_add": ( - "template \n" - "__device__ __forceinline__ T vkdispatch_subgroup_add(T value) {\n" - " unsigned int mask = 0xffffffffu;\n" - " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" - " value += __shfl_xor_sync(mask, value, (int)offset);\n" - " }\n" - " return value;\n" - "}" - ), - "subgroup_mul": ( - "template \n" - "__device__ __forceinline__ T vkdispatch_subgroup_mul(T value) {\n" - " unsigned int mask = 0xffffffffu;\n" - " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" - " value *= __shfl_xor_sync(mask, value, (int)offset);\n" - " }\n" - " return value;\n" - "}" - ), - "subgroup_min": ( - "template \n" - "__device__ __forceinline__ T vkdispatch_subgroup_min(T value) {\n" - " unsigned int mask = 0xffffffffu;\n" - " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" - " T other = __shfl_xor_sync(mask, value, (int)offset);\n" - " value = other < value ? other : value;\n" - " }\n" - " return value;\n" - "}" - ), - "subgroup_max": ( - "template \n" - "__device__ __forceinline__ T vkdispatch_subgroup_max(T value) {\n" - " unsigned int mask = 0xffffffffu;\n" - " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" - " T other = __shfl_xor_sync(mask, value, (int)offset);\n" - " value = other > value ? other : value;\n" - " }\n" - " return value;\n" - "}" - ), - "subgroup_and": ( - "template \n" - "__device__ __forceinline__ T vkdispatch_subgroup_and(T value) {\n" - " unsigned int mask = 0xffffffffu;\n" - " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" - " value &= __shfl_xor_sync(mask, value, (int)offset);\n" - " }\n" - " return value;\n" - "}" - ), - "subgroup_or": ( - "template \n" - "__device__ __forceinline__ T vkdispatch_subgroup_or(T value) {\n" - " unsigned int mask = 0xffffffffu;\n" - " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" - " value |= __shfl_xor_sync(mask, value, (int)offset);\n" - " }\n" - " return value;\n" - "}" - ), - "subgroup_xor": ( - "template \n" - "__device__ __forceinline__ T vkdispatch_subgroup_xor(T value) {\n" - " unsigned int mask = 0xffffffffu;\n" - " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" - " value ^= __shfl_xor_sync(mask, value, (int)offset);\n" - " }\n" - " return value;\n" - "}" - ), - "mod": "__device__ __forceinline__ float mod(float x, float y) { return fmodf(x, y); }", - "fract": "__device__ __forceinline__ float fract(float x) { return x - floorf(x); }", - "roundEven": "__device__ __forceinline__ float roundEven(float x) { return nearbyintf(x); }", - "mix": "__device__ __forceinline__ float mix(float x, float y, float a) { return x + (y - x) * a; }", - "step": "__device__ __forceinline__ float step(float edge, float x) { return x < edge ? 0.0f : 1.0f; }", - "smoothstep": ( - "__device__ __forceinline__ float smoothstep(float edge0, float edge1, float x) {\n" - " float t = fminf(fmaxf((x - edge0) / (edge1 - edge0), 0.0f), 1.0f);\n" - " return t * t * (3.0f - 2.0f * t);\n" - "}" - ), - "radians": "__device__ __forceinline__ float radians(float x) { return x * (3.14159265358979323846f / 180.0f); }", - "degrees": "__device__ __forceinline__ float degrees(float x) { return x * (180.0f / 3.14159265358979323846f); }", - "inversesqrt": "__device__ __forceinline__ float inversesqrt(float x) { return rsqrtf(x); }", - "floatBitsToInt": "__device__ __forceinline__ int floatBitsToInt(float x) { return __float_as_int(x); }", - "floatBitsToUint": "__device__ __forceinline__ unsigned int floatBitsToUint(float x) { return __float_as_uint(x); }", - "intBitsToFloat": "__device__ __forceinline__ float intBitsToFloat(int x) { return __int_as_float(x); }", - "uintBitsToFloat": "__device__ __forceinline__ float uintBitsToFloat(unsigned int x) { return __uint_as_float(x); }", - "sample_texture": "", - } - - _HELPER_ORDER: List[str] = [ - "composite_types", - "global_invocation_id", - "local_invocation_id", - "workgroup_id", - "local_invocation_index", - "subgroup_size", - "num_subgroups", - "subgroup_id", - "subgroup_invocation_id", - "subgroup_add", - "subgroup_mul", - "subgroup_min", - "subgroup_max", - "subgroup_and", - "subgroup_or", - "subgroup_xor", - "mod", - "fract", - "roundEven", - "mix", - "step", - "smoothstep", - "radians", - "degrees", - "inversesqrt", - "floatBitsToInt", - "floatBitsToUint", - "intBitsToFloat", - "uintBitsToFloat", - "sample_texture", - ] - - _HELPER_DEPENDENCIES: Dict[str, List[str]] = { - "mat2_type": ["composite_types"], - "mat3_type": ["composite_types"], - "mat4_type": ["composite_types"], - "make_mat2": ["composite_types"], - "make_mat3": ["composite_types"], - "make_mat4": ["composite_types"], - "make_int2": ["composite_types"], - "make_int3": ["composite_types"], - "make_int4": ["composite_types"], - "make_uint2": ["composite_types"], - "make_uint3": ["composite_types"], - "make_uint4": ["composite_types"], - "float2_ops": ["composite_types"], - "make_float2": ["composite_types"], - "make_float3": ["composite_types"], - "make_float4": ["composite_types"], - "global_invocation_id": ["composite_types"], - "local_invocation_id": ["composite_types"], - "workgroup_id": ["composite_types"], - "sample_texture": ["composite_types"], - "num_subgroups": ["subgroup_size"], - "subgroup_id": ["local_invocation_index", "subgroup_size"], - "subgroup_invocation_id": ["local_invocation_index", "subgroup_size"], - "subgroup_add": ["subgroup_size"], - "subgroup_mul": ["subgroup_size"], - "subgroup_min": ["subgroup_size"], - "subgroup_max": ["subgroup_size"], - "subgroup_and": ["subgroup_size"], - "subgroup_or": ["subgroup_size"], - "subgroup_xor": ["subgroup_size"], - } - - def __init__(self) -> None: - self._fixed_preamble = "" - self.reset_state() - - def reset_state(self) -> None: - self._kernel_params: List[str] = [] - self._entry_alias_lines: List[str] = [] - self._composite_type_usage: Set[str] = set() - self._composite_vec_op_usage: Dict[str, Set[str]] = {} - self._composite_mat_op_usage: Dict[str, Set[str]] = {} - self._sample_texture_dims: Set[int] = set() - self._feature_usage: Dict[str, bool] = { - feature_name: False - for feature_name in self._HELPER_SNIPPETS - } - - def mark_feature_usage(self, feature_name: str) -> None: - if feature_name in self._feature_usage: - self._feature_usage[feature_name] = True - - def _composite_key_for_dtype(self, var_type: dtypes.dtype) -> Optional[str]: - if var_type == dtypes.complex64 or var_type == dtypes.vec2: - return "float2" - if var_type == dtypes.vec3: - return "float3" - if var_type == dtypes.vec4: - return "float4" - if var_type == dtypes.ivec2: - return "int2" - if var_type == dtypes.ivec3: - return "int3" - if var_type == dtypes.ivec4: - return "int4" - if var_type == dtypes.uvec2: - return "uint2" - if var_type == dtypes.uvec3: - return "uint3" - if var_type == dtypes.uvec4: - return "uint4" - if var_type == dtypes.mat2: - return "mat2" - if var_type == dtypes.mat3: - return "mat3" - if var_type == dtypes.mat4: - return "mat4" - return None - - def _record_composite_type_key(self, key: str) -> None: - self.mark_feature_usage("composite_types") - self._composite_type_usage.add(key) - - if key in _CUDA_MAT_TYPE_SPECS: - dim = _CUDA_MAT_TYPE_SPECS[key][3] - self._composite_type_usage.add(f"float{dim}") - - def _record_composite_type(self, var_type: dtypes.dtype) -> Optional[str]: - key = self._composite_key_for_dtype(var_type) - if key is None: - return None - self._record_composite_type_key(key) - return key - - def _record_vec_op(self, key: str, token: str) -> None: - self._record_composite_type_key(key) - self._composite_vec_op_usage.setdefault(key, set()).add(token) - - def _record_mat_op(self, key: str, token: str) -> None: - self._record_composite_type_key(key) - self._composite_mat_op_usage.setdefault(key, set()).add(token) - - def _propagate_matrix_vec_dependencies(self, mat_key: str, token: str) -> None: - dim = _CUDA_MAT_TYPE_SPECS[mat_key][3] - vec_key = f"float{dim}" - - if token == "un:-": - self._record_vec_op(vec_key, "un:-") - return - - if token.startswith("cmpd:"): - if token.endswith(":m"): - vec_token = token[:-1] + "v" - self._record_vec_op(vec_key, vec_token) - return - if token.endswith(":s"): - self._record_vec_op(vec_key, token) - return - - if token.startswith("bin:"): - parts = token.split(":") - if len(parts) != 3: - return - _, op, shape = parts - if shape == "mm": - if op in ["+", "-"]: - self._record_vec_op(vec_key, f"bin:{op}:vv") - elif op == "*": - self._record_mat_op(mat_key, "bin:*:mv") - self._propagate_matrix_vec_dependencies(mat_key, "bin:*:mv") - return - if shape == "ms": - self._record_vec_op(vec_key, f"bin:{op}:vs") - return - if shape == "sm": - self._record_vec_op(vec_key, f"bin:{op}:sv") - return - if shape == "mv": - self._record_vec_op(vec_key, "bin:*:vs") - self._record_vec_op(vec_key, "bin:+:vv") - return - if shape == "vm": - return - - def mark_composite_unary_op(self, var_type: dtypes.dtype, op: str) -> None: - key = self._record_composite_type(var_type) - if key is None: - return - - token = f"un:{op}" - if key in _CUDA_VEC_TYPE_SPECS: - self._record_vec_op(key, token) - return - if key in _CUDA_MAT_TYPE_SPECS: - self._record_mat_op(key, token) - self._propagate_matrix_vec_dependencies(key, token) - - def mark_composite_binary_op( - self, - lhs_type: dtypes.dtype, - rhs_type: dtypes.dtype, - op: str, - *, - inplace: bool = False, - ) -> None: - lhs_key = self._record_composite_type(lhs_type) - rhs_key = self._record_composite_type(rhs_type) - - lhs_is_composite = lhs_key is not None - rhs_is_composite = rhs_key is not None - if not lhs_is_composite and not rhs_is_composite: - return - - lhs_is_scalar = dtypes.is_scalar(lhs_type) - rhs_is_scalar = dtypes.is_scalar(rhs_type) - - if lhs_key in _CUDA_VEC_TYPE_SPECS and (rhs_is_scalar or rhs_key in _CUDA_VEC_TYPE_SPECS): - if inplace: - suffix = "s" if rhs_is_scalar else "v" - self._record_vec_op(lhs_key, f"cmpd:{op}=:{suffix}") - return - shape = "vs" if rhs_is_scalar else "vv" - self._record_vec_op(lhs_key, f"bin:{op}:{shape}") - return - - if rhs_key in _CUDA_VEC_TYPE_SPECS and lhs_is_scalar and not inplace: - self._record_vec_op(rhs_key, f"bin:{op}:sv") - return - - if lhs_key in _CUDA_MAT_TYPE_SPECS: - if inplace: - if rhs_is_scalar: - token = f"cmpd:{op}=:s" - elif rhs_key in _CUDA_MAT_TYPE_SPECS: - token = f"cmpd:{op}=:m" - else: - return - self._record_mat_op(lhs_key, token) - self._propagate_matrix_vec_dependencies(lhs_key, token) - return - - if rhs_is_scalar: - token = f"bin:{op}:ms" - self._record_mat_op(lhs_key, token) - self._propagate_matrix_vec_dependencies(lhs_key, token) - return - - if rhs_key in _CUDA_MAT_TYPE_SPECS: - token = "bin:*:mm" if op == "*" else f"bin:{op}:mm" - self._record_mat_op(lhs_key, token) - self._propagate_matrix_vec_dependencies(lhs_key, token) - return - - if rhs_key in _CUDA_VEC_TYPE_SPECS and op == "*": - token = "bin:*:mv" - self._record_mat_op(lhs_key, token) - self._propagate_matrix_vec_dependencies(lhs_key, token) - return - - if rhs_key in _CUDA_MAT_TYPE_SPECS and lhs_is_scalar and not inplace: - token = f"bin:{op}:sm" - self._record_mat_op(rhs_key, token) - self._propagate_matrix_vec_dependencies(rhs_key, token) - return - - if lhs_key in _CUDA_VEC_TYPE_SPECS and rhs_key in _CUDA_MAT_TYPE_SPECS and op == "*" and not inplace: - token = "bin:*:vm" - self._record_mat_op(rhs_key, token) - self._propagate_matrix_vec_dependencies(rhs_key, token) - - def mark_texture_sample_dimension(self, dimensions: int) -> None: - self._sample_texture_dims.add(dimensions) - self.mark_feature_usage("sample_texture") - self._record_composite_type_key("float4") - if dimensions == 2: - self._record_composite_type_key("float2") - elif dimensions == 3: - self._record_composite_type_key("float3") - - def _emit_used_composite_helpers(self) -> str: - if len(self._composite_type_usage) == 0: - return "" - - parts: List[str] = [] - - vec_order = ["int2", "int3", "int4", "uint2", "uint3", "uint4", "float2", "float3", "float4"] - for key in vec_order: - if key not in self._composite_type_usage: - continue - vec_name, scalar_type, dim, allow_neg, enable_bitwise = _CUDA_VEC_TYPE_SPECS[key] - parts.append( - _cuda_emit_vec_type( - vec_name, - scalar_type, - dim, - allow_unary_neg=allow_neg, - enable_bitwise=enable_bitwise, - needed_ops=self._composite_vec_op_usage.get(key, set()), - ) - ) - parts.append(_cuda_emit_vec_helper(key, vec_name, scalar_type, dim)) - - mat_order = ["mat2", "mat3", "mat4"] - for key in mat_order: - if key not in self._composite_type_usage: - continue - mat_name, vec_name, vec_helper_suffix, dim = _CUDA_MAT_TYPE_SPECS[key] - parts.append(_cuda_emit_mat_type(mat_name, vec_name, dim, self._composite_mat_op_usage.get(key, set()))) - parts.append(_cuda_emit_mat_helpers(mat_name, key, vec_name, vec_helper_suffix, dim)) - - return "\n\n".join(parts) - - def _emit_sample_texture_helpers(self) -> str: - dims = set(self._sample_texture_dims) - if len(dims) == 0: - dims = {1, 2, 3} - - lines: List[str] = [] - if 1 in dims: - lines.append( - "__device__ __forceinline__ vkdispatch_float4 vkdispatch_sample_texture(cudaTextureObject_t tex, float coord) { return vkdispatch_make_float4(tex1D(tex, coord)); }" - ) - lines.append( - "__device__ __forceinline__ vkdispatch_float4 vkdispatch_sample_texture(cudaTextureObject_t tex, float coord, float lod) { return vkdispatch_make_float4(tex1DLod(tex, coord, lod)); }" - ) - self._record_composite_type_key("float4") - if 2 in dims: - lines.append( - "__device__ __forceinline__ vkdispatch_float4 vkdispatch_sample_texture(cudaTextureObject_t tex, vkdispatch_float2 coord) { return vkdispatch_make_float4(tex2D(tex, coord.x, coord.y)); }" - ) - lines.append( - "__device__ __forceinline__ vkdispatch_float4 vkdispatch_sample_texture(cudaTextureObject_t tex, vkdispatch_float2 coord, float lod) { return vkdispatch_make_float4(tex2DLod(tex, coord.x, coord.y, lod)); }" - ) - self._record_composite_type_key("float2") - self._record_composite_type_key("float4") - if 3 in dims: - lines.append( - "__device__ __forceinline__ vkdispatch_float4 vkdispatch_sample_texture(cudaTextureObject_t tex, vkdispatch_float3 coord) { return vkdispatch_make_float4(tex3D(tex, coord.x, coord.y, coord.z)); }" - ) - lines.append( - "__device__ __forceinline__ vkdispatch_float4 vkdispatch_sample_texture(cudaTextureObject_t tex, vkdispatch_float3 coord, float lod) { return vkdispatch_make_float4(tex3DLod(tex, coord.x, coord.y, coord.z, lod)); }" - ) - self._record_composite_type_key("float3") - self._record_composite_type_key("float4") - - return "\n".join(lines) - - def _register_kernel_param(self, param_decl: str) -> None: - if param_decl not in self._kernel_params: - self._kernel_params.append(param_decl) - - def _register_alias_line(self, alias_line: str) -> None: - if alias_line not in self._entry_alias_lines: - self._entry_alias_lines.append(alias_line) - - @staticmethod - def _is_plain_integer_literal(expr: str) -> bool: - if len(expr) == 0: - return False - if expr[0] in "+-": - return len(expr) > 1 and expr[1:].isdigit() - return expr.isdigit() - - def type_name(self, var_type: dtypes.dtype) -> str: - if var_type == dtypes.int32: - return "int" - if var_type == dtypes.uint32: - return "unsigned int" - if var_type == dtypes.float32: - return "float" - if var_type == dtypes.complex64: - self._record_composite_type(var_type) - return "vkdispatch_float2" - - if var_type == dtypes.ivec2: - self._record_composite_type(var_type) - return "vkdispatch_int2" - if var_type == dtypes.ivec3: - self._record_composite_type(var_type) - return "vkdispatch_int3" - if var_type == dtypes.ivec4: - self._record_composite_type(var_type) - return "vkdispatch_int4" - - if var_type == dtypes.uvec2: - self._record_composite_type(var_type) - return "vkdispatch_uint2" - if var_type == dtypes.uvec3: - self._record_composite_type(var_type) - return "vkdispatch_uint3" - if var_type == dtypes.uvec4: - self._record_composite_type(var_type) - return "vkdispatch_uint4" - - if var_type == dtypes.vec2: - self._record_composite_type(var_type) - return "vkdispatch_float2" - if var_type == dtypes.vec3: - self._record_composite_type(var_type) - return "vkdispatch_float3" - if var_type == dtypes.vec4: - self._record_composite_type(var_type) - return "vkdispatch_float4" - - if var_type == dtypes.mat2: - self._record_composite_type(var_type) - return "vkdispatch_mat2" - if var_type == dtypes.mat3: - self._record_composite_type(var_type) - return "vkdispatch_mat3" - if var_type == dtypes.mat4: - self._record_composite_type(var_type) - return "vkdispatch_mat4" - - raise ValueError(f"Unsupported CUDA type mapping for '{var_type.name}'") - - def constructor(self, var_type: dtypes.dtype, args: List[str]) -> str: - if ( - len(args) == 1 - and var_type in (dtypes.complex64, dtypes.vec2, dtypes.vec3, dtypes.vec4) - and self._is_plain_integer_literal(args[0]) - ): - args = [f"{args[0]}.0f"] - - target_type = self.type_name(var_type) - - if dtypes.is_scalar(var_type): - assert len(args) > 0, f"Constructor for scalar type '{var_type.name}' needs at least one argument." - return f"(({target_type})({args[0]}))" - - if var_type == dtypes.mat2: - self.mark_feature_usage("make_mat2") - return f"vkdispatch_make_mat2({', '.join(args)})" - if var_type == dtypes.mat3: - self.mark_feature_usage("make_mat3") - return f"vkdispatch_make_mat3({', '.join(args)})" - if var_type == dtypes.mat4: - self.mark_feature_usage("make_mat4") - return f"vkdispatch_make_mat4({', '.join(args)})" - - helper_suffix = target_type[len("vkdispatch_"):] if target_type.startswith("vkdispatch_") else target_type - helper_name = f"vkdispatch_make_{helper_suffix}" - self.mark_feature_usage(f"make_{helper_suffix}") - return f"{helper_name}({', '.join(args)})" - - def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str: - self.reset_state() - - subgroup_support = "1" if enable_subgroup_ops else "0" - printf_support = "1" if enable_printf else "0" - - self._fixed_preamble = ( - "#include \n" - "#include \n" - "#include \n\n" - f"#define VKDISPATCH_ENABLE_SUBGROUP_OPS {subgroup_support}\n" - f"#define VKDISPATCH_ENABLE_PRINTF {printf_support}\n\n" - ) - - return self._fixed_preamble - - def _resolve_helper_dependencies(self, helpers: Set[str]) -> Set[str]: - pending = list(helpers) - resolved = set(helpers) - - while len(pending) > 0: - helper_name = pending.pop() - - for dependency in self._HELPER_DEPENDENCIES.get(helper_name, []): - if dependency not in resolved: - resolved.add(dependency) - pending.append(dependency) - - return resolved - - def _helper_header(self) -> str: - enabled_helpers = { - helper_name - for helper_name, is_enabled in self._feature_usage.items() - if is_enabled - } - - resolved_helpers = self._resolve_helper_dependencies(enabled_helpers) - - if len(resolved_helpers) == 0: - return "" - - helper_sections: List[str] = [] - - for helper_name in self._HELPER_ORDER: - if helper_name in resolved_helpers: - if helper_name == "composite_types": - composite_helpers = self._emit_used_composite_helpers() - if len(composite_helpers) > 0: - helper_sections.append(composite_helpers) - continue - if helper_name == "sample_texture": - texture_helpers = self._emit_sample_texture_helpers() - if len(texture_helpers) > 0: - helper_sections.append(texture_helpers) - continue - - snippet = self._HELPER_SNIPPETS[helper_name] - if len(snippet) > 0: - helper_sections.append(snippet) - - return "\n\n".join(helper_sections) + "\n\n" - - def make_source(self, header: str, body: str, x: int, y: int, z: int) -> str: - expected_size_header = ( - f"// Expected local size: ({x}, {y}, {z})\n" - f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X {x}\n" - f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y {y}\n" - f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Z {z}\n" - ) - - helper_header = self._helper_header() - - if len(helper_header) == 0: - return f"{expected_size_header}\n{header}\n{body}" - - if len(self._fixed_preamble) > 0 and header.startswith(self._fixed_preamble): - header_suffix = header[len(self._fixed_preamble):] - finalized_header = f"{self._fixed_preamble}{helper_header}{header_suffix}" - else: - finalized_header = f"{header}\n{helper_header}" - - return f"{expected_size_header}\n{finalized_header}\n{body}" - - def constant_namespace(self) -> str: - return "UBO" - - def variable_namespace(self) -> str: - return "PC" - - def exec_bounds_guard(self, exec_count_expr: str) -> str: - gid = self.global_invocation_id_expr() - return ( - f"if (({exec_count_expr}).x <= ({gid}).x || " - f"({exec_count_expr}).y <= ({gid}).y || " - f"({exec_count_expr}).z <= ({gid}).z) {{ return; }}\n" - ) - - def shared_buffer_declaration(self, var_type: dtypes.dtype, name: str, size: int) -> str: - return f"__shared__ {self.type_name(var_type)} {name}[{size}];" - - def uniform_block_declaration(self, contents: str) -> str: - self._register_kernel_param("const UniformObjectBuffer* vkdispatch_uniform_ptr") - self._register_alias_line("const UniformObjectBuffer& UBO = *vkdispatch_uniform_ptr;") - return f"\nstruct UniformObjectBuffer {{\n{contents}\n}};\n" - - def storage_buffer_declaration(self, binding: int, var_type: dtypes.dtype, name: str) -> str: - struct_name = f"Buffer{binding}" - param_name = f"vkdispatch_binding_{binding}_ptr" - self._register_kernel_param(f"{self.type_name(var_type)}* {param_name}") - self._register_alias_line(f"{struct_name} {name} = {{{param_name}}};") - return f"struct {struct_name} {{ {self.type_name(var_type)}* data; }};\n" - - def sampler_declaration(self, binding: int, dimensions: int, name: str) -> str: - param_name = f"vkdispatch_sampler_{binding}" - self._register_kernel_param(f"cudaTextureObject_t {param_name}") - self._register_alias_line(f"cudaTextureObject_t {name} = {param_name};") - return f"// sampler binding {binding}, dimensions={dimensions}\n" - - def push_constant_declaration(self, contents: str) -> str: - self._register_kernel_param("const PushConstant* vkdispatch_pc_ptr") - self._register_alias_line("const PushConstant& PC = *vkdispatch_pc_ptr;") - return f"\nstruct PushConstant {{\n{contents}\n}};\n" - - def entry_point(self, body_contents: str) -> str: - params = ", ".join(self._kernel_params) - - alias_block = "" - for line in self._entry_alias_lines: - alias_block += f" {line}\n" - - return ( - f'extern "C" __global__ void vkdispatch_main({params}) {{\n' - f"{alias_block}" - f"{body_contents}" - f"}}\n" - ) - - def inf_f32_expr(self) -> str: - self.mark_feature_usage("uintBitsToFloat") - return "uintBitsToFloat(0x7F800000u)" - - def ninf_f32_expr(self) -> str: - self.mark_feature_usage("uintBitsToFloat") - return "uintBitsToFloat(0xFF800000u)" - - def fma_function_name(self, var_type: dtypes.dtype) -> str: - if var_type == dtypes.float32: - return "fmaf" - return "fma" - - def float_bits_to_int_expr(self, var_expr: str) -> str: - self.mark_feature_usage("floatBitsToInt") - return f"floatBitsToInt({var_expr})" - - def float_bits_to_uint_expr(self, var_expr: str) -> str: - self.mark_feature_usage("floatBitsToUint") - return f"floatBitsToUint({var_expr})" - - def int_bits_to_float_expr(self, var_expr: str) -> str: - self.mark_feature_usage("intBitsToFloat") - return f"intBitsToFloat({var_expr})" - - def uint_bits_to_float_expr(self, var_expr: str) -> str: - self.mark_feature_usage("uintBitsToFloat") - return f"uintBitsToFloat({var_expr})" - - def global_invocation_id_expr(self) -> str: - self._record_composite_type_key("uint3") - self.mark_feature_usage("global_invocation_id") - return "vkdispatch_global_invocation_id()" - - def local_invocation_id_expr(self) -> str: - self._record_composite_type_key("uint3") - self.mark_feature_usage("local_invocation_id") - return "vkdispatch_local_invocation_id()" - - def local_invocation_index_expr(self) -> str: - self.mark_feature_usage("local_invocation_index") - return "vkdispatch_local_invocation_index()" - - def workgroup_id_expr(self) -> str: - self._record_composite_type_key("uint3") - self.mark_feature_usage("workgroup_id") - return "vkdispatch_workgroup_id()" - - def workgroup_size_expr(self) -> str: - self._record_composite_type_key("uint3") - self.mark_feature_usage("make_uint3") - return "vkdispatch_make_uint3((unsigned int)blockDim.x, (unsigned int)blockDim.y, (unsigned int)blockDim.z)" - - def num_workgroups_expr(self) -> str: - self._record_composite_type_key("uint3") - self.mark_feature_usage("make_uint3") - return "vkdispatch_make_uint3((unsigned int)gridDim.x, (unsigned int)gridDim.y, (unsigned int)gridDim.z)" - - def num_subgroups_expr(self) -> str: - self.mark_feature_usage("num_subgroups") - return "vkdispatch_num_subgroups()" - - def subgroup_id_expr(self) -> str: - self.mark_feature_usage("subgroup_id") - return "vkdispatch_subgroup_id()" - - def subgroup_size_expr(self) -> str: - self.mark_feature_usage("subgroup_size") - return "vkdispatch_subgroup_size()" - - def subgroup_invocation_id_expr(self) -> str: - self.mark_feature_usage("subgroup_invocation_id") - return "vkdispatch_subgroup_invocation_id()" - - def barrier_statement(self) -> str: - return "__syncthreads();" - - def memory_barrier_statement(self) -> str: - return "__threadfence();" - - def memory_barrier_buffer_statement(self) -> str: - return "__threadfence();" - - def memory_barrier_shared_statement(self) -> str: - return "__threadfence_block();" - - def memory_barrier_image_statement(self) -> str: - return "__threadfence();" - - def group_memory_barrier_statement(self) -> str: - return "__threadfence_block();" - - def subgroup_add_expr(self, arg_expr: str) -> str: - self.mark_feature_usage("subgroup_add") - return f"vkdispatch_subgroup_add({arg_expr})" - - def subgroup_mul_expr(self, arg_expr: str) -> str: - self.mark_feature_usage("subgroup_mul") - return f"vkdispatch_subgroup_mul({arg_expr})" - - def subgroup_min_expr(self, arg_expr: str) -> str: - self.mark_feature_usage("subgroup_min") - return f"vkdispatch_subgroup_min({arg_expr})" - - def subgroup_max_expr(self, arg_expr: str) -> str: - self.mark_feature_usage("subgroup_max") - return f"vkdispatch_subgroup_max({arg_expr})" - - def subgroup_and_expr(self, arg_expr: str) -> str: - self.mark_feature_usage("subgroup_and") - return f"vkdispatch_subgroup_and({arg_expr})" - - def subgroup_or_expr(self, arg_expr: str) -> str: - self.mark_feature_usage("subgroup_or") - return f"vkdispatch_subgroup_or({arg_expr})" - - def subgroup_xor_expr(self, arg_expr: str) -> str: - self.mark_feature_usage("subgroup_xor") - return f"vkdispatch_subgroup_xor({arg_expr})" - - def subgroup_elect_expr(self) -> str: - self.mark_feature_usage("subgroup_invocation_id") - return "((int)(vkdispatch_subgroup_invocation_id() == 0u))" - - def subgroup_barrier_statement(self) -> str: - return "__syncwarp();" - - def printf_statement(self, fmt: str, args: List[str]) -> str: - safe_fmt = fmt.replace("\\", "\\\\").replace('"', '\\"') - - if len(args) == 0: - return f'printf("{safe_fmt}");' - - return f'printf("{safe_fmt}", {", ".join(args)});' - - def texture_size_expr(self, texture_expr: str, lod: int, dimensions: int) -> str: - # CUDA texture objects do not expose shape directly in device code. - # The future CUDA backend should pass explicit texture shape parameters. - if dimensions == 1: - return "1.0f" - if dimensions == 2: - self.mark_feature_usage("make_float2") - return "vkdispatch_make_float2(1.0f)" - if dimensions == 3: - self.mark_feature_usage("make_float3") - return "vkdispatch_make_float3(1.0f)" - - raise ValueError(f"Unsupported texture dimensions '{dimensions}'") - - def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Optional[str] = None) -> str: - self.mark_feature_usage("sample_texture") - if lod_expr is None: - return f"vkdispatch_sample_texture({texture_expr}, {coord_expr})" - - return f"vkdispatch_sample_texture({texture_expr}, {coord_expr}, {lod_expr})" diff --git a/vkdispatch/codegen/backends/cuda/__init__.py b/vkdispatch/codegen/backends/cuda/__init__.py new file mode 100644 index 00000000..31730746 --- /dev/null +++ b/vkdispatch/codegen/backends/cuda/__init__.py @@ -0,0 +1,3 @@ +from .backend import CUDABackend + +__all__ = ["CUDABackend"] diff --git a/vkdispatch/codegen/backends/cuda/backend.py b/vkdispatch/codegen/backends/cuda/backend.py new file mode 100644 index 00000000..7cd91f29 --- /dev/null +++ b/vkdispatch/codegen/backends/cuda/backend.py @@ -0,0 +1,892 @@ +from typing import Dict, List, Optional, Set, Tuple + +import vkdispatch.base.dtype as dtypes + +from ..base import CodeGenBackend +from .composite_emitters import ( + _cuda_emit_mat_helpers, + _cuda_emit_mat_type, + _cuda_emit_subgroup_shuffle_xor_vec_overloads, + _cuda_emit_vec_helper, + _cuda_emit_vec_type, + _cuda_emit_vec_wrapper_conversion_helpers, +) +from .helper_snippets import ( + _HELPER_DEPENDENCIES as _CUDA_HELPER_DEPENDENCIES, + _HELPER_ORDER as _CUDA_HELPER_ORDER, + _HELPER_SNIPPETS as _CUDA_HELPER_SNIPPETS, + initialize_feature_usage, +) +from .math_utils import ( + cuda_fast_binary_math_name, + cuda_fast_unary_math_name, + cuda_float_vec_components_for_suffix, + cuda_float_vec_helper_suffix, + cuda_scalar_binary_math_name, + cuda_scalar_unary_math_name, + emit_used_vec_math_helpers, +) +from .specs import ( + _CUDA_MAT_ORDER, + _CUDA_MAT_TYPE_SPECS, + _CUDA_VEC_ORDER, + _CUDA_VEC_TYPE_SPECS, + _DTYPE_TO_COMPOSITE_KEY as _CUDA_DTYPE_TO_COMPOSITE_KEY, + _FLOAT_VEC_DTYPES as _CUDA_FLOAT_VEC_DTYPES, + _FLOAT_VEC_HELPER_SUFFIX_MAP as _CUDA_FLOAT_VEC_HELPER_SUFFIX_MAP, + _SCALAR_TYPE_NAMES as _CUDA_SCALAR_TYPE_NAMES, +) + +class CUDABackend(CodeGenBackend): + name = "cuda" + _CUDA_BUILTIN_UVEC3_SENTINELS: Dict[str, Dict[str, str]] = { + "global_invocation_id": { + "sentinel": "VKDISPATCH_CUDA_GLOBAL_INVOCATION_ID_SENTINEL()", + "x": "(unsigned int)(blockIdx.x * blockDim.x + threadIdx.x)", + "y": "(unsigned int)(blockIdx.y * blockDim.y + threadIdx.y)", + "z": "(unsigned int)(blockIdx.z * blockDim.z + threadIdx.z)", + }, + "local_invocation_id": { + "sentinel": "VKDISPATCH_CUDA_LOCAL_INVOCATION_ID_SENTINEL()", + "x": "(unsigned int)threadIdx.x", + "y": "(unsigned int)threadIdx.y", + "z": "(unsigned int)threadIdx.z", + }, + "workgroup_id": { + "sentinel": "VKDISPATCH_CUDA_WORKGROUP_ID_SENTINEL()", + "x": "(unsigned int)blockIdx.x", + "y": "(unsigned int)blockIdx.y", + "z": "(unsigned int)blockIdx.z", + }, + } + + _HELPER_SNIPPETS: Dict[str, str] = _CUDA_HELPER_SNIPPETS + _HELPER_ORDER: List[str] = _CUDA_HELPER_ORDER + _HELPER_DEPENDENCIES: Dict[str, List[str]] = _CUDA_HELPER_DEPENDENCIES + + def __init__(self) -> None: + self._fixed_preamble = "" + self.reset_state() + + def reset_state(self) -> None: + self._kernel_params: List[str] = [] + self._entry_alias_lines: List[str] = [] + self._composite_type_usage: Set[str] = set() + self._composite_vec_op_usage: Dict[str, Set[str]] = {} + self._composite_mat_op_usage: Dict[str, Set[str]] = {} + self._composite_vec_unary_math_usage: Dict[str, Set[str]] = {} + self._composite_vec_binary_math_usage: Dict[str, Set[str]] = {} + self._sample_texture_dims: Set[int] = set() + self._needs_cuda_fp16: bool = False + self._feature_usage: Dict[str, bool] = initialize_feature_usage() + + def mark_feature_usage(self, feature_name: str) -> None: + if feature_name in self._feature_usage: + self._feature_usage[feature_name] = True + + _DTYPE_TO_COMPOSITE_KEY = _CUDA_DTYPE_TO_COMPOSITE_KEY + + def _composite_key_for_dtype(self, var_type: dtypes.dtype) -> Optional[str]: + return self._DTYPE_TO_COMPOSITE_KEY.get(var_type) + + def _record_composite_type_key(self, key: str) -> None: + self.mark_feature_usage("composite_types") + self._composite_type_usage.add(key) + + if key in _CUDA_MAT_TYPE_SPECS: + dim = _CUDA_MAT_TYPE_SPECS[key][3] + self._composite_type_usage.add(f"float{dim}") + + def _record_composite_type(self, var_type: dtypes.dtype) -> Optional[str]: + key = self._composite_key_for_dtype(var_type) + if key is None: + return None + self._record_composite_type_key(key) + return key + + def _record_vec_op(self, key: str, token: str) -> None: + self._record_composite_type_key(key) + self._composite_vec_op_usage.setdefault(key, set()).add(token) + + def _record_mat_op(self, key: str, token: str) -> None: + self._record_composite_type_key(key) + self._composite_mat_op_usage.setdefault(key, set()).add(token) + + def _record_vec_unary_math(self, key: str, func_name: str) -> None: + self._record_composite_type_key(key) + self._composite_vec_unary_math_usage.setdefault(key, set()).add(func_name) + + def _record_vec_binary_math(self, key: str, func_name: str, signature: str) -> None: + self._record_composite_type_key(key) + self._composite_vec_binary_math_usage.setdefault(key, set()).add(f"{func_name}:{signature}") + + def _propagate_matrix_vec_dependencies(self, mat_key: str, token: str) -> None: + dim = _CUDA_MAT_TYPE_SPECS[mat_key][3] + vec_key = f"float{dim}" + + if token == "un:-": + self._record_vec_op(vec_key, "un:-") + return + + if token.startswith("cmpd:"): + if token.endswith(":m"): + vec_token = token[:-1] + "v" + self._record_vec_op(vec_key, vec_token) + return + if token.endswith(":s"): + self._record_vec_op(vec_key, token) + return + + if token.startswith("bin:"): + parts = token.split(":") + if len(parts) != 3: + return + _, op, shape = parts + if shape == "mm": + if op in ["+", "-"]: + self._record_vec_op(vec_key, f"bin:{op}:vv") + elif op == "*": + self._record_mat_op(mat_key, "bin:*:mv") + self._propagate_matrix_vec_dependencies(mat_key, "bin:*:mv") + return + if shape == "ms": + self._record_vec_op(vec_key, f"bin:{op}:vs") + return + if shape == "sm": + self._record_vec_op(vec_key, f"bin:{op}:sv") + return + if shape == "mv": + self._record_vec_op(vec_key, "bin:*:vs") + self._record_vec_op(vec_key, "bin:+:vv") + return + if shape == "vm": + return + + def mark_composite_unary_op(self, var_type: dtypes.dtype, op: str) -> None: + key = self._record_composite_type(var_type) + if key is None: + return + + token = f"un:{op}" + if key in _CUDA_VEC_TYPE_SPECS: + self._record_vec_op(key, token) + return + if key in _CUDA_MAT_TYPE_SPECS: + self._record_mat_op(key, token) + self._propagate_matrix_vec_dependencies(key, token) + + def mark_composite_binary_op( + self, + lhs_type: dtypes.dtype, + rhs_type: dtypes.dtype, + op: str, + *, + inplace: bool = False, + ) -> None: + lhs_key = self._record_composite_type(lhs_type) + rhs_key = self._record_composite_type(rhs_type) + + lhs_is_composite = lhs_key is not None + rhs_is_composite = rhs_key is not None + if not lhs_is_composite and not rhs_is_composite: + return + + lhs_is_scalar = dtypes.is_scalar(lhs_type) + rhs_is_scalar = dtypes.is_scalar(rhs_type) + + if lhs_key in _CUDA_VEC_TYPE_SPECS and (rhs_is_scalar or rhs_key in _CUDA_VEC_TYPE_SPECS): + if inplace: + suffix = "s" if rhs_is_scalar else "v" + self._record_vec_op(lhs_key, f"cmpd:{op}=:{suffix}") + return + shape = "vs" if rhs_is_scalar else "vv" + self._record_vec_op(lhs_key, f"bin:{op}:{shape}") + return + + if rhs_key in _CUDA_VEC_TYPE_SPECS and lhs_is_scalar and not inplace: + self._record_vec_op(rhs_key, f"bin:{op}:sv") + return + + if lhs_key in _CUDA_MAT_TYPE_SPECS: + if inplace: + if rhs_is_scalar: + token = f"cmpd:{op}=:s" + elif rhs_key in _CUDA_MAT_TYPE_SPECS: + token = f"cmpd:{op}=:m" + else: + return + self._record_mat_op(lhs_key, token) + self._propagate_matrix_vec_dependencies(lhs_key, token) + return + + if rhs_is_scalar: + token = f"bin:{op}:ms" + self._record_mat_op(lhs_key, token) + self._propagate_matrix_vec_dependencies(lhs_key, token) + return + + if rhs_key in _CUDA_MAT_TYPE_SPECS: + token = "bin:*:mm" if op == "*" else f"bin:{op}:mm" + self._record_mat_op(lhs_key, token) + self._propagate_matrix_vec_dependencies(lhs_key, token) + return + + if rhs_key in _CUDA_VEC_TYPE_SPECS and op == "*": + token = "bin:*:mv" + self._record_mat_op(lhs_key, token) + self._propagate_matrix_vec_dependencies(lhs_key, token) + return + + if rhs_key in _CUDA_MAT_TYPE_SPECS and lhs_is_scalar and not inplace: + token = f"bin:{op}:sm" + self._record_mat_op(rhs_key, token) + self._propagate_matrix_vec_dependencies(rhs_key, token) + return + + if lhs_key in _CUDA_VEC_TYPE_SPECS and rhs_key in _CUDA_MAT_TYPE_SPECS and op == "*" and not inplace: + token = "bin:*:vm" + self._record_mat_op(rhs_key, token) + self._propagate_matrix_vec_dependencies(rhs_key, token) + + def _emit_used_composite_helpers(self) -> str: + if len(self._composite_type_usage) == 0: + return "" + + parts: List[str] = [] + + # Subgroup helpers use vector binary operators internally (e.g. value = value + shuffled) + # even if user code never directly emits the corresponding operator on that vector type. + subgroup_vec_op_requirements = [ + ("subgroup_add", "bin:+:vv"), + ("subgroup_mul", "bin:*:vv"), + ("subgroup_and", "bin:&:vv"), + ("subgroup_or", "bin:|:vv"), + ("subgroup_xor", "bin:^:vv"), + ] + for feature_name, token in subgroup_vec_op_requirements: + if not self._feature_usage.get(feature_name, False): + continue + for key in self._composite_type_usage: + if key in _CUDA_VEC_TYPE_SPECS: + self._composite_vec_op_usage.setdefault(key, set()).add(token) + + emitted_vec_keys: Set[str] = set() + for key in _CUDA_VEC_ORDER: + if key not in self._composite_type_usage: + continue + vec_name, scalar_type, dim, cuda_native_type, allow_neg, enable_bitwise = _CUDA_VEC_TYPE_SPECS[key] + emitted_vec_keys.add(key) + parts.append( + _cuda_emit_vec_type( + vec_name, + scalar_type, + dim, + cuda_native_type, + allow_unary_neg=allow_neg, + enable_bitwise=enable_bitwise, + needed_ops=self._composite_vec_op_usage.get(key, set()), + ) + ) + parts.append(_cuda_emit_vec_helper(key, vec_name, scalar_type, dim)) + for key in _CUDA_VEC_ORDER: + if key not in emitted_vec_keys: + continue + vec_name, scalar_type, dim, _, _, _ = _CUDA_VEC_TYPE_SPECS[key] + conversion_helpers = _cuda_emit_vec_wrapper_conversion_helpers( + key, + vec_name, + scalar_type, + dim, + available_keys=emitted_vec_keys, + ) + if len(conversion_helpers) > 0: + parts.append(conversion_helpers) + + subgroup_shuffle_overloads = _cuda_emit_subgroup_shuffle_xor_vec_overloads(emitted_vec_keys) + if len(subgroup_shuffle_overloads) > 0: + parts.append(subgroup_shuffle_overloads) + + for key in _CUDA_MAT_ORDER: + if key not in self._composite_type_usage: + continue + mat_name, vec_name, vec_helper_suffix, dim = _CUDA_MAT_TYPE_SPECS[key] + parts.append(_cuda_emit_mat_type(mat_name, vec_name, dim, self._composite_mat_op_usage.get(key, set()))) + parts.append(_cuda_emit_mat_helpers(mat_name, key, vec_name, vec_helper_suffix, dim)) + + vec_math_helpers = self._emit_used_vec_math_helpers() + if len(vec_math_helpers) > 0: + parts.append(vec_math_helpers) + + return "\n\n".join(parts) + + @staticmethod + def _cuda_scalar_unary_math_name(func_name: str, scalar_type: str) -> str: + return cuda_scalar_unary_math_name(func_name, scalar_type) + + @staticmethod + def _cuda_scalar_binary_math_name(func_name: str, scalar_type: str) -> str: + return cuda_scalar_binary_math_name(func_name, scalar_type) + + def _emit_used_vec_math_helpers(self) -> str: + return emit_used_vec_math_helpers( + self._composite_vec_unary_math_usage, + self._composite_vec_binary_math_usage, + ) + + def _register_kernel_param(self, param_decl: str) -> None: + if param_decl not in self._kernel_params: + self._kernel_params.append(param_decl) + + def _register_alias_line(self, alias_line: str) -> None: + if alias_line not in self._entry_alias_lines: + self._entry_alias_lines.append(alias_line) + + @staticmethod + def _is_plain_integer_literal(expr: str) -> bool: + if len(expr) == 0: + return False + if expr[0] in "+-": + return len(expr) > 1 and expr[1:].isdigit() + return expr.isdigit() + + _SCALAR_TYPE_NAMES = _CUDA_SCALAR_TYPE_NAMES + + def type_name(self, var_type: dtypes.dtype) -> str: + scalar_name = self._SCALAR_TYPE_NAMES.get(var_type) + if scalar_name is not None: + if var_type == dtypes.float16: + self._needs_cuda_fp16 = True + return scalar_name + + key = self._composite_key_for_dtype(var_type) + if key is not None: + self._record_composite_type(var_type) + if key in _CUDA_VEC_TYPE_SPECS: + # Track fp16 header need when half vector types are used. + if _CUDA_VEC_TYPE_SPECS[key][1] == "__half": + self._needs_cuda_fp16 = True + return _CUDA_VEC_TYPE_SPECS[key][0] + if key in _CUDA_MAT_TYPE_SPECS: + return _CUDA_MAT_TYPE_SPECS[key][0] + + raise ValueError(f"Unsupported CUDA type mapping for '{var_type.name}'") + + _FLOAT_VEC_DTYPES = _CUDA_FLOAT_VEC_DTYPES + + def constructor( + self, + var_type: dtypes.dtype, + args: List[str], + arg_types: Optional[List[Optional[dtypes.dtype]]] = None, + ) -> str: + _ = arg_types + if ( + len(args) == 1 + and var_type in self._FLOAT_VEC_DTYPES + and self._is_plain_integer_literal(args[0]) + ): + scalar_type = None + if dtypes.is_complex(var_type): + scalar_type = var_type.child_type + elif dtypes.is_vector(var_type): + scalar_type = var_type.scalar + + if scalar_type == dtypes.float64: + args = [f"{args[0]}.0"] + else: + args = [f"{args[0]}.0f"] + + target_type = self.type_name(var_type) + + if dtypes.is_scalar(var_type): + assert len(args) > 0, f"Constructor for scalar type '{var_type.name}' needs at least one argument." + return f"(({target_type})({args[0]}))" + + if var_type == dtypes.mat2: + self.mark_feature_usage("make_mat2") + return f"vkdispatch_make_mat2({', '.join(args)})" + if var_type == dtypes.mat3: + self.mark_feature_usage("make_mat3") + return f"vkdispatch_make_mat3({', '.join(args)})" + if var_type == dtypes.mat4: + self.mark_feature_usage("make_mat4") + return f"vkdispatch_make_mat4({', '.join(args)})" + + helper_suffix = target_type[len("vkdispatch_"):] if target_type.startswith("vkdispatch_") else target_type + helper_name = f"vkdispatch_make_{helper_suffix}" + self.mark_feature_usage(f"make_{helper_suffix}") + return f"{helper_name}({', '.join(args)})" + + def component_access_expr(self, expr: str, component: str, base_type: dtypes.dtype) -> str: + if dtypes.is_scalar(base_type): + if component == "x": + return expr + return super().component_access_expr(expr, component, base_type) + + if dtypes.is_vector(base_type) or dtypes.is_complex(base_type): + direct_builtin_component = self._cuda_builtin_uvec3_component_expr(expr, component, base_type) + if direct_builtin_component is not None: + return direct_builtin_component + return f"{expr}.v.{component}" + + return super().component_access_expr(expr, component, base_type) + + def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str: + subgroup_support = "1" if enable_subgroup_ops else "0" + printf_support = "1" if enable_printf else "0" + + self._enable_subgroup_ops = enable_subgroup_ops + self._enable_printf = enable_printf + + helper_header = self._helper_header() + fp16_include = "#include \n" if self._needs_cuda_fp16 else "" + + self._fixed_preamble = ( + "#include \n" + f"{fp16_include}\n" + f"#define VKDISPATCH_ENABLE_SUBGROUP_OPS {subgroup_support}\n" + f"#define VKDISPATCH_ENABLE_PRINTF {printf_support}\n\n" + f"{helper_header}\n\n" + ) + + return self._fixed_preamble + + def _resolve_helper_dependencies(self, helpers: Set[str]) -> Set[str]: + pending = list(helpers) + resolved = set(helpers) + + while len(pending) > 0: + helper_name = pending.pop() + + for dependency in self._HELPER_DEPENDENCIES.get(helper_name, []): + if dependency not in resolved: + resolved.add(dependency) + pending.append(dependency) + + return resolved + + def _helper_header(self) -> str: + enabled_helpers = { + helper_name + for helper_name, is_enabled in self._feature_usage.items() + if is_enabled + } + + resolved_helpers = self._resolve_helper_dependencies(enabled_helpers) + + if len(resolved_helpers) == 0: + return "" + + helper_sections: List[str] = [] + + for helper_name in self._HELPER_ORDER: + if helper_name in resolved_helpers: + if helper_name == "composite_types": + composite_helpers = self._emit_used_composite_helpers() + if len(composite_helpers) > 0: + helper_sections.append(composite_helpers) + continue + + snippet = self._HELPER_SNIPPETS[helper_name] + if len(snippet) > 0: + helper_sections.append(snippet) + + return "\n\n".join(helper_sections) + "\n\n" + + def make_source(self, header: str, body: str, x: int, y: int, z: int) -> str: + header, body = self._finalize_cuda_builtin_uvec3_sentinels(header, body) + + expected_size_header = ( + f"// Expected local size: ({x}, {y}, {z})\n" + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X {x}\n" + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y {y}\n" + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Z {z}\n" + ) + + return f"{expected_size_header}\n{header}\n{body}" + + def constant_namespace(self) -> str: + return "UBO" + + def variable_namespace(self) -> str: + return "PC" + + def exec_bounds_guard(self, exec_count_expr: str) -> str: + gid = self.global_invocation_id_expr() + exec_expr = f"({exec_count_expr})" + gid_expr = f"({gid})" + return ( + f"if ({self.component_access_expr(exec_expr, 'x', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'x', dtypes.uvec3)} || " + f"{self.component_access_expr(exec_expr, 'y', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'y', dtypes.uvec3)} || " + f"{self.component_access_expr(exec_expr, 'z', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'z', dtypes.uvec3)}) {{ return; }}\n" + ) + + def shared_buffer_declaration(self, var_type: dtypes.dtype, name: str, size: int) -> str: + return f"__shared__ {self.type_name(var_type)} {name}[{size}];" + + def uniform_block_declaration(self, contents: str) -> str: + self._register_kernel_param("const UniformObjectBuffer vkdispatch_uniform_value") + self._register_alias_line("const UniformObjectBuffer& UBO = vkdispatch_uniform_value;") + return f"\nstruct UniformObjectBuffer {{\n{contents}\n}};\n" + + def storage_buffer_declaration(self, binding: int, var_type: dtypes.dtype, name: str) -> str: + struct_name = f"Buffer{binding}" + param_name = f"vkdispatch_binding_{binding}_ptr" + self._register_kernel_param(f"{self.type_name(var_type)}* {param_name}") + self._register_alias_line(f"{struct_name} {name} = {{{param_name}}};") + return f"struct {struct_name} {{ {self.type_name(var_type)}* data; }};\n" + + def sampler_declaration(self, binding: int, dimensions: int, name: str) -> str: + param_name = f"vkdispatch_sampler_{binding}" + self._register_kernel_param(f"cudaTextureObject_t {param_name}") + self._register_alias_line(f"cudaTextureObject_t {name} = {param_name};") + return f"// sampler binding {binding}, dimensions={dimensions}\n" + + def push_constant_declaration(self, contents: str) -> str: + self._register_kernel_param("const PushConstant vkdispatch_pc_value") + self._register_alias_line("const PushConstant& PC = vkdispatch_pc_value;") + return f"\nstruct PushConstant {{\n{contents}\n}};\n" + + def entry_point(self, body_contents: str) -> str: + params = ", ".join(self._kernel_params) + + alias_block = "" + for line in self._entry_alias_lines: + alias_block += f" {line}\n" + + return ( + f'extern "C" __global__ void vkdispatch_main({params}) {{\n' + f"{alias_block}" + f"{body_contents}" + f"}}\n" + ) + + def inf_f32_expr(self) -> str: + self.mark_feature_usage("uintBitsToFloat") + return "uintBitsToFloat(0x7F800000u)" + + def ninf_f32_expr(self) -> str: + self.mark_feature_usage("uintBitsToFloat") + return "uintBitsToFloat(0xFF800000u)" + + def inf_f64_expr(self) -> str: + self.mark_feature_usage("longlong_as_double") + return "__longlong_as_double(0x7FF0000000000000LL)" + + def ninf_f64_expr(self) -> str: + self.mark_feature_usage("longlong_as_double") + return "__longlong_as_double(0xFFF0000000000000LL)" + + def inf_f16_expr(self) -> str: + self.mark_feature_usage("ushort_as_half") + return "__ushort_as_half(0x7C00u)" + + def ninf_f16_expr(self) -> str: + self.mark_feature_usage("ushort_as_half") + return "__ushort_as_half(0xFC00u)" + + def fma_function_name(self, var_type: dtypes.dtype) -> str: + if var_type == dtypes.float16: + return "__hfma" + if var_type == dtypes.float32: + return "fmaf" + return "fma" + + def math_func_name(self, func_name: str, var_type: dtypes.dtype) -> str: + scalar = var_type + if dtypes.is_vector(var_type) or dtypes.is_matrix(var_type): + scalar = var_type.scalar + elif dtypes.is_complex(var_type): + scalar = var_type.child_type + + if scalar == dtypes.float16: + return self._cuda_scalar_unary_math_name(func_name, "__half") + if scalar == dtypes.float32: + return self._cuda_fast_unary_math_name(func_name) + # double and integer types use standard C names + return func_name + + @staticmethod + def _cuda_fast_unary_math_name(func_name: str) -> str: + return cuda_fast_unary_math_name(func_name) + + @staticmethod + def _cuda_fast_binary_math_name(func_name: str) -> str: + return cuda_fast_binary_math_name(func_name) + + _FLOAT_VEC_HELPER_SUFFIX_MAP = _CUDA_FLOAT_VEC_HELPER_SUFFIX_MAP + + @staticmethod + def _cuda_float_vec_helper_suffix(var_type: dtypes.dtype) -> Optional[str]: + return cuda_float_vec_helper_suffix(var_type) + + @staticmethod + def _cuda_float_vec_components_for_suffix(helper_suffix: str) -> List[str]: + return cuda_float_vec_components_for_suffix(helper_suffix) + + def _cuda_componentwise_unary_math_expr(self, func_name: str, arg_type: dtypes.dtype, arg_expr: str) -> Optional[str]: + helper_suffix = self._cuda_float_vec_helper_suffix(arg_type) + if helper_suffix is None: + return None + + self._record_vec_unary_math(helper_suffix, func_name) + return f"{func_name}({arg_expr})" + + def _cuda_componentwise_binary_math_expr( + self, + func_name: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + ) -> Optional[str]: + lhs_helper = self._cuda_float_vec_helper_suffix(lhs_type) + rhs_helper = self._cuda_float_vec_helper_suffix(rhs_type) + + if lhs_helper is None and rhs_helper is None: + return None + + if lhs_helper is not None and rhs_helper is not None and lhs_helper != rhs_helper: + return None + + helper_suffix = lhs_helper if lhs_helper is not None else rhs_helper + assert helper_suffix is not None + + signature = ("v" if lhs_helper is not None else "s") + ("v" if rhs_helper is not None else "s") + self._record_vec_binary_math(helper_suffix, func_name, signature) + return f"{func_name}({lhs_expr}, {rhs_expr})" + + def unary_math_expr(self, func_name: str, arg_type: dtypes.dtype, arg_expr: str) -> str: + vector_expr = self._cuda_componentwise_unary_math_expr(func_name, arg_type, arg_expr) + if vector_expr is not None: + return vector_expr + + mapped = self.math_func_name(func_name, arg_type) + return f"{mapped}({arg_expr})" + + def binary_math_expr( + self, + func_name: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + ) -> str: + vector_expr = self._cuda_componentwise_binary_math_expr( + func_name, + lhs_type, + lhs_expr, + rhs_type, + rhs_expr, + ) + if vector_expr is not None: + return vector_expr + + if dtypes.is_scalar(lhs_type) and dtypes.is_scalar(rhs_type): + scalar = lhs_type + scalar_name = self._SCALAR_TYPE_NAMES.get(scalar, "float") + return f"{self._cuda_scalar_binary_math_name(func_name, scalar_name)}({lhs_expr}, {rhs_expr})" + + return f"{func_name}({lhs_expr}, {rhs_expr})" + + def float_bits_to_int_expr(self, var_expr: str) -> str: + self.mark_feature_usage("floatBitsToInt") + return f"floatBitsToInt({var_expr})" + + def float_bits_to_uint_expr(self, var_expr: str) -> str: + self.mark_feature_usage("floatBitsToUint") + return f"floatBitsToUint({var_expr})" + + def int_bits_to_float_expr(self, var_expr: str) -> str: + self.mark_feature_usage("intBitsToFloat") + return f"intBitsToFloat({var_expr})" + + def uint_bits_to_float_expr(self, var_expr: str) -> str: + self.mark_feature_usage("uintBitsToFloat") + return f"uintBitsToFloat({var_expr})" + + def global_invocation_id_expr(self) -> str: + return self._CUDA_BUILTIN_UVEC3_SENTINELS["global_invocation_id"]["sentinel"] + + def local_invocation_id_expr(self) -> str: + return self._CUDA_BUILTIN_UVEC3_SENTINELS["local_invocation_id"]["sentinel"] + + def local_invocation_index_expr(self) -> str: + self.mark_feature_usage("local_invocation_index") + return "vkdispatch_local_invocation_index()" + + def workgroup_id_expr(self) -> str: + return self._CUDA_BUILTIN_UVEC3_SENTINELS["workgroup_id"]["sentinel"] + + def workgroup_size_expr(self) -> str: + self._record_composite_type_key("uint3") + self.mark_feature_usage("make_uint3") + return "vkdispatch_make_uint3((unsigned int)blockDim.x, (unsigned int)blockDim.y, (unsigned int)blockDim.z)" + + def num_workgroups_expr(self) -> str: + self._record_composite_type_key("uint3") + self.mark_feature_usage("make_uint3") + return "vkdispatch_make_uint3((unsigned int)gridDim.x, (unsigned int)gridDim.y, (unsigned int)gridDim.z)" + + def num_subgroups_expr(self) -> str: + self.mark_feature_usage("num_subgroups") + return "vkdispatch_num_subgroups()" + + def subgroup_id_expr(self) -> str: + self.mark_feature_usage("subgroup_id") + return "vkdispatch_subgroup_id()" + + def subgroup_size_expr(self) -> str: + self.mark_feature_usage("subgroup_size") + return "vkdispatch_subgroup_size()" + + def subgroup_invocation_id_expr(self) -> str: + self.mark_feature_usage("subgroup_invocation_id") + return "vkdispatch_subgroup_invocation_id()" + + def barrier_statement(self) -> str: + return "__syncthreads();" + + def memory_barrier_statement(self) -> str: + return "__threadfence();" + + def memory_barrier_buffer_statement(self) -> str: + return "__threadfence();" + + def memory_barrier_shared_statement(self) -> str: + return "__threadfence_block();" + + def memory_barrier_image_statement(self) -> str: + return "__threadfence();" + + def group_memory_barrier_statement(self) -> str: + return "__threadfence_block();" + + @staticmethod + def _strip_outer_parens(expr: str) -> str: + stripped = expr.strip() + while len(stripped) >= 2 and stripped[0] == "(" and stripped[-1] == ")": + depth = 0 + balanced = True + for idx, ch in enumerate(stripped): + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + if depth < 0: + balanced = False + break + if depth == 0 and idx != len(stripped) - 1: + balanced = False + break + if not balanced or depth != 0: + break + stripped = stripped[1:-1].strip() + return stripped + + def _cuda_builtin_uvec3_component_expr( + self, + expr: str, + component: str, + base_type: dtypes.dtype, + ) -> Optional[str]: + if base_type != dtypes.uvec3 or component not in ("x", "y", "z"): + return None + + stripped_expr = self._strip_outer_parens(expr) + for builtin_spec in self._CUDA_BUILTIN_UVEC3_SENTINELS.values(): + if stripped_expr == builtin_spec["sentinel"]: + return builtin_spec[component] + + return None + + def _finalize_cuda_builtin_uvec3_sentinels(self, header: str, body: str) -> Tuple[str, str]: + for builtin_spec in self._CUDA_BUILTIN_UVEC3_SENTINELS.values(): + sentinel = builtin_spec["sentinel"] + if sentinel not in header and sentinel not in body: + continue + + self._record_composite_type_key("uint3") + self.mark_feature_usage("make_uint3") + replacement = ( + "vkdispatch_make_uint3(" + f"{builtin_spec['x']}, {builtin_spec['y']}, {builtin_spec['z']}" + ")" + ) + header = header.replace(sentinel, replacement) + body = body.replace(sentinel, replacement) + + return header, body + + def subgroup_add_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_add") + return f"vkdispatch_subgroup_add({arg_expr})" + + def subgroup_mul_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_mul") + return f"vkdispatch_subgroup_mul({arg_expr})" + + def subgroup_min_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_min") + return f"vkdispatch_subgroup_min({arg_expr})" + + def subgroup_max_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_max") + return f"vkdispatch_subgroup_max({arg_expr})" + + def subgroup_and_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_and") + return f"vkdispatch_subgroup_and({arg_expr})" + + def subgroup_or_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_or") + return f"vkdispatch_subgroup_or({arg_expr})" + + def subgroup_xor_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_xor") + return f"vkdispatch_subgroup_xor({arg_expr})" + + def subgroup_elect_expr(self) -> str: + self.mark_feature_usage("subgroup_invocation_id") + return "((int)(vkdispatch_subgroup_invocation_id() == 0u))" + + def subgroup_barrier_statement(self) -> str: + return "__syncwarp();" + + def printf_statement(self, fmt: str, args: List[str]) -> str: + #safe_fmt = fmt.replace("\\", "\\\\").replace('"', '\\"') + + if len(args) == 0: + return f'printf("{fmt}");' + + return f'printf("{fmt}", {", ".join(args)});' + + def texture_size_expr(self, texture_expr: str, lod: int, dimensions: int) -> str: + # CUDA texture objects do not expose shape directly in device code. + # The future CUDA backend should pass explicit texture shape parameters. + if dimensions == 1: + return "1.0f" + if dimensions == 2: + self.mark_feature_usage("make_float2") + return "vkdispatch_make_float2(1.0f)" + if dimensions == 3: + self.mark_feature_usage("make_float3") + return "vkdispatch_make_float3(1.0f)" + + raise ValueError(f"Unsupported texture dimensions '{dimensions}'") + + def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Optional[str] = None) -> str: + raise NotImplementedError("Direct texture sampling is not supported in CUDA backend. Use vkdispatch_sample_texture helper functions instead.") + + def atomic_add_expr(self, mem_expr: str, value_expr: str, var_type: dtypes.dtype) -> str: + if var_type not in (dtypes.int32, dtypes.uint32): + raise NotImplementedError(f"CUDA atomic_add only supports int32/uint32, got '{var_type.name}'") + + return f"atomicAdd(&({mem_expr}), {value_expr})" diff --git a/vkdispatch/codegen/backends/cuda/composite_emitters.py b/vkdispatch/codegen/backends/cuda/composite_emitters.py new file mode 100644 index 00000000..abb23ed6 --- /dev/null +++ b/vkdispatch/codegen/backends/cuda/composite_emitters.py @@ -0,0 +1,380 @@ +from typing import List, Optional, Set + +from .specs import _CUDA_MAT_TYPE_SPECS, _CUDA_VEC_ORDER, _CUDA_VEC_TYPE_SPECS + + +def _cuda_vec_components(dim: int) -> List[str]: + if dim < 2 or dim > 4: + raise ValueError(f"Unsupported vector dimension '{dim}'") + return list("xyzw"[:dim]) + + +def _cuda_join_statements(statements: List[str]) -> str: + if len(statements) == 0: + return "" + return " ".join(statements) + + +def _cuda_emit_vec_type( + vec_name: str, + scalar_type: str, + dim: int, + cuda_native_type: str, + *, + allow_unary_neg: bool, + enable_bitwise: bool, + needed_ops: Optional[Set[str]] = None, +) -> str: + comps = _cuda_vec_components(dim) + if needed_ops is None: + needed_ops = set() + if allow_unary_neg: + needed_ops.add("un:-") + if enable_bitwise: + needed_ops.add("un:~") + for op in ["+", "-", "*", "/"]: + needed_ops.add(f"cmpd:{op}=:v") + needed_ops.add(f"cmpd:{op}=:s") + needed_ops.add(f"bin:{op}:vv") + needed_ops.add(f"bin:{op}:vs") + needed_ops.add(f"bin:{op}:sv") + if enable_bitwise: + for op in ["&", "|", "^", "<<", ">>"]: + needed_ops.add(f"cmpd:{op}=:v") + needed_ops.add(f"cmpd:{op}=:s") + needed_ops.add(f"bin:{op}:vv") + needed_ops.add(f"bin:{op}:vs") + needed_ops.add(f"bin:{op}:sv") + + def has(token: str) -> bool: + return token in needed_ops + + def self_comp(c: str) -> str: + return f"v.{c}" + + def wrap_comp(obj: str, c: str) -> str: + return f"{obj}.v.{c}" + + def native_comp(obj: str, c: str) -> str: + return f"{obj}.{c}" + + def index_op_body() -> str: + branches: List[str] = [] + for idx, c in enumerate(comps): + prefix = "if" if idx == 0 else "else if" + branches.append(f"{prefix} (i == {idx}) return v.{c};") + branches.append(f"else return v.{comps[0]};") + return " ".join(branches) + + lines: List[str] = [f"struct {vec_name} {{"] + lines.append(f" {cuda_native_type} v;") + lines.append("") + ctor_args = ", ".join([f"{scalar_type} {c}_" for c in comps]) + ctor_init = "{" + ", ".join([f"{c}_" for c in comps]) + "}" + splat_init = "{" + ", ".join(["s" for _ in comps]) + "}" + cast_init = "{" + ", ".join([f"({scalar_type}){native_comp('src', c)}" for c in comps]) + "}" + member_guard = ", ".join([f"(void)(((const TVec*)0)->{c})" for c in comps]) + lines.append(f" __device__ __forceinline__ {vec_name}() = default;") + lines.append(f" __device__ __forceinline__ {vec_name}({ctor_args}) : v{ctor_init} {{}}") + lines.append(f" __device__ __forceinline__ explicit {vec_name}({scalar_type} s) : v{splat_init} {{}}") + lines.append(f" __device__ __forceinline__ explicit {vec_name}(const {cuda_native_type}& native) : v(native) {{}}") + lines.append(f" template ") + lines.append(f" __device__ __forceinline__ explicit {vec_name}(const TVec& src) : v{cast_init} {{}}") + lines.append(f" __device__ __forceinline__ {scalar_type}& operator[](int i) {{ {index_op_body()} }}") + lines.append(f" __device__ __forceinline__ const {scalar_type}& operator[](int i) const {{ {index_op_body()} }}") + + if allow_unary_neg and has("un:-"): + neg_expr = ", ".join([f"-{self_comp(c)}" for c in comps]) + lines.append(f" __device__ __forceinline__ {vec_name} operator-() const {{ return {vec_name}({neg_expr}); }}") + + if enable_bitwise and has("un:~"): + not_expr = ", ".join([f"~{self_comp(c)}" for c in comps]) + lines.append(f" __device__ __forceinline__ {vec_name} operator~() const {{ return {vec_name}({not_expr}); }}") + + for op in ["+", "-", "*", "/"]: + op_assign = op + "=" + if has(f"cmpd:{op}=:v"): + vv_ops = _cuda_join_statements([f"{self_comp(c)} {op_assign} {wrap_comp('b', c)};" for c in comps]) + lines.append( + f" __device__ __forceinline__ {vec_name}& operator{op_assign}(const {vec_name}& b) {{ {vv_ops} return *this; }}" + ) + if has(f"cmpd:{op}=:s"): + sv_ops = _cuda_join_statements([f"{self_comp(c)} {op_assign} b;" for c in comps]) + lines.append( + f" __device__ __forceinline__ {vec_name}& operator{op_assign}({scalar_type} b) {{ {sv_ops} return *this; }}" + ) + + if enable_bitwise: + for op in ["&", "|", "^", "<<", ">>"]: + op_assign = op + "=" + if has(f"cmpd:{op}=:v"): + vv_ops = _cuda_join_statements([f"{self_comp(c)} {op_assign} {wrap_comp('b', c)};" for c in comps]) + lines.append( + f" __device__ __forceinline__ {vec_name}& operator{op_assign}(const {vec_name}& b) {{ {vv_ops} return *this; }}" + ) + if has(f"cmpd:{op}=:s"): + sv_ops = _cuda_join_statements([f"{self_comp(c)} {op_assign} b;" for c in comps]) + lines.append( + f" __device__ __forceinline__ {vec_name}& operator{op_assign}({scalar_type} b) {{ {sv_ops} return *this; }}" + ) + + lines.append("};") + lines.append( + f'static_assert(sizeof({vec_name}) == sizeof({cuda_native_type}), "{vec_name} size must match {cuda_native_type}");' + ) + lines.append( + f'static_assert(alignof({vec_name}) == alignof({cuda_native_type}), "{vec_name} alignment must match {cuda_native_type}");' + ) + + for op in ["+", "-", "*", "/"]: + if has(f"bin:{op}:vv"): + vv_expr = ", ".join([f"({wrap_comp('a', c)} {op} {wrap_comp('b', c)})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, const {vec_name}& b) {{ return {vec_name}({vv_expr}); }}" + ) + if has(f"bin:{op}:vs"): + vs_expr = ", ".join([f"({wrap_comp('a', c)} {op} b)" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, {scalar_type} b) {{ return {vec_name}({vs_expr}); }}" + ) + if has(f"bin:{op}:sv"): + if op in ["+", "*"]: + sv_expr = ", ".join([f"(a {op} {wrap_comp('b', c)})" for c in comps]) + else: + sv_expr = ", ".join([f"({scalar_type})(a {op} {wrap_comp('b', c)})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} operator{op}({scalar_type} a, const {vec_name}& b) {{ return {vec_name}({sv_expr}); }}" + ) + + if enable_bitwise: + for op in ["&", "|", "^", "<<", ">>"]: + if has(f"bin:{op}:vv"): + vv_expr = ", ".join([f"({wrap_comp('a', c)} {op} {wrap_comp('b', c)})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, const {vec_name}& b) {{ return {vec_name}({vv_expr}); }}" + ) + if has(f"bin:{op}:vs"): + vs_expr = ", ".join([f"({wrap_comp('a', c)} {op} b)" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, {scalar_type} b) {{ return {vec_name}({vs_expr}); }}" + ) + if has(f"bin:{op}:sv"): + sv_expr = ", ".join([f"({scalar_type})(a {op} {wrap_comp('b', c)})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} operator{op}({scalar_type} a, const {vec_name}& b) {{ return {vec_name}({sv_expr}); }}" + ) + + return "\n".join(lines) + + +def _cuda_emit_vec_helper(helper_suffix: str, vec_name: str, scalar_type: str, dim: int) -> str: + comps = _cuda_vec_components(dim) + args = ", ".join([f"{scalar_type} {c}" for c in comps]) + ctor_args = ", ".join(comps) + member_guard = ", ".join([f"(void)(((const TVec*)0)->{c})" for c in comps]) + return "\n".join( + [ + f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}({args}) {{ return {vec_name}({ctor_args}); }}", + f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}({scalar_type} x) {{ return {vec_name}(x); }}", + f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}(const {vec_name}& v) {{ return v; }}", + f"template ", + f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}(const TVec& v) {{ return {vec_name}(v); }}", + ] + ) + + +def _cuda_emit_vec_wrapper_conversion_helpers( + helper_suffix: str, + vec_name: str, + scalar_type: str, + dim: int, + *, + available_keys: Optional[Set[str]] = None, +) -> str: + comps = _cuda_vec_components(dim) + dim_keys = [key for key in _CUDA_VEC_TYPE_SPECS if key.endswith(str(dim))] + if available_keys is not None: + dim_keys = [key for key in dim_keys if key in available_keys] + + lines: List[str] = [] + for src_key in dim_keys: + if src_key == helper_suffix: + continue + src_vec_name = _CUDA_VEC_TYPE_SPECS[src_key][0] + ctor_args = ", ".join([f"({scalar_type})src.v.{c}" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}(const {src_vec_name}& src) {{ return {vec_name}({ctor_args}); }}" + ) + + return "\n".join(lines) + + +def _cuda_emit_mat_type(mat_name: str, vec_name: str, dim: int, needed_ops: Optional[Set[str]] = None) -> str: + cols = [f"c{i}" for i in range(dim)] + if needed_ops is None: + needed_ops = { + "un:-", + "cmpd:+=:m", + "cmpd:+=:s", + "cmpd:-=:m", + "cmpd:-=:s", + "cmpd:*=:s", + "cmpd:/=:s", + "bin:+:mm", + "bin:+:ms", + "bin:+:sm", + "bin:-:mm", + "bin:-:ms", + "bin:-:sm", + "bin:*:ms", + "bin:*:sm", + "bin:/:ms", + "bin:/:sm", + "bin:*:mv", + "bin:*:vm", + "bin:*:mm", + } + + def has(token: str) -> bool: + return token in needed_ops + + lines: List[str] = [f"struct {mat_name} {{"] + lines.extend([f" {vec_name} {c};" for c in cols]) + lines.append("") + lines.append(f" __device__ __forceinline__ {mat_name}() = default;") + ctor_args = ", ".join([f"{vec_name} {c}_" for c in cols]) + ctor_init = ", ".join([f"{c}({c}_)" for c in cols]) + lines.append(f" __device__ __forceinline__ {mat_name}({ctor_args}) : {ctor_init} {{}}") + + zero = "0.0f" + diag_init = ", ".join( + [f"c{col_idx}({vec_name}({', '.join(['s' if row_idx == col_idx else zero for row_idx in range(dim)])}))" for col_idx in range(dim)] + ) + lines.append(f" __device__ __forceinline__ explicit {mat_name}(float s) : {diag_init} {{}}") + lines.append(f" __device__ __forceinline__ {vec_name}& operator[](int i) {{ return (&c0)[i]; }}") + lines.append(f" __device__ __forceinline__ const {vec_name}& operator[](int i) const {{ return (&c0)[i]; }}") + if has("un:-"): + lines.append( + f" __device__ __forceinline__ {mat_name} operator-() const {{ return {mat_name}({', '.join([f'-c{i}' for i in range(dim)])}); }}" + ) + + for op in ["+", "-"]: + op_assign = op + "=" + if has(f"cmpd:{op}=:m"): + mm_ops = _cuda_join_statements([f"c{i} {op_assign} b.c{i};" for i in range(dim)]) + lines.append( + f" __device__ __forceinline__ {mat_name}& operator{op_assign}(const {mat_name}& b) {{ {mm_ops} return *this; }}" + ) + if has(f"cmpd:{op}=:s"): + ms_ops = _cuda_join_statements([f"c{i} {op_assign} b;" for i in range(dim)]) + lines.append( + f" __device__ __forceinline__ {mat_name}& operator{op_assign}(float b) {{ {ms_ops} return *this; }}" + ) + + for op in ["*", "/"]: + op_assign = op + "=" + if has(f"cmpd:{op}=:s"): + ms_ops = _cuda_join_statements([f"c{i} {op_assign} b;" for i in range(dim)]) + lines.append( + f" __device__ __forceinline__ {mat_name}& operator{op_assign}(float b) {{ {ms_ops} return *this; }}" + ) + + lines.append("};") + + for op in ["+", "-"]: + if has(f"bin:{op}:mm"): + cols_expr = ", ".join([f"(a.c{i} {op} b.c{i})" for i in range(dim)]) + lines.append( + f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}" + ) + if has(f"bin:{op}:ms"): + cols_expr = ", ".join([f"(a.c{i} {op} b)" for i in range(dim)]) + lines.append( + f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, float b) {{ return {mat_name}({cols_expr}); }}" + ) + if has(f"bin:{op}:sm"): + cols_expr = ", ".join([f"(a {op} b.c{i})" for i in range(dim)]) + lines.append( + f"__device__ __forceinline__ {mat_name} operator{op}(float a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}" + ) + + for op in ["*", "/"]: + if has(f"bin:{op}:ms"): + cols_expr = ", ".join([f"(a.c{i} {op} b)" for i in range(dim)]) + lines.append( + f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, float b) {{ return {mat_name}({cols_expr}); }}" + ) + if has(f"bin:{op}:sm"): + cols_expr = ", ".join([f"(a {op} b.c{i})" for i in range(dim)]) + lines.append( + f"__device__ __forceinline__ {mat_name} operator{op}(float a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}" + ) + + vec_comps = _cuda_vec_components(dim) + if has("bin:*:mv"): + mat_vec_terms = [f"(m.c{i} * v.v.{vec_comps[i]})" for i in range(dim)] + mat_vec_expr = " + ".join(mat_vec_terms) + lines.append( + f"__device__ __forceinline__ {vec_name} operator* (const {mat_name}& m, const {vec_name}& v) {{ return {mat_vec_expr}; }}" + ) + + if has("bin:*:vm"): + row_exprs: List[str] = [] + for col_idx in range(dim): + terms = [f"(v.v.{vec_comps[row_idx]} * m.c{col_idx}.v.{vec_comps[row_idx]})" for row_idx in range(dim)] + row_exprs.append(" + ".join(terms)) + lines.append( + f"__device__ __forceinline__ {vec_name} operator* (const {vec_name}& v, const {mat_name}& m) {{ return {vec_name}({', '.join(row_exprs)}); }}" + ) + + if has("bin:*:mm"): + col_products = ", ".join([f"(a * b.c{i})" for i in range(dim)]) + lines.append( + f"__device__ __forceinline__ {mat_name} operator* (const {mat_name}& a, const {mat_name}& b) {{ return {mat_name}({col_products}); }}" + ) + + return "\n".join(lines) + + +def _cuda_emit_mat_helpers(mat_name: str, helper_suffix: str, vec_name: str, vec_helper_suffix: str, dim: int) -> str: + col_type = vec_name + col_args = ", ".join([f"{col_type} c{i}" for i in range(dim)]) + col_ctor = ", ".join([f"c{i}" for i in range(dim)]) + + flat_names = [f"m{col}{row}" for col in range(dim) for row in range(dim)] + flat_args = ", ".join([f"float {name}" for name in flat_names]) + flat_cols: List[str] = [] + for col in range(dim): + values = [f"m{col}{row}" for row in range(dim)] + flat_cols.append(f"vkdispatch_make_{vec_helper_suffix}({', '.join(values)})") + flat_ctor = ", ".join(flat_cols) + + cast_cols = ", ".join([f"vkdispatch_make_{vec_helper_suffix}(m[{i}])" for i in range(dim)]) + + return "\n".join( + [ + f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}({col_args}) {{ return {mat_name}({col_ctor}); }}", + f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}(float s) {{ return {mat_name}(s); }}", + f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}({flat_args}) {{ return {mat_name}({flat_ctor}); }}", + "template ", + f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}(TMat m) {{ return {mat_name}({cast_cols}); }}", + ] + ) + + +def _cuda_emit_subgroup_shuffle_xor_vec_overloads(vec_keys: Set[str]) -> str: + lines: List[str] = [] + + for key in _CUDA_VEC_ORDER: + if key not in vec_keys: + continue + + vec_name, _, dim, _, _, _ = _CUDA_VEC_TYPE_SPECS[key] + comps = _cuda_vec_components(dim) + comp_exprs = ", ".join([f"__shfl_xor_sync(mask, value.v.{c}, lane_mask)" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} vkdispatch_subgroup_shuffle_xor(unsigned int mask, const {vec_name}& value, int lane_mask) " + f"{{ return vkdispatch_make_{key}({comp_exprs}); }}" + ) + + return "\n".join(lines) diff --git a/vkdispatch/codegen/backends/cuda/helper_snippets.py b/vkdispatch/codegen/backends/cuda/helper_snippets.py new file mode 100644 index 00000000..93fa3eeb --- /dev/null +++ b/vkdispatch/codegen/backends/cuda/helper_snippets.py @@ -0,0 +1,287 @@ +from typing import Dict, List + + +_HELPER_SNIPPETS: Dict[str, str] = { + "composite_types": "", + "mat2_type": "", + "mat3_type": "", + "mat4_type": "", + "make_mat2": "", + "make_mat3": "", + "make_mat4": "", + "make_short2": "", + "make_short3": "", + "make_short4": "", + "make_ushort2": "", + "make_ushort3": "", + "make_ushort4": "", + "make_int2": "", + "make_int3": "", + "make_int4": "", + "make_uint2": "", + "make_uint3": "", + "make_uint4": "", + "make_half2": "", + "make_half3": "", + "make_half4": "", + "float2_ops": "", + "make_float2": "", + "make_float3": "", + "make_float4": "", + "make_double2": "", + "make_double3": "", + "make_double4": "", + "global_invocation_id": ( + "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_global_invocation_id() {\n" + " return vkdispatch_uint3(\n" + " (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x),\n" + " (unsigned int)(blockIdx.y * blockDim.y + threadIdx.y),\n" + " (unsigned int)(blockIdx.z * blockDim.z + threadIdx.z)\n" + " );\n" + "}" + ), + "local_invocation_id": ( + "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_local_invocation_id() {\n" + " return vkdispatch_uint3((unsigned int)threadIdx.x, (unsigned int)threadIdx.y, (unsigned int)threadIdx.z);\n" + "}" + ), + "workgroup_id": ( + "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_workgroup_id() {\n" + " return vkdispatch_uint3((unsigned int)blockIdx.x, (unsigned int)blockIdx.y, (unsigned int)blockIdx.z);\n" + "}" + ), + "local_invocation_index": ( + "__device__ __forceinline__ unsigned int vkdispatch_local_invocation_index() {\n" + " return (unsigned int)(threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z));\n" + "}" + ), + "subgroup_size": "__device__ __forceinline__ unsigned int vkdispatch_subgroup_size() { return (unsigned int)warpSize; }", + "num_subgroups": ( + "__device__ __forceinline__ unsigned int vkdispatch_num_subgroups() {\n" + " unsigned int local_count = (unsigned int)(blockDim.x * blockDim.y * blockDim.z);\n" + " return (local_count + vkdispatch_subgroup_size() - 1u) / vkdispatch_subgroup_size();\n" + "}" + ), + "subgroup_id": ( + "__device__ __forceinline__ unsigned int vkdispatch_subgroup_id() {\n" + " return vkdispatch_local_invocation_index() / vkdispatch_subgroup_size();\n" + "}" + ), + "subgroup_invocation_id": ( + "__device__ __forceinline__ unsigned int vkdispatch_subgroup_invocation_id() {\n" + " return vkdispatch_local_invocation_index() % vkdispatch_subgroup_size();\n" + "}" + ), + "subgroup_shuffle_xor": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_shuffle_xor(unsigned int mask, T value, int lane_mask) {\n" + " return __shfl_xor_sync(mask, value, lane_mask);\n" + "}" + ), + "subgroup_add": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_add(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " value = value + vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " }\n" + " return value;\n" + "}" + ), + "subgroup_mul": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_mul(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " value = value * vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " }\n" + " return value;\n" + "}" + ), + "subgroup_min": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_min(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " T other = vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " value = other < value ? other : value;\n" + " }\n" + " return value;\n" + "}" + ), + "subgroup_max": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_max(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " T other = vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " value = other > value ? other : value;\n" + " }\n" + " return value;\n" + "}" + ), + "subgroup_and": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_and(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " value = value & vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " }\n" + " return value;\n" + "}" + ), + "subgroup_or": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_or(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " value = value | vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " }\n" + " return value;\n" + "}" + ), + "subgroup_xor": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_xor(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " value = value ^ vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " }\n" + " return value;\n" + "}" + ), + "mod": ( + "__device__ __forceinline__ float mod(float x, float y) { return fmodf(x, y); }\n" + "__device__ __forceinline__ double mod(double x, double y) { return fmod(x, y); }" + ), + "fract": ( + "__device__ __forceinline__ float fract(float x) { return x - floorf(x); }\n" + "__device__ __forceinline__ double fract(double x) { return x - floor(x); }" + ), + "roundEven": ( + "__device__ __forceinline__ float roundEven(float x) { return nearbyintf(x); }\n" + "__device__ __forceinline__ double roundEven(double x) { return nearbyint(x); }" + ), + "mix": ( + "__device__ __forceinline__ float mix(float x, float y, float a) { return x + (y - x) * a; }\n" + "__device__ __forceinline__ double mix(double x, double y, double a) { return x + (y - x) * a; }" + ), + "step": ( + "__device__ __forceinline__ float step(float edge, float x) { return x < edge ? 0.0f : 1.0f; }\n" + "__device__ __forceinline__ double step(double edge, double x) { return x < edge ? 0.0 : 1.0; }" + ), + "smoothstep": ( + "__device__ __forceinline__ float smoothstep(float edge0, float edge1, float x) {\n" + " float t = fminf(fmaxf((x - edge0) / (edge1 - edge0), 0.0f), 1.0f);\n" + " return t * t * (3.0f - 2.0f * t);\n" + "}\n" + "__device__ __forceinline__ double smoothstep(double edge0, double edge1, double x) {\n" + " double t = fmin(fmax((x - edge0) / (edge1 - edge0), 0.0), 1.0);\n" + " return t * t * (3.0 - 2.0 * t);\n" + "}" + ), + "radians": ( + "__device__ __forceinline__ float radians(float x) { return x * (3.14159265358979323846f / 180.0f); }\n" + "__device__ __forceinline__ double radians(double x) { return x * (3.14159265358979323846 / 180.0); }" + ), + "degrees": ( + "__device__ __forceinline__ float degrees(float x) { return x * (180.0f / 3.14159265358979323846f); }\n" + "__device__ __forceinline__ double degrees(double x) { return x * (180.0 / 3.14159265358979323846); }" + ), + "inversesqrt": ( + "__device__ __forceinline__ float inversesqrt(float x) { return rsqrtf(x); }\n" + "__device__ __forceinline__ double inversesqrt(double x) { return rsqrt(x); }" + ), + "floatBitsToInt": "__device__ __forceinline__ int floatBitsToInt(float x) { return __float_as_int(x); }", + "floatBitsToUint": "__device__ __forceinline__ unsigned int floatBitsToUint(float x) { return __float_as_uint(x); }", + "intBitsToFloat": "__device__ __forceinline__ float intBitsToFloat(int x) { return __int_as_float(x); }", + "uintBitsToFloat": "__device__ __forceinline__ float uintBitsToFloat(unsigned int x) { return __uint_as_float(x); }", + "longlong_as_double": "__device__ __forceinline__ double longlong_as_double(long long x) { return __longlong_as_double(x); }", + "ushort_as_half": "__device__ __forceinline__ __half ushort_as_half(unsigned short x) { __half h; *reinterpret_cast(&h) = x; return h; }", + "sample_texture": "", +} + +_HELPER_ORDER: List[str] = [ + "composite_types", + "global_invocation_id", + "local_invocation_id", + "workgroup_id", + "local_invocation_index", + "subgroup_size", + "num_subgroups", + "subgroup_id", + "subgroup_invocation_id", + "subgroup_shuffle_xor", + "subgroup_add", + "subgroup_mul", + "subgroup_min", + "subgroup_max", + "subgroup_and", + "subgroup_or", + "subgroup_xor", + "mod", + "fract", + "roundEven", + "mix", + "step", + "smoothstep", + "radians", + "degrees", + "inversesqrt", + "floatBitsToInt", + "floatBitsToUint", + "intBitsToFloat", + "uintBitsToFloat", + "longlong_as_double", + "ushort_as_half", + "sample_texture", +] + +_HELPER_DEPENDENCIES: Dict[str, List[str]] = { + "mat2_type": ["composite_types"], + "mat3_type": ["composite_types"], + "mat4_type": ["composite_types"], + "make_mat2": ["composite_types"], + "make_mat3": ["composite_types"], + "make_mat4": ["composite_types"], + "make_short2": ["composite_types"], + "make_short3": ["composite_types"], + "make_short4": ["composite_types"], + "make_ushort2": ["composite_types"], + "make_ushort3": ["composite_types"], + "make_ushort4": ["composite_types"], + "make_int2": ["composite_types"], + "make_int3": ["composite_types"], + "make_int4": ["composite_types"], + "make_uint2": ["composite_types"], + "make_uint3": ["composite_types"], + "make_uint4": ["composite_types"], + "make_half2": ["composite_types"], + "make_half3": ["composite_types"], + "make_half4": ["composite_types"], + "float2_ops": ["composite_types"], + "make_float2": ["composite_types"], + "make_float3": ["composite_types"], + "make_float4": ["composite_types"], + "make_double2": ["composite_types"], + "make_double3": ["composite_types"], + "make_double4": ["composite_types"], + "global_invocation_id": ["composite_types"], + "local_invocation_id": ["composite_types"], + "workgroup_id": ["composite_types"], + "sample_texture": ["composite_types"], + "num_subgroups": ["subgroup_size"], + "subgroup_id": ["local_invocation_index", "subgroup_size"], + "subgroup_invocation_id": ["local_invocation_index", "subgroup_size"], + "subgroup_add": ["subgroup_size", "subgroup_shuffle_xor"], + "subgroup_mul": ["subgroup_size", "subgroup_shuffle_xor"], + "subgroup_min": ["subgroup_size", "subgroup_shuffle_xor"], + "subgroup_max": ["subgroup_size", "subgroup_shuffle_xor"], + "subgroup_and": ["subgroup_size", "subgroup_shuffle_xor"], + "subgroup_or": ["subgroup_size", "subgroup_shuffle_xor"], + "subgroup_xor": ["subgroup_size", "subgroup_shuffle_xor"], +} + + +def initialize_feature_usage() -> Dict[str, bool]: + return {feature_name: False for feature_name in _HELPER_SNIPPETS} diff --git a/vkdispatch/codegen/backends/cuda/math_utils.py b/vkdispatch/codegen/backends/cuda/math_utils.py new file mode 100644 index 00000000..fc5ce5ad --- /dev/null +++ b/vkdispatch/codegen/backends/cuda/math_utils.py @@ -0,0 +1,174 @@ +from typing import Dict, List, Optional, Set + +import vkdispatch.base.dtype as dtypes + +from .composite_emitters import _cuda_vec_components +from .specs import _CUDA_VEC_TYPE_SPECS, _FLOAT_VEC_HELPER_SUFFIX_MAP + + +def cuda_fast_unary_math_name(func_name: str) -> str: + if func_name == "sin": + return "__sinf" + if func_name == "cos": + return "__cosf" + if func_name == "tan": + return "__tanf" + if func_name == "exp": + return "__expf" + if func_name == "exp2": + return "__exp2f" + if func_name == "log": + return "__logf" + if func_name == "log2": + return "__log2f" + if func_name == "asin": + return "asinf" + if func_name == "acos": + return "acosf" + if func_name == "atan": + return "atanf" + if func_name == "sinh": + return "sinhf" + if func_name == "cosh": + return "coshf" + if func_name == "tanh": + return "tanhf" + if func_name == "asinh": + return "asinhf" + if func_name == "acosh": + return "acoshf" + if func_name == "atanh": + return "atanhf" + if func_name == "sqrt": + return "sqrtf" + + return func_name + + +def cuda_fast_binary_math_name(func_name: str) -> str: + if func_name == "atan2": + return "atan2f" + if func_name == "pow": + return "__powf" + + return func_name + + +def cuda_scalar_unary_math_name(func_name: str, scalar_type: str) -> str: + if scalar_type == "__half": + half_math = { + "sin": "hsin", + "cos": "hcos", + "exp": "hexp", + "exp2": "hexp2", + "log": "hlog", + "log2": "hlog2", + "sqrt": "hsqrt", + } + return half_math.get(func_name, func_name) + if scalar_type == "double": + return func_name + return cuda_fast_unary_math_name(func_name) + + +def cuda_scalar_binary_math_name(func_name: str, scalar_type: str) -> str: + if scalar_type == "__half": + return func_name + if scalar_type == "double": + return func_name + return cuda_fast_binary_math_name(func_name) + + +def cuda_float_vec_components_for_suffix(helper_suffix: str) -> List[str]: + dim_char = helper_suffix[-1] + if dim_char == "2": + return ["x", "y"] + if dim_char == "3": + return ["x", "y", "z"] + if dim_char == "4": + return ["x", "y", "z", "w"] + + raise ValueError(f"Unsupported CUDA float vector helper suffix '{helper_suffix}'") + + +def cuda_float_vec_helper_suffix(var_type: dtypes.dtype) -> Optional[str]: + return _FLOAT_VEC_HELPER_SUFFIX_MAP.get(var_type) + + +def emit_used_vec_math_helpers( + composite_vec_unary_math_usage: Dict[str, Set[str]], + composite_vec_binary_math_usage: Dict[str, Set[str]], +) -> str: + helper_sections: List[str] = [] + + unary_order = [ + "sin", + "cos", + "tan", + "asin", + "acos", + "atan", + "sinh", + "cosh", + "tanh", + "asinh", + "acosh", + "atanh", + "exp", + "exp2", + "log", + "log2", + "sqrt", + ] + binary_order = ["atan2", "pow"] + signature_order = ["vv", "vs", "sv"] + + for key in ["half2", "half3", "half4", "float2", "float3", "float4", "double2", "double3", "double4"]: + unary_funcs = composite_vec_unary_math_usage.get(key, set()) + binary_tokens = composite_vec_binary_math_usage.get(key, set()) + if len(unary_funcs) == 0 and len(binary_tokens) == 0: + continue + + if key not in _CUDA_VEC_TYPE_SPECS: + continue + + vec_name, scalar_type, dim, _, _, _ = _CUDA_VEC_TYPE_SPECS[key] + comps = _cuda_vec_components(dim) + lines: List[str] = [] + + for func_name in unary_order: + if func_name not in unary_funcs: + continue + scalar_func = cuda_scalar_unary_math_name(func_name, scalar_type) + comp_args = ", ".join([f"{scalar_func}(v.v.{c})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} {func_name}(const {vec_name}& v) {{ return vkdispatch_make_{key}({comp_args}); }}" + ) + + for func_name in binary_order: + scalar_func = cuda_scalar_binary_math_name(func_name, scalar_type) + for signature in signature_order: + token = f"{func_name}:{signature}" + if token not in binary_tokens: + continue + + if signature == "vv": + comp_args = ", ".join([f"{scalar_func}(a.v.{c}, b.v.{c})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} {func_name}(const {vec_name}& a, const {vec_name}& b) {{ return vkdispatch_make_{key}({comp_args}); }}" + ) + elif signature == "vs": + comp_args = ", ".join([f"{scalar_func}(a.v.{c}, b)" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} {func_name}(const {vec_name}& a, {scalar_type} b) {{ return vkdispatch_make_{key}({comp_args}); }}" + ) + elif signature == "sv": + comp_args = ", ".join([f"{scalar_func}(a, b.v.{c})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} {func_name}({scalar_type} a, const {vec_name}& b) {{ return vkdispatch_make_{key}({comp_args}); }}" + ) + + if len(lines) > 0: + helper_sections.append("\n".join(lines)) + + return "\n\n".join(helper_sections) diff --git a/vkdispatch/codegen/backends/cuda/specs.py b/vkdispatch/codegen/backends/cuda/specs.py new file mode 100644 index 00000000..c029b5b0 --- /dev/null +++ b/vkdispatch/codegen/backends/cuda/specs.py @@ -0,0 +1,120 @@ +from typing import Dict, FrozenSet, Tuple + +import vkdispatch.base.dtype as dtypes + + +_CUDA_VEC_TYPE_SPECS: Dict[str, Tuple[str, str, int, str, bool, bool]] = { + "short2": ("vkdispatch_short2", "short", 2, "short2", True, True), + "short3": ("vkdispatch_short3", "short", 3, "short3", True, True), + "short4": ("vkdispatch_short4", "short", 4, "short4", True, True), + "ushort2": ("vkdispatch_ushort2", "unsigned short", 2, "ushort2", False, True), + "ushort3": ("vkdispatch_ushort3", "unsigned short", 3, "ushort3", False, True), + "ushort4": ("vkdispatch_ushort4", "unsigned short", 4, "ushort4", False, True), + "int2": ("vkdispatch_int2", "int", 2, "int2", True, True), + "int3": ("vkdispatch_int3", "int", 3, "int3", True, True), + "int4": ("vkdispatch_int4", "int", 4, "int4", True, True), + "uint2": ("vkdispatch_uint2", "unsigned int", 2, "uint2", False, True), + "uint3": ("vkdispatch_uint3", "unsigned int", 3, "uint3", False, True), + "uint4": ("vkdispatch_uint4", "unsigned int", 4, "uint4", False, True), + "half2": ("vkdispatch_half2", "__half", 2, "half2", True, False), + "half3": ("vkdispatch_half3", "__half", 3, "half3", True, False), + "half4": ("vkdispatch_half4", "__half", 4, "half4", True, False), + "float2": ("vkdispatch_float2", "float", 2, "float2", True, False), + "float3": ("vkdispatch_float3", "float", 3, "float3", True, False), + "float4": ("vkdispatch_float4", "float", 4, "float4", True, False), + "double2": ("vkdispatch_double2", "double", 2, "double2", True, False), + "double3": ("vkdispatch_double3", "double", 3, "double3", True, False), + "double4": ("vkdispatch_double4", "double", 4, "double4", True, False), +} + +_CUDA_MAT_TYPE_SPECS: Dict[str, Tuple[str, str, str, int]] = { + "mat2": ("vkdispatch_mat2", "vkdispatch_float2", "float2", 2), + "mat3": ("vkdispatch_mat3", "vkdispatch_float3", "float3", 3), + "mat4": ("vkdispatch_mat4", "vkdispatch_float4", "float4", 4), +} + +_CUDA_VEC_ORDER = [ + "short2", "short3", "short4", + "ushort2", "ushort3", "ushort4", + "int2", "int3", "int4", + "uint2", "uint3", "uint4", + "half2", "half3", "half4", + "float2", "float3", "float4", + "double2", "double3", "double4", +] + +_CUDA_MAT_ORDER = ["mat2", "mat3", "mat4"] + +_DTYPE_TO_COMPOSITE_KEY = { + dtypes.ihvec2: "short2", + dtypes.ihvec3: "short3", + dtypes.ihvec4: "short4", + dtypes.uhvec2: "ushort2", + dtypes.uhvec3: "ushort3", + dtypes.uhvec4: "ushort4", + dtypes.ivec2: "int2", + dtypes.ivec3: "int3", + dtypes.ivec4: "int4", + dtypes.uvec2: "uint2", + dtypes.uvec3: "uint3", + dtypes.uvec4: "uint4", + dtypes.hvec2: "half2", + dtypes.hvec3: "half3", + dtypes.hvec4: "half4", + dtypes.complex32: "half2", + dtypes.complex64: "float2", + dtypes.complex128: "double2", + dtypes.vec2: "float2", + dtypes.vec3: "float3", + dtypes.vec4: "float4", + dtypes.dvec2: "double2", + dtypes.dvec3: "double3", + dtypes.dvec4: "double4", + dtypes.mat2: "mat2", + dtypes.mat3: "mat3", + dtypes.mat4: "mat4", +} + +_SCALAR_TYPE_NAMES = { + dtypes.int16: "short", + dtypes.uint16: "unsigned short", + dtypes.int32: "int", + dtypes.uint32: "unsigned int", + dtypes.int64: "long long", + dtypes.uint64: "unsigned long long", + dtypes.float16: "__half", + dtypes.float32: "float", + dtypes.float64: "double", +} + +_FLOAT_VEC_DTYPES: FrozenSet[dtypes.dtype] = frozenset( + { + dtypes.complex32, + dtypes.complex64, + dtypes.complex128, + dtypes.hvec2, + dtypes.hvec3, + dtypes.hvec4, + dtypes.vec2, + dtypes.vec3, + dtypes.vec4, + dtypes.dvec2, + dtypes.dvec3, + dtypes.dvec4, + } +) + +_FLOAT_VEC_HELPER_SUFFIX_MAP = { + dtypes.hvec2: "half2", + dtypes.hvec3: "half3", + dtypes.hvec4: "half4", + dtypes.complex32: "half2", + dtypes.complex64: "float2", + dtypes.complex128: "double2", + dtypes.vec2: "float2", + dtypes.vec3: "float3", + dtypes.vec4: "float4", + dtypes.dvec2: "double2", + dtypes.dvec3: "double3", + dtypes.dvec4: "double4", +} diff --git a/vkdispatch/codegen/backends/glsl.py b/vkdispatch/codegen/backends/glsl.py index e0c82738..c2187e06 100644 --- a/vkdispatch/codegen/backends/glsl.py +++ b/vkdispatch/codegen/backends/glsl.py @@ -1,17 +1,52 @@ -from typing import List, Optional +from typing import List, Optional, Set import vkdispatch.base.dtype as dtypes from .base import CodeGenBackend +# Map scalar dtypes to GLSL extension names. +_GLSL_TYPE_EXTENSIONS = { + dtypes.float16: "GL_EXT_shader_explicit_arithmetic_types_float16", + dtypes.int16: "GL_EXT_shader_explicit_arithmetic_types_int16", + dtypes.uint16: "GL_EXT_shader_explicit_arithmetic_types_int16", + dtypes.int64: "GL_ARB_gpu_shader_int64", + dtypes.uint64: "GL_ARB_gpu_shader_int64", + dtypes.float64: "GL_ARB_gpu_shader_fp64", +} + class GLSLBackend(CodeGenBackend): name = "glsl" + def __init__(self) -> None: + super().__init__() + self._needed_extensions: Set[str] = set() + + def reset_state(self) -> None: + self._needed_extensions = set() + + def _track_type_extension(self, var_type: dtypes.dtype) -> None: + """Record the GLSL extension required by *var_type* (if any).""" + scalar = var_type + if dtypes.is_vector(var_type) or dtypes.is_matrix(var_type): + scalar = var_type.scalar + elif dtypes.is_complex(var_type): + scalar = var_type.child_type + ext = _GLSL_TYPE_EXTENSIONS.get(scalar) + if ext is not None: + self._needed_extensions.add(ext) + def type_name(self, var_type: dtypes.dtype) -> str: + self._track_type_extension(var_type) return var_type.glsl_type - def constructor(self, var_type: dtypes.dtype, args: List[str]) -> str: + def constructor( + self, + var_type: dtypes.dtype, + args: List[str], + arg_types: Optional[List[Optional[dtypes.dtype]]] = None, + ) -> str: + _ = arg_types return f"{self.type_name(var_type)}({', '.join(args)})" def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str: @@ -24,10 +59,17 @@ def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str: if enable_printf: header += "#extension GL_EXT_debug_printf : require\n" - return header + ext_block = "" + for ext in sorted(self._needed_extensions): + ext_line = f"#extension {ext} : require\n" + if ext_line not in header: + ext_block += ext_line + + return header + ext_block def make_source(self, header: str, body: str, x: int, y: int, z: int) -> str: layout_str = f"layout(local_size_x = {x}, local_size_y = {y}, local_size_z = {z}) in;" + return f"{header}\n{layout_str}\n{body}" def constant_namespace(self) -> str: @@ -63,6 +105,18 @@ def inf_f32_expr(self) -> str: def ninf_f32_expr(self) -> str: return "uintBitsToFloat(0xFF800000)" + def inf_f64_expr(self) -> str: + return "packDouble2x32(uvec2(0x00000000u, 0x7FF00000u))" + + def ninf_f64_expr(self) -> str: + return "packDouble2x32(uvec2(0x00000000u, 0xFFF00000u))" + + def inf_f16_expr(self) -> str: + return "float16_t(uintBitsToFloat(0x7F800000))" + + def ninf_f16_expr(self) -> str: + return "float16_t(uintBitsToFloat(0xFF800000))" + def float_bits_to_int_expr(self, var_expr: str) -> str: return f"floatBitsToInt({var_expr})" @@ -123,25 +177,32 @@ def memory_barrier_image_statement(self) -> str: def group_memory_barrier_statement(self) -> str: return "groupMemoryBarrier();" - def subgroup_add_expr(self, arg_expr: str) -> str: + def subgroup_add_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type return f"subgroupAdd({arg_expr})" - def subgroup_mul_expr(self, arg_expr: str) -> str: + def subgroup_mul_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type return f"subgroupMul({arg_expr})" - def subgroup_min_expr(self, arg_expr: str) -> str: + def subgroup_min_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type return f"subgroupMin({arg_expr})" - def subgroup_max_expr(self, arg_expr: str) -> str: + def subgroup_max_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type return f"subgroupMax({arg_expr})" - def subgroup_and_expr(self, arg_expr: str) -> str: + def subgroup_and_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type return f"subgroupAnd({arg_expr})" - def subgroup_or_expr(self, arg_expr: str) -> str: + def subgroup_or_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type return f"subgroupOr({arg_expr})" - def subgroup_xor_expr(self, arg_expr: str) -> str: + def subgroup_xor_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type return f"subgroupXor({arg_expr})" def subgroup_elect_expr(self) -> str: @@ -166,3 +227,9 @@ def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Opti return f"texture({texture_expr}, {coord_expr})" return f"texture({texture_expr}, {coord_expr}, {lod_expr})" + + def atomic_add_expr(self, mem_expr: str, value_expr: str, var_type: dtypes.dtype) -> str: + if var_type not in (dtypes.int32, dtypes.uint32): + raise NotImplementedError(f"GLSL atomic_add only supports int32/uint32, got '{var_type.name}'") + + return f"atomicAdd({mem_expr}, {value_expr})" diff --git a/vkdispatch/codegen/backends/opencl.py b/vkdispatch/codegen/backends/opencl.py new file mode 100644 index 00000000..907f0508 --- /dev/null +++ b/vkdispatch/codegen/backends/opencl.py @@ -0,0 +1,701 @@ +from typing import List, Optional, Set + +import vkdispatch.base.dtype as dtypes + +from .base import CodeGenBackend + + +class OpenCLBackend(CodeGenBackend): + name = "opencl" + + _SCALAR_TYPE_NAMES = { + dtypes.int16: "short", + dtypes.uint16: "ushort", + dtypes.int32: "int", + dtypes.uint32: "uint", + dtypes.int64: "long", + dtypes.uint64: "ulong", + dtypes.float16: "half", + dtypes.float32: "float", + dtypes.float64: "double", + } + + _MATRIX_TYPE_NAMES = { + dtypes.mat2: "vkdispatch_mat2", + dtypes.mat3: "vkdispatch_mat3", + dtypes.mat4: "vkdispatch_mat4", + } + + def __init__(self) -> None: + self.reset_state() + + def reset_state(self) -> None: + self._kernel_params: List[str] = [] + self._entry_alias_lines: List[str] = [] + self._shared_buffer_lines: List[str] = [] + self._matrix_type_usage: Set[int] = set() + + def _register_kernel_param(self, param_decl: str) -> None: + if param_decl not in self._kernel_params: + self._kernel_params.append(param_decl) + + def _register_alias_line(self, alias_line: str) -> None: + self._entry_alias_lines.append(alias_line) + + def _record_matrix_dim(self, dim: int) -> None: + if dim not in (2, 3, 4): + raise ValueError(f"Unsupported OpenCL matrix dimension '{dim}'") + self._matrix_type_usage.add(dim) + + def _record_matrix_type(self, var_type: dtypes.dtype) -> None: + if dtypes.is_matrix(var_type): + self._record_matrix_dim(var_type.child_count) + + @staticmethod + def _matrix_helper_name(dim: int, constructor_kind: str) -> str: + return f"vkdispatch_make_mat{dim}_{constructor_kind}" + + def _is_matrix_copy_constructor_arg(self, arg_expr: str, dim: int) -> bool: + stripped = arg_expr.strip() + mat_type = self._matrix_struct_name(dim) + + if stripped.startswith(f"({mat_type})") or stripped.startswith(f"(({mat_type})"): + return True + + if f"vkdispatch_make_mat{dim}_" in stripped: + return True + + if f"vkdispatch_mat{dim}_" in stripped: + return True + + return False + + @classmethod + def _scalar_type_name(cls, scalar_type: dtypes.dtype) -> str: + type_name = cls._SCALAR_TYPE_NAMES.get(scalar_type) + if type_name is None: + raise ValueError(f"Unsupported OpenCL scalar type mapping for '{scalar_type.name}'") + return type_name + + def type_name(self, var_type: dtypes.dtype) -> str: + if dtypes.is_scalar(var_type): + return self._scalar_type_name(var_type) + + if dtypes.is_vector(var_type): + return f"{self._scalar_type_name(var_type.scalar)}{var_type.child_count}" + + if dtypes.is_complex(var_type): + return f"{self._scalar_type_name(var_type.child_type)}2" + + if dtypes.is_matrix(var_type): + self._record_matrix_type(var_type) + matrix_name = self._MATRIX_TYPE_NAMES.get(var_type) + if matrix_name is None: + raise ValueError(f"Unsupported OpenCL matrix type mapping for '{var_type.name}'") + return matrix_name + + raise ValueError(f"Unsupported OpenCL type mapping for '{var_type.name}'") + + def constructor( + self, + var_type: dtypes.dtype, + args: List[str], + arg_types: Optional[List[Optional[dtypes.dtype]]] = None, + ) -> str: + target_type = self.type_name(var_type) + + if dtypes.is_scalar(var_type): + assert len(args) > 0, f"Constructor for scalar type '{var_type.name}' needs at least one argument." + return f"(({target_type})({args[0]}))" + + if dtypes.is_matrix(var_type): + dim = var_type.child_count + assert len(args) in (1, dim, dim * dim), ( + f"Constructor for matrix type '{var_type.name}' needs 1, {dim}, or {dim * dim} arguments." + ) + if len(args) == 1: + single_arg = args[0] + helper_name = self._matrix_helper_name( + dim, + "copy" if self._is_matrix_copy_constructor_arg(single_arg, dim) else "scalar", + ) + return f"{helper_name}({single_arg})" + + if len(args) == dim: + return f"{self._matrix_helper_name(dim, 'cols')}({', '.join(args)})" + + return f"{self._matrix_helper_name(dim, 'flat')}({', '.join(args)})" + + # NVIDIA's OpenCL frontend rejects direct vector casts between different + # vector base types (e.g. uint2 -> float2). Use convert_* builtins when + # we know this is a vector/complex-to-vector/complex conversion. + if ( + len(args) == 1 + and arg_types is not None + and len(arg_types) == 1 + and arg_types[0] is not None + and (dtypes.is_vector(var_type) or dtypes.is_complex(var_type)) + and (dtypes.is_vector(arg_types[0]) or dtypes.is_complex(arg_types[0])) + ): + return f"convert_{target_type}({args[0]})" + + return f"(({target_type})({', '.join(args)}))" + + def component_access_expr(self, expr: str, component: str, base_type: dtypes.dtype) -> str: + if dtypes.is_scalar(base_type) and component == "x": + return expr + return super().component_access_expr(expr, component, base_type) + + def buffer_component_expr( + self, + scalar_buffer_expr: str, + base_type: dtypes.dtype, + element_index_expr: str, + component_index_expr: str, + ) -> Optional[str]: + if dtypes.is_complex(base_type): + component_count = base_type.child_count + elif dtypes.is_vector(base_type): + component_count = base_type.child_count + else: + return None + + return ( + f"{scalar_buffer_expr}[" + f"(({element_index_expr}) * {component_count}) + ({component_index_expr})" + f"]" + ) + + def _cast_math_arg(self, arg_type: dtypes.dtype, arg_expr: str) -> str: + if dtypes.is_scalar(arg_type) or dtypes.is_vector(arg_type) or dtypes.is_complex(arg_type): + return self.constructor(arg_type, [arg_expr], arg_types=[arg_type]) + + return arg_expr + + def math_func_name(self, func_name: str, var_type: dtypes.dtype) -> str: + func_name_dict = { + "sin": "native_sin", + "cos": "native_cos", + "tan": "native_tan", + "sqrt": "native_sqrt", + "exp": "native_exp", + "exp2": "native_exp2", + "log": "native_log", + "log2": "native_log2", + } + + if func_name in func_name_dict: + return func_name_dict[func_name] + + return func_name + + def unary_math_expr(self, func_name: str, arg_type: dtypes.dtype, arg_expr: str) -> str: + mapped = self.math_func_name(func_name, arg_type) + return f"{mapped}({self._cast_math_arg(arg_type, arg_expr)})" + + def binary_math_expr( + self, + func_name: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + ) -> str: + mapped = self.math_func_name(func_name, lhs_type) + lhs_cast_expr = self._cast_math_arg(lhs_type, lhs_expr) + rhs_cast_expr = self._cast_math_arg(rhs_type, rhs_expr) + return f"{mapped}({lhs_cast_expr}, {rhs_cast_expr})" + + def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str: + _ = enable_subgroup_ops + _ = enable_printf + header = ( + "// OpenCL C source generated by vkdispatch\n" + "#ifdef cl_khr_global_int32_base_atomics\n" + "#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable\n" + "#endif\n" + "#ifdef cl_khr_local_int32_base_atomics\n" + "#pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable\n" + "#endif\n" + "#ifdef cl_khr_fp64\n" + "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" + "#endif\n" + "#ifdef cl_khr_fp16\n" + "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" + "#endif\n" + ) + matrix_helpers = self._emit_matrix_helpers() + if len(matrix_helpers) > 0: + header += f"\n{matrix_helpers}\n" + return header + + def _emit_matrix_helpers(self) -> str: + if len(self._matrix_type_usage) == 0: + return "" + + sections: List[str] = [] + if 3 in self._matrix_type_usage: + sections.append( + "typedef struct __attribute__((packed)) vkdispatch_packed_float3 {\n" + " float x;\n" + " float y;\n" + " float z;\n" + "} vkdispatch_packed_float3;\n" + "static inline float3 vkdispatch_unpack_float3(vkdispatch_packed_float3 v) { return (float3)(v.x, v.y, v.z); }\n" + "static inline vkdispatch_packed_float3 vkdispatch_pack_float3(float3 v) {\n" + " vkdispatch_packed_float3 out = {v.x, v.y, v.z};\n" + " return out;\n" + "}" + ) + + for dim in sorted(self._matrix_type_usage): + sections.append(self._emit_matrix_helpers_for_dim(dim)) + + return "\n\n".join(sections) + + @staticmethod + def _vector_components(dim: int) -> List[str]: + return list("xyzw"[:dim]) + + @staticmethod + def _matrix_struct_name(dim: int) -> str: + return f"vkdispatch_mat{dim}" + + @staticmethod + def _vector_type_name(dim: int) -> str: + return f"float{dim}" + + def _matrix_col_expr(self, mat_expr: str, col: int, dim: int) -> str: + if dim == 3: + return f"vkdispatch_unpack_float3({mat_expr}.c{col})" + return f"{mat_expr}.c{col}" + + def _matrix_col_assign_stmt(self, target_expr: str, col: int, value_expr: str, dim: int) -> str: + if dim == 3: + return f"{target_expr}.c{col} = vkdispatch_pack_float3({value_expr});" + return f"{target_expr}.c{col} = {value_expr};" + + def _emit_matrix_helpers_for_dim(self, dim: int) -> str: + mat_type = self._matrix_struct_name(dim) + vec_type = self._vector_type_name(dim) + comps = self._vector_components(dim) + scalar_helper_name = self._matrix_helper_name(dim, "scalar") + copy_helper_name = self._matrix_helper_name(dim, "copy") + cols_helper_name = self._matrix_helper_name(dim, "cols") + flat_helper_name = self._matrix_helper_name(dim, "flat") + + lines: List[str] = [] + + if dim == 3: + lines.append( + "typedef struct __attribute__((packed)) vkdispatch_mat3 {\n" + " vkdispatch_packed_float3 c0;\n" + " vkdispatch_packed_float3 c1;\n" + " vkdispatch_packed_float3 c2;\n" + "} vkdispatch_mat3;" + ) + else: + cols = "\n".join([f" {vec_type} c{i};" for i in range(dim)]) + lines.append(f"typedef struct {mat_type} {{\n{cols}\n}} {mat_type};") + + # Constructors. + lines.append(f"static inline {mat_type} {scalar_helper_name}(float s) {{") + lines.append(f" {mat_type} out;") + for col_idx in range(dim): + diag_values = [("s" if row_idx == col_idx else "0.0f") for row_idx in range(dim)] + vec_expr = f"({vec_type})(" + ", ".join(diag_values) + ")" + lines.append(f" {self._matrix_col_assign_stmt('out', col_idx, vec_expr, dim)}") + lines.append(" return out;") + lines.append("}") + + lines.append(f"static inline {mat_type} {copy_helper_name}({mat_type} m) {{ return m; }}") + + col_args = ", ".join([f"{vec_type} c{i}" for i in range(dim)]) + lines.append(f"static inline {mat_type} {cols_helper_name}({col_args}) {{") + lines.append(f" {mat_type} out;") + for col_idx in range(dim): + lines.append(f" {self._matrix_col_assign_stmt('out', col_idx, f'c{col_idx}', dim)}") + lines.append(" return out;") + lines.append("}") + + flat_names = [f"m{col}{row}" for col in range(dim) for row in range(dim)] + flat_args = ", ".join([f"float {name}" for name in flat_names]) + lines.append(f"static inline {mat_type} {flat_helper_name}({flat_args}) {{") + lines.append(f" return {cols_helper_name}(") + for col_idx in range(dim): + values = [f"m{col_idx}{row_idx}" for row_idx in range(dim)] + suffix = "," if col_idx < dim - 1 else "" + lines.append(f" ({vec_type})({', '.join(values)}){suffix}") + lines.append(" );") + lines.append("}") + + # Unary negation. + lines.append(f"static inline {mat_type} vkdispatch_mat{dim}_neg({mat_type} a) {{") + lines.append(f" {mat_type} out;") + for col_idx in range(dim): + col_expr = self._matrix_col_expr("a", col_idx, dim) + lines.append( + f" {self._matrix_col_assign_stmt('out', col_idx, f'-{col_expr}', dim)}" + ) + lines.append(" return out;") + lines.append("}") + + # Matrix +/- matrix. + for op_name, op_symbol in (("add", "+"), ("sub", "-")): + lines.append( + f"static inline {mat_type} vkdispatch_mat{dim}_{op_name}_mm({mat_type} a, {mat_type} b) {{" + ) + lines.append(f" {mat_type} out;") + for col_idx in range(dim): + lhs_col = self._matrix_col_expr("a", col_idx, dim) + rhs_col = self._matrix_col_expr("b", col_idx, dim) + lines.append( + f" {self._matrix_col_assign_stmt('out', col_idx, f'{lhs_col} {op_symbol} {rhs_col}', dim)}" + ) + lines.append(" return out;") + lines.append("}") + + # Matrix/scalar and scalar/matrix arithmetic. + for op_name, op_symbol in (("add", "+"), ("sub", "-"), ("mul", "*"), ("div", "/")): + lines.append( + f"static inline {mat_type} vkdispatch_mat{dim}_{op_name}_ms({mat_type} a, float b) {{" + ) + lines.append(f" {mat_type} out;") + for col_idx in range(dim): + lhs_col = self._matrix_col_expr("a", col_idx, dim) + lines.append( + f" {self._matrix_col_assign_stmt('out', col_idx, f'{lhs_col} {op_symbol} b', dim)}" + ) + lines.append(" return out;") + lines.append("}") + + lines.append( + f"static inline {mat_type} vkdispatch_mat{dim}_{op_name}_sm(float a, {mat_type} b) {{" + ) + lines.append(f" {mat_type} out;") + for col_idx in range(dim): + rhs_col = self._matrix_col_expr("b", col_idx, dim) + lines.append( + f" {self._matrix_col_assign_stmt('out', col_idx, f'a {op_symbol} {rhs_col}', dim)}" + ) + lines.append(" return out;") + lines.append("}") + + # Matrix/vector product (column-major, GLSL-style): m * v. + mat_vec_terms = [f"({self._matrix_col_expr('m', i, dim)} * v.{comps[i]})" for i in range(dim)] + lines.append(f"static inline {vec_type} vkdispatch_mat{dim}_mul_mv({mat_type} m, {vec_type} v) {{") + lines.append(f" return {' + '.join(mat_vec_terms)};") + lines.append("}") + + # Vector/matrix product (column-major, GLSL-style): v * m. + lines.append(f"static inline {vec_type} vkdispatch_mat{dim}_mul_vm({vec_type} v, {mat_type} m) {{") + for col_idx in range(dim): + lines.append(f" {vec_type} col{col_idx} = {self._matrix_col_expr('m', col_idx, dim)};") + row_exprs = [] + for col_idx in range(dim): + terms = [f"(v.{comps[row_idx]} * col{col_idx}.{comps[row_idx]})" for row_idx in range(dim)] + row_exprs.append(" + ".join(terms)) + lines.append(f" return ({vec_type})({', '.join(row_exprs)});") + lines.append("}") + + return "\n".join(lines) + + def arithmetic_unary_expr(self, op: str, var_type: dtypes.dtype, var_expr: str) -> Optional[str]: + if op == "-" and dtypes.is_matrix(var_type): + dim = var_type.child_count + self._record_matrix_dim(dim) + return f"vkdispatch_mat{dim}_neg({var_expr})" + return None + + def arithmetic_binary_expr( + self, + op: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + ) -> Optional[str]: + if not (dtypes.is_matrix(lhs_type) or dtypes.is_matrix(rhs_type)): + return None + + if op not in ("+", "-", "*", "/"): + raise NotImplementedError( + f"OpenCL matrix arithmetic override does not support operator '{op}' " + f"for ({lhs_type.name}, {rhs_type.name})." + ) + + if dtypes.is_matrix(lhs_type): + dim = lhs_type.child_count + if dtypes.is_matrix(rhs_type): + if rhs_type.child_count != dim: + raise ValueError( + f"OpenCL matrix arithmetic requires matching dimensions, got '{lhs_type.name}' and '{rhs_type.name}'." + ) + if op not in ("+", "-"): + raise NotImplementedError( + f"OpenCL matrix arithmetic does not support operator '{op}' for two matrices." + ) + self._record_matrix_dim(dim) + return f"vkdispatch_mat{dim}_{'add' if op == '+' else 'sub'}_mm({lhs_expr}, {rhs_expr})" + + if dtypes.is_scalar(rhs_type): + self._record_matrix_dim(dim) + suffix = "add" if op == "+" else "sub" if op == "-" else "mul" if op == "*" else "div" + return f"vkdispatch_mat{dim}_{suffix}_ms({lhs_expr}, {rhs_expr})" + + if dtypes.is_vector(rhs_type) and op == "*": + if rhs_type.child_count != dim or rhs_type.scalar != dtypes.float32: + raise ValueError( + f"OpenCL matrix/vector multiplication requires float32 vec{dim}, got '{rhs_type.name}'." + ) + self._record_matrix_dim(dim) + return f"vkdispatch_mat{dim}_mul_mv({lhs_expr}, {rhs_expr})" + + raise NotImplementedError( + f"Unsupported OpenCL matrix arithmetic for ({lhs_type.name}, {rhs_type.name}) with operator '{op}'." + ) + + # lhs is not matrix; rhs is matrix + dim = rhs_type.child_count + if dtypes.is_scalar(lhs_type): + self._record_matrix_dim(dim) + suffix = "add" if op == "+" else "sub" if op == "-" else "mul" if op == "*" else "div" + return f"vkdispatch_mat{dim}_{suffix}_sm({lhs_expr}, {rhs_expr})" + + if dtypes.is_vector(lhs_type) and op == "*": + if lhs_type.child_count != dim or lhs_type.scalar != dtypes.float32: + raise ValueError( + f"OpenCL vector/matrix multiplication requires float32 vec{dim}, got '{lhs_type.name}'." + ) + self._record_matrix_dim(dim) + return f"vkdispatch_mat{dim}_mul_vm({lhs_expr}, {rhs_expr})" + + raise NotImplementedError( + f"Unsupported OpenCL matrix arithmetic for ({lhs_type.name}, {rhs_type.name}) with operator '{op}'." + ) + + def make_source(self, header: str, body: str, x: int, y: int, z: int) -> str: + expected_size_header = ( + f"// Expected local size: ({x}, {y}, {z})\n" + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X {x}\n" + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y {y}\n" + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Z {z}\n" + ) + workgroup_attribute = f"__attribute__((reqd_work_group_size({x}, {y}, {z})))" + if "__kernel void vkdispatch_main" in body: + body = body.replace( + "__kernel void vkdispatch_main", + f"{workgroup_attribute}\n__kernel void vkdispatch_main", + 1, + ) + else: + body = f"{workgroup_attribute}\n{body}" + + return f"{expected_size_header}\n{header}\n{body}" + + def constant_namespace(self) -> str: + return "UBO" + + def variable_namespace(self) -> str: + return "PC" + + def exec_bounds_guard(self, exec_count_expr: str) -> str: + gid_expr = f"({self.global_invocation_id_expr()})" + exec_expr = f"({exec_count_expr})" + return ( + f"if ({self.component_access_expr(exec_expr, 'x', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'x', dtypes.uvec3)} || " + f"{self.component_access_expr(exec_expr, 'y', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'y', dtypes.uvec3)} || " + f"{self.component_access_expr(exec_expr, 'z', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'z', dtypes.uvec3)}) {{ return; }}\n" + ) + + def shared_buffer_declaration(self, var_type: dtypes.dtype, name: str, size: int) -> str: + self._shared_buffer_lines.append(f"__local {self.type_name(var_type)} {name}[{size}];") + # OpenCL requires __local storage declarations at kernel/function scope. + return "" + + def uniform_block_declaration(self, contents: str) -> str: + self._register_kernel_param("__global const UniformObjectBuffer* vkdispatch_uniform_ptr") + self._register_alias_line("const UniformObjectBuffer UBO = *vkdispatch_uniform_ptr;") + return f"\ntypedef struct UniformObjectBuffer {{\n{contents}\n}} UniformObjectBuffer;\n" + + def storage_buffer_declaration(self, binding: int, var_type: dtypes.dtype, name: str) -> str: + struct_name = f"Buffer{binding}" + param_name = f"vkdispatch_binding_{binding}_ptr" + data_type = self.type_name(var_type) + self._register_kernel_param(f"__global {data_type}* {param_name}") + if dtypes.is_complex(var_type): + scalar_type = self.type_name(var_type.child_type) + self._register_alias_line( + f"__global {scalar_type}* {name}_scalar = (__global {scalar_type}*)({param_name});" + ) + elif dtypes.is_vector(var_type): + scalar_type = self.type_name(var_type.scalar) + self._register_alias_line( + f"__global {scalar_type}* {name}_scalar = (__global {scalar_type}*)({param_name});" + ) + self._register_alias_line(f"{struct_name} {name} = {{{param_name}}};") + return f"typedef struct {struct_name} {{ __global {data_type}* data; }} {struct_name};\n" + + def sampler_declaration(self, binding: int, dimensions: int, name: str) -> str: + _ = (binding, dimensions, name) + raise NotImplementedError("image/sampler unsupported in OpenCL backend") + + def push_constant_declaration(self, contents: str) -> str: + self._register_kernel_param("const PushConstant vkdispatch_pc_value") + self._register_alias_line("const PushConstant PC = vkdispatch_pc_value;") + return f"\ntypedef struct PushConstant {{\n{contents}\n}} PushConstant;\n" + + def entry_point(self, body_contents: str) -> str: + params = ", ".join(self._kernel_params) + alias_block = "" + shared_block = "" + for line in self._shared_buffer_lines: + shared_block += f" {line}\n" + for line in self._entry_alias_lines: + alias_block += f" {line}\n" + + return ( + f"__kernel void vkdispatch_main({params}) {{\n" + f"{shared_block}" + f"{alias_block}" + f"{body_contents}" + f"}}\n" + ) + + def inf_f32_expr(self) -> str: + return "as_float((uint)0x7F800000u)" + + def ninf_f32_expr(self) -> str: + return "as_float((uint)0xFF800000u)" + + def inf_f64_expr(self) -> str: + return "as_double((ulong)0x7FF0000000000000UL)" + + def ninf_f64_expr(self) -> str: + return "as_double((ulong)0xFFF0000000000000UL)" + + def inf_f16_expr(self) -> str: + return "as_half((ushort)0x7C00u)" + + def ninf_f16_expr(self) -> str: + return "as_half((ushort)0xFC00u)" + + def float_bits_to_int_expr(self, var_expr: str) -> str: + return f"as_int({var_expr})" + + def float_bits_to_uint_expr(self, var_expr: str) -> str: + return f"as_uint({var_expr})" + + def int_bits_to_float_expr(self, var_expr: str) -> str: + return f"as_float({var_expr})" + + def uint_bits_to_float_expr(self, var_expr: str) -> str: + return f"as_float({var_expr})" + + def global_invocation_id_expr(self) -> str: + return "((uint3)((uint)get_global_id(0), (uint)get_global_id(1), (uint)get_global_id(2)))" + + def local_invocation_id_expr(self) -> str: + return "((uint3)((uint)get_local_id(0), (uint)get_local_id(1), (uint)get_local_id(2)))" + + def local_invocation_index_expr(self) -> str: + return ( + "((uint)(get_local_id(0) + " + "get_local_size(0) * (get_local_id(1) + get_local_size(1) * get_local_id(2))))" + ) + + def workgroup_id_expr(self) -> str: + return "((uint3)((uint)get_group_id(0), (uint)get_group_id(1), (uint)get_group_id(2)))" + + def workgroup_size_expr(self) -> str: + return "((uint3)((uint)get_local_size(0), (uint)get_local_size(1), (uint)get_local_size(2)))" + + def num_workgroups_expr(self) -> str: + return "((uint3)((uint)get_num_groups(0), (uint)get_num_groups(1), (uint)get_num_groups(2)))" + + def num_subgroups_expr(self) -> str: + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_id_expr(self) -> str: + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_size_expr(self) -> str: + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_invocation_id_expr(self) -> str: + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def barrier_statement(self) -> str: + return "barrier(CLK_LOCAL_MEM_FENCE);" + + def memory_barrier_statement(self) -> str: + return "mem_fence(CLK_LOCAL_MEM_FENCE);" + + def memory_barrier_buffer_statement(self) -> str: + return "mem_fence(CLK_GLOBAL_MEM_FENCE);" + + def memory_barrier_shared_statement(self) -> str: + return "mem_fence(CLK_LOCAL_MEM_FENCE);" + + def memory_barrier_image_statement(self) -> str: + raise NotImplementedError("image/sampler unsupported in OpenCL backend") + + def group_memory_barrier_statement(self) -> str: + return "mem_fence(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE);" + + def subgroup_add_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_mul_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_min_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_max_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_and_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_or_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_xor_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_elect_expr(self) -> str: + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_barrier_statement(self) -> str: + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def printf_statement(self, fmt: str, args: List[str]) -> str: + if len(args) == 0: + return f'printf("{fmt}");' + return f'printf("{fmt}", {", ".join(args)});' + + def texture_size_expr(self, texture_expr: str, lod: int, dimensions: int) -> str: + _ = (texture_expr, lod, dimensions) + raise NotImplementedError("image/sampler unsupported in OpenCL backend") + + def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Optional[str] = None) -> str: + _ = (texture_expr, coord_expr, lod_expr) + raise NotImplementedError("image/sampler unsupported in OpenCL backend") + + def mark_texture_sample_dimension(self, dimensions: int) -> None: + _ = dimensions + raise NotImplementedError("image/sampler unsupported in OpenCL backend") + + def atomic_add_expr(self, mem_expr: str, value_expr: str, var_type: dtypes.dtype) -> str: + if var_type not in (dtypes.int32, dtypes.uint32): + raise NotImplementedError(f"OpenCL atomic_add only supports int32/uint32, got '{var_type.name}'") + + return f"atomic_add(&({mem_expr}), {value_expr})" diff --git a/vkdispatch/codegen/builder.py b/vkdispatch/codegen/builder.py index c3214976..44e50e48 100644 --- a/vkdispatch/codegen/builder.py +++ b/vkdispatch/codegen/builder.py @@ -17,6 +17,15 @@ from .variables.variables import BaseVariable, ShaderVariable, ScaledAndOfftsetIntVariable from .variables.bound_variables import BufferVariable, ImageVariable +_PUSH_CONSTANT_UNSUPPORTED_BACKENDS = set() + + +def _push_constant_not_supported_error(backend_name: str) -> str: + return ( + f"Push Constants are not supported for the {backend_name.upper()} backend. " + "Use Const instead." + ) + @dataclasses.dataclass class SharedBuffer: """ @@ -61,15 +70,22 @@ class ShaderDescription: uniform_structure: List[StructElement] binding_type_list: List[BindingType] binding_access: List[Tuple[bool, bool]] # List of tuples indicating read and write access for each binding - exec_count_name: str + exec_count_name: Optional[str] + resource_binding_base: int backend: Optional[CodeGenBackend] = None def make_source(self, x: int, y: int, z: int) -> str: if self.backend is None: layout_str = f"layout(local_size_x = {x}, local_size_y = {y}, local_size_z = {z}) in;" - return f"{self.header}\n{layout_str}\n{self.body}" + shader_source = f"{self.header}\n{layout_str}\n{self.body}" + else: + shader_source = self.backend.make_source(self.header, self.body, x, y, z) + + # ff = open(f"sources/{self.name}.comp", "w") + # ff.write(shader_source) + # ff.close() - return self.backend.make_source(self.header, self.body, x, y, z) + return shader_source def __repr__(self): description_string = "" @@ -123,7 +139,6 @@ class ShaderBuilder(ShaderWriter): pc_struct: StructBuilder uniform_struct: StructBuilder exec_count: Optional[ShaderVariable] - pre_header: str flags: ShaderFlags backend: CodeGenBackend @@ -140,11 +155,6 @@ def __init__(self, else: # Use the selected backend type while keeping per-builder backend state isolated. self.backend = get_codegen_backend().__class__() - - self.pre_header = self.backend.pre_header( - enable_subgroup_ops=not (self.flags & ShaderFlags.NO_SUBGROUP_OPS), - enable_printf=not (self.flags & ShaderFlags.NO_PRINTF) - ) self.reset() @@ -159,9 +169,10 @@ def reset(self) -> None: self.shared_buffers = [] self.scope_num = 1 - self.exec_count = self.declare_constant(dtypes.uvec4, var_name="exec_count") - + self.exec_count = None + if not (self.flags & ShaderFlags.NO_EXEC_BOUNDS): + self.exec_count = self.declare_constant(dtypes.uvec4, var_name="exec_count") self.append_contents(self.backend.exec_bounds_guard(self.exec_count.resolve())) def new_var(self, @@ -211,6 +222,9 @@ def declare_constant(self, var_type: dtypes.dtype, count: int = 1, var_name: Opt return new_var def declare_variable(self, var_type: dtypes.dtype, count: int = 1, var_name: Optional[str] = None): + if self.backend.name in _PUSH_CONSTANT_UNSUPPORTED_BACKENDS: + raise NotImplementedError(_push_constant_not_supported_error(self.backend.name)) + if var_name is None: var_name = self.new_name() @@ -235,6 +249,10 @@ def declare_buffer(self, var_type: dtypes.dtype, var_name: Optional[str] = None) buffer_name = f"buf{self.binding_count}" if var_name is None else var_name shape_name = f"{buffer_name}_shape" + scalar_expr = None + + if self.backend.name == "opencl" and (dtypes.is_vector(var_type) or dtypes.is_complex(var_type)): + scalar_expr = f"{buffer_name}_scalar" self.binding_list.append(ShaderBinding(var_type, buffer_name, 0, BindingType.STORAGE_BUFFER)) self.binding_read_access[self.binding_count] = False @@ -247,13 +265,18 @@ def read_lambda(): def write_lambda(): self.binding_write_access[current_binding_count] = True + + def shape_var_factory(): + return self.declare_constant(dtypes.ivec4, var_name=shape_name) return BufferVariable( var_type, self.binding_count, f"{buffer_name}.data", - self.declare_constant(dtypes.ivec4, var_name=shape_name), - shape_name, + shape_var_factory=shape_var_factory, + shape_name=shape_name, + scalar_expr=scalar_expr, + codegen_backend=self.backend, read_lambda=read_lambda, write_lambda=write_lambda ) @@ -287,12 +310,17 @@ def shared_buffer(self, var_type: dtypes.dtype, size: int, var_name: Optional[st shape_name = f"{var_name}_shape" + def shape_var_factory(): + return self.declare_constant(dtypes.ivec4, var_name=shape_name) + new_var = BufferVariable( var_type, -1, var_name, - self.declare_constant(dtypes.ivec4, var_name=shape_name), - shape_name, + shape_var_factory=shape_var_factory, + shape_name=shape_name, + scalar_expr=None, + codegen_backend=self.backend, read_lambda=lambda: None, write_lambda=lambda: None ) @@ -316,7 +344,7 @@ def compose_struct_decleration(self, elements: List[StructElement]) -> str: return "\n".join(declerations) def build(self, name: str) -> ShaderDescription: - header = "" + self.pre_header + header = "" for shared_buffer in self.shared_buffers: header += self.backend.shared_buffer_declaration( @@ -328,22 +356,28 @@ def build(self, name: str) -> ShaderDescription: uniform_elements = self.uniform_struct.build() uniform_decleration_contents = self.compose_struct_decleration(uniform_elements) - if len(uniform_decleration_contents) > 0: + has_uniform_buffer = len(uniform_decleration_contents) > 0 + if has_uniform_buffer: header += self.backend.uniform_block_declaration(uniform_decleration_contents) - binding_type_list = [BindingType.UNIFORM_BUFFER] - binding_access = [(True, False)] # UBO is read-only + binding_base = 1 if has_uniform_buffer else 0 + binding_type_list = [] + binding_access = [] + if has_uniform_buffer: + binding_type_list.append(BindingType.UNIFORM_BUFFER) + binding_access.append((True, False)) # UBO is read-only for ii, binding in enumerate(self.binding_list): + emitted_binding = ii + binding_base if binding.binding_type == BindingType.STORAGE_BUFFER: - header += self.backend.storage_buffer_declaration(ii + 1, binding.dtype, binding.name) + header += self.backend.storage_buffer_declaration(emitted_binding, binding.dtype, binding.name) binding_type_list.append(binding.binding_type) binding_access.append(( self.binding_read_access[ii + 1], self.binding_write_access[ii + 1] )) else: - header += self.backend.sampler_declaration(ii + 1, binding.dimension, binding.name) + header += self.backend.sampler_declaration(emitted_binding, binding.dimension, binding.name) binding_type_list.append(binding.binding_type) binding_access.append(( self.binding_read_access[ii + 1], @@ -355,10 +389,18 @@ def build(self, name: str) -> ShaderDescription: pc_decleration_contents = self.compose_struct_decleration(pc_elements) if len(pc_decleration_contents) > 0: + assert self.backend.name not in _PUSH_CONSTANT_UNSUPPORTED_BACKENDS, ( + _push_constant_not_supported_error(self.backend.name) + ) header += self.backend.push_constant_declaration(pc_decleration_contents) + pre_header = self.backend.pre_header( + enable_subgroup_ops=not (self.flags & ShaderFlags.NO_SUBGROUP_OPS), + enable_printf=not (self.flags & ShaderFlags.NO_PRINTF) + ) + return ShaderDescription( - header=header, + header=f"{pre_header}{header}", body=self.backend.entry_point(self.contents), name=name, pc_size=self.pc_struct.size, @@ -366,6 +408,7 @@ def build(self, name: str) -> ShaderDescription: uniform_structure=uniform_elements, binding_type_list=[binding.value for binding in binding_type_list], binding_access=binding_access, - exec_count_name=self.exec_count.raw_name, + exec_count_name=self.exec_count.raw_name if self.exec_count is not None else None, + resource_binding_base=binding_base, backend=self.backend ) diff --git a/vkdispatch/codegen/functions/atomic_memory.py b/vkdispatch/codegen/functions/atomic_memory.py index 000350f7..7efb8590 100644 --- a/vkdispatch/codegen/functions/atomic_memory.py +++ b/vkdispatch/codegen/functions/atomic_memory.py @@ -1,20 +1,76 @@ +from typing import Any, List + +import vkdispatch.base.dtype as dtypes + +from ..variables.base_variable import BaseVariable +from ..variables.bound_variables import BufferVariable from ..variables.variables import ShaderVariable +from . import utils -from typing import Any -# https://docs.vulkan.org/glsl/latest/chapters/builtinfunctions.html#atomic-memory-functions +def _is_buffer_backed_target(var: ShaderVariable) -> bool: + stack: List[BaseVariable] = [var] + visited_ids = set() + + while len(stack) > 0: + current = stack.pop() + current_id = id(current) + if current_id in visited_ids: + continue + visited_ids.add(current_id) + + if isinstance(current, BufferVariable): + return True + + stack.extend(current.parents) + + return False + +# https://docs.vulkan.org/glsl/latest/chapters/builtinfunctions.html#atomic-memory-functions def atomic_add(mem: ShaderVariable, y: Any) -> ShaderVariable: - raise NotImplementedError("atomic_add is not implemented yet") + assert isinstance(mem, ShaderVariable), f"atomic_add target must be a ShaderVariable, got {type(mem)}" + assert dtypes.is_scalar(mem.var_type), "atomic_add target must be a scalar lvalue" + assert mem.is_setable(), "atomic_add target must be a writable lvalue" + assert not mem.is_register(), "atomic_add does not support register/local variables as target" + assert _is_buffer_backed_target(mem), "atomic_add target must reference a buffer element (e.g., buf[idx])" + + assert mem.var_type in (dtypes.int32, dtypes.uint32), ( + f"atomic_add currently supports only int32/uint32 targets, got '{mem.var_type.name}'" + ) + + parents: List[BaseVariable] = [mem] + + if isinstance(y, ShaderVariable): + assert dtypes.is_scalar(y.var_type), "atomic_add increment variable must be scalar" + assert dtypes.is_integer_dtype(y.var_type), ( + f"atomic_add increment variable must be integer-typed, got '{y.var_type.name}'" + ) + y.read_callback() + parents.append(y) + y_expr = utils.backend_constructor(mem.var_type, y) + elif utils.is_int_number(y): + y_expr = utils.backend_constructor(mem.var_type, y) + elif utils.is_number(y): + raise TypeError(f"atomic_add increment must be an integer scalar, got {y!r}") + else: + raise TypeError(f"atomic_add increment must be an integer scalar or ShaderVariable, got {type(y)}") + + mem.read_callback() + mem.write_callback() - # assert isinstance(mem, BaseVariable), "mem must be a BaseVariable" + result_var = utils.new_var( + mem.var_type, + None, + parents=parents, + lexical_unit=True, + settable=True, + register=True + ) - # new_var = self.make_var(arg1.var_type, None, []) - # self.append_contents(f"{new_var.var_type.glsl_type} {new_var.name} = atomicAdd({arg1.resolve()}, {arg2.resolve()});\n") + atomic_expr = utils.codegen_backend().atomic_add_expr(mem.resolve(), y_expr, mem.var_type) + utils.append_contents( + f"{utils.backend_type_name(result_var.var_type)} {result_var.name} = {atomic_expr};\n" + ) - # return mem.new_var( - # mem.var_type, - # f"atomicAdd({mem.resolve()}, {resolve_input(y)})", - # parents=[y, x], - # lexical_unit=True - # ) \ No newline at end of file + return result_var diff --git a/vkdispatch/codegen/functions/base_functions/arithmetic.py b/vkdispatch/codegen/functions/base_functions/arithmetic.py index 4ecab608..79e890e5 100644 --- a/vkdispatch/codegen/functions/base_functions/arithmetic.py +++ b/vkdispatch/codegen/functions/base_functions/arithmetic.py @@ -1,11 +1,11 @@ import vkdispatch.base.dtype as dtypes from vkdispatch.codegen.variables.base_variable import BaseVariable -from typing import Any +from typing import Any, Tuple, Union -from ...._compat import numpy_compat as npc +from .. import scalar_eval as se def my_log2_int(x: int) -> int: - return int(npc.round(npc.log2(x))) + return int(se.round(se.log2(x))) from . import base_utils @@ -18,6 +18,26 @@ def _mark_arith_unary(var: BaseVariable, op: str) -> None: def _mark_arith_binary(lhs_type: dtypes.dtype, rhs_type: dtypes.dtype, op: str, *, inplace: bool = False) -> None: base_utils.get_codegen_backend().mark_composite_binary_op(lhs_type, rhs_type, op, inplace=inplace) +def _resolve_arithmetic_binary_expr( + op: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, +) -> Tuple[str, bool]: + override_expr = base_utils.get_codegen_backend().arithmetic_binary_expr( + op, lhs_type, lhs_expr, rhs_type, rhs_expr + ) + if override_expr is not None: + return override_expr, True + return f"{lhs_expr} {op} {rhs_expr}", False + +def _resolve_arithmetic_unary_expr(op: str, var_type: dtypes.dtype, var_expr: str) -> Tuple[str, bool]: + override_expr = base_utils.get_codegen_backend().arithmetic_unary_expr(op, var_type, var_expr) + if override_expr is not None: + return override_expr, True + return f"{op}{var_expr}", False + def arithmetic_op_common(var: BaseVariable, other: Any, reverse: bool = False, @@ -54,27 +74,55 @@ def add(var: BaseVariable, other: Any, inplace: bool = False) -> BaseVariable: return_type = arithmetic_op_common(var, other, inplace=inplace) if base_utils.is_scalar_number(other): - _mark_arith_binary(var.var_type, base_utils.number_to_dtype(other), "+", inplace=inplace) + scalar_type = base_utils.number_to_dtype(other) + scalar_expr = base_utils.format_number_literal(other) + _mark_arith_binary(var.var_type, scalar_type, "+", inplace=inplace) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "+", + var.var_type, + var.resolve(), + scalar_type, + scalar_expr, + ) if not inplace: + if use_assignment: + return base_utils.new_base_var( + return_type, + expr, + parents=[var], + ) return base_utils.new_scaled_var( return_type, var.resolve(), offset=other, parents=[var]) - base_utils.append_contents(f"{var.resolve()} += {other};\n") + if use_assignment: + base_utils.append_contents(f"{var.resolve()} = {expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} += {scalar_expr};\n") return var assert isinstance(other, BaseVariable) _mark_arith_binary(var.var_type, other.var_type, "+", inplace=inplace) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "+", + var.var_type, + var.resolve(), + other.var_type, + other.resolve(), + ) if not inplace: return base_utils.new_base_var( return_type, - f"{var.resolve()} + {other.resolve()}", + expr, parents=[var, other]) - base_utils.append_contents(f"{var.resolve()} += {other.resolve()};\n") + if use_assignment: + base_utils.append_contents(f"{var.resolve()} = {expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} += {other.resolve()};\n") return var def sub(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: @@ -82,60 +130,103 @@ def sub(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = Fa if base_utils.is_scalar_number(other): scalar_type = base_utils.number_to_dtype(other) + scalar_expr = base_utils.format_number_literal(other) if reverse and not inplace: _mark_arith_unary(var, "-") _mark_arith_binary(var.var_type, scalar_type, "+", inplace=False) else: # Non-reverse scalar subtraction is emitted as `+ (-scalar)` via scaled-var optimization. _mark_arith_binary(var.var_type, scalar_type, "+" if not inplace else "-", inplace=inplace) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "-", + scalar_type if reverse else var.var_type, + scalar_expr if reverse else var.resolve(), + var.var_type if reverse else scalar_type, + var.resolve() if reverse else scalar_expr, + ) if not inplace: + if use_assignment: + return base_utils.new_base_var( + return_type, + expr, + parents=[var], + ) return base_utils.new_scaled_var( return_type, f"(-{var.resolve()})" if reverse else var.resolve(), - offset=other, + offset=other if reverse else -other, parents=[var]) - base_utils.append_contents(f"{var.resolve()} -= {other};\n") + if use_assignment: + base_utils.append_contents(f"{var.resolve()} = {expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} -= {scalar_expr};\n") return var assert isinstance(other, BaseVariable) - _mark_arith_binary(var.var_type if not reverse else other.var_type, other.var_type if not reverse else var.var_type, "-", inplace=inplace) + lhs_type = var.var_type if not reverse else other.var_type + rhs_type = other.var_type if not reverse else var.var_type + _mark_arith_binary(lhs_type, rhs_type, "-", inplace=inplace) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "-", + lhs_type, + var.resolve() if not reverse else other.resolve(), + rhs_type, + other.resolve() if not reverse else var.resolve(), + ) if not inplace: return base_utils.new_base_var( return_type, - ( - f"{var.resolve()} - {other.resolve()}" - if not reverse else - f"{other.resolve()} - {var.resolve()}" - ), + expr, parents=[var, other]) - base_utils.append_contents(f"{var.resolve()} -= {other.resolve()};\n") + if use_assignment: + base_utils.append_contents(f"{var.resolve()} = {expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} -= {other.resolve()};\n") return var def mul(var: BaseVariable, other: Any, inplace: bool = False) -> BaseVariable: - return_type = arithmetic_op_common(var, other, inplace=inplace) - if base_utils.is_scalar_number(other): + return_type = arithmetic_op_common(var, other, inplace=inplace) + scalar_type = base_utils.number_to_dtype(other) + scalar_expr = base_utils.format_number_literal(other) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "*", + var.var_type, + var.resolve(), + scalar_type, + scalar_expr, + ) if not inplace: if other == 1: return var - if dtypes.is_integer_dtype(var.var_type) and base_utils.is_int_number(other) and base_utils.is_int_power_of_2(other): + if ( + not use_assignment + and dtypes.is_integer_dtype(var.var_type) + and base_utils.is_int_number(other) + and base_utils.is_int_power_of_2(other) + ): power = my_log2_int(other) - _mark_arith_binary(var.var_type, base_utils.number_to_dtype(other), "<<", inplace=False) + _mark_arith_binary(var.var_type, scalar_type, "<<", inplace=False) return base_utils.new_base_var(var.var_type, f"{var.resolve()} << {power}", [var]) - _mark_arith_binary(var.var_type, base_utils.number_to_dtype(other), "*", inplace=False) - return base_utils.new_scaled_var( - return_type, - var.resolve(), - scale=other, - parents=[var]) - - _mark_arith_binary(var.var_type, base_utils.number_to_dtype(other), "*", inplace=True) - base_utils.append_contents(f"{var.resolve()} *= {other};\n") + _mark_arith_binary(var.var_type, scalar_type, "*", inplace=False) + if use_assignment: + return base_utils.new_base_var( + return_type, + expr, + parents=[var], + ) + return base_utils.new_scaled_var(return_type, var.resolve(), scale=other, parents=[var]) + + _mark_arith_binary(var.var_type, scalar_type, "*", inplace=True) + if use_assignment: + base_utils.append_contents(f"{var.resolve()} = {expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} *= {scalar_expr};\n") return var assert isinstance(other, BaseVariable) @@ -146,14 +237,32 @@ def mul(var: BaseVariable, other: Any, inplace: bool = False) -> BaseVariable: if dtypes.is_matrix(var.var_type) and dtypes.is_matrix(other.var_type): raise ValueError("Matrix multiplication is not supported via the `*` operator. Use `@` operator instead.") + return_type = dtypes.cross_multiply_type(var.var_type, other.var_type) + if inplace: + assert var.is_setable(), "Inplace arithmetic requires the variable to be settable." + var.read_callback() + var.write_callback() + other.read_callback() + assert return_type == var.var_type, "Inplace arithmetic requires the result type to match the variable type." + _mark_arith_binary(var.var_type, other.var_type, "*", inplace=inplace) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "*", + var.var_type, + var.resolve(), + other.var_type, + other.resolve(), + ) if not inplace: return base_utils.new_base_var( - var.var_type, - f"{var.resolve()} * {other.resolve()}", + return_type, + expr, parents=[var, other]) - base_utils.append_contents(f"{var.resolve()} *= {other.resolve()};\n") + if use_assignment: + base_utils.append_contents(f"{var.resolve()} = {expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} *= {other.resolve()};\n") return var def truediv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: @@ -165,21 +274,39 @@ def truediv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool if base_utils.is_scalar_number(other): scalar_f_type = dtypes.float32 + other_expr = base_utils.format_number_literal(other, force_float32=True) if not reverse: _mark_arith_binary(return_type, scalar_f_type, "/", inplace=inplace) else: _mark_arith_binary(scalar_f_type, return_type, "/", inplace=inplace) + lhs_expr = base_utils.to_dtype_base(return_type, var).resolve() if not reverse else other_expr + rhs_expr = other_expr if not reverse else base_utils.to_dtype_base(return_type, var).resolve() + lhs_type = return_type if not reverse else scalar_f_type + rhs_type = scalar_f_type if not reverse else return_type + expr, use_assignment = _resolve_arithmetic_binary_expr( + "/", + lhs_type, + lhs_expr, + rhs_type, + rhs_expr, + ) if not inplace: return base_utils.new_base_var( return_type, - ( - f"{base_utils.to_dtype_base(return_type, var).resolve()} / {float(other)}" - if not reverse else - f"{float(other)} / {base_utils.to_dtype_base(return_type, var).resolve()}" - ), + expr, parents=[var]) - base_utils.append_contents(f"{var.resolve()} /= {float(other)};\n") + if use_assignment: + inplace_expr, _ = _resolve_arithmetic_binary_expr( + "/", + var.var_type, + var.resolve(), + scalar_f_type, + other_expr, + ) + base_utils.append_contents(f"{var.resolve()} = {inplace_expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} /= {other_expr};\n") return var assert isinstance(other, BaseVariable) @@ -193,17 +320,42 @@ def truediv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool lhs_mark_type = return_type if not reverse else dtypes.make_floating_dtype(other.var_type) rhs_mark_type = dtypes.make_floating_dtype(other.var_type) if not reverse else return_type _mark_arith_binary(lhs_mark_type, rhs_mark_type, "/", inplace=inplace) + + lhs_expr = ( + base_utils.to_dtype_base(lhs_mark_type, var).resolve() + if not reverse else + base_utils.to_dtype_base(lhs_mark_type, other).resolve() + ) + rhs_expr = ( + base_utils.to_dtype_base(rhs_mark_type, other).resolve() + if not reverse else + base_utils.to_dtype_base(rhs_mark_type, var).resolve() + ) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "/", + lhs_mark_type, + lhs_expr, + rhs_mark_type, + rhs_expr, + ) + if not inplace: return base_utils.new_base_var( return_type, - ( - f"{base_utils.to_dtype_base(return_type, var).resolve()} / {base_utils.to_dtype_base(return_type, other).resolve()}" - if not reverse else - f"{base_utils.to_dtype_base(return_type, other).resolve()} / {base_utils.to_dtype_base(return_type, var).resolve()}" - ), + expr, parents=[var, other]) - base_utils.append_contents(f"{var.resolve()} /= {base_utils.to_dtype_base(return_type, other).resolve()};\n") + if use_assignment: + inplace_expr, _ = _resolve_arithmetic_binary_expr( + "/", + var.var_type, + var.resolve(), + rhs_mark_type, + rhs_expr, + ) + base_utils.append_contents(f"{var.resolve()} = {inplace_expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} /= {rhs_expr};\n") return var def floordiv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: @@ -291,43 +443,84 @@ def mod(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = Fa base_utils.append_contents(f"{var.resolve()} %= {other.resolve()};\n") return var -def pow(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: - return_type = arithmetic_op_common(var, other, reverse=reverse, inplace=inplace) - - if base_utils.is_scalar_number(other): - if not inplace: - return base_utils.new_base_var( - return_type, - ( - f"pow({var.resolve()}, {other})" - if not reverse else - f"pow({other}, {var.resolve()})" - ), - parents=[var]) - base_utils.append_contents(f"{var.resolve()} = pow({var.resolve()}, {other});\n") - return var +def pow_expr(x: Any, y: Any) -> Union[BaseVariable, float]: + if base_utils.is_int_number(y) and y == 0: + return 1 + + if base_utils.is_number(y) and base_utils.is_number(x): + return se.power(x, y) + + if base_utils.is_number(x) and isinstance(y, BaseVariable): + result_type = base_utils.dtype_to_floating(y.var_type) + return base_utils.new_base_var( + result_type, + base_utils.get_codegen_backend().binary_math_expr( + "pow", + dtypes.float32, + base_utils.resolve_input(x), + result_type, + y.resolve(), + ), + parents=[y] + ) + + if base_utils.is_number(y) and isinstance(x, BaseVariable): + result_type = base_utils.dtype_to_floating(x.var_type) - assert isinstance(other, BaseVariable) + if base_utils.is_int_number(y) and x.is_register(): + if y > 0 and y <= 4: + expr = " * ".join([x.resolve()] * int(y)) + return base_utils.new_base_var(result_type, expr, parents=[x]) + elif y < 0 and y >= -4: + expr = " * ".join([x.resolve()] * int(-y)) + return base_utils.new_base_var(result_type, f"1 / ({expr})", parents=[x]) - if not inplace: return base_utils.new_base_var( - return_type, - ( - f"pow({var.resolve()}, {other.resolve()})" - if not reverse else - f"pow({other.resolve()}, {var.resolve()})" + result_type, + base_utils.get_codegen_backend().binary_math_expr( + "pow", + result_type, + x.resolve(), + dtypes.float32, + base_utils.resolve_input(y), ), - parents=[var, other]) + parents=[x] + ) + + assert isinstance(y, BaseVariable), "First argument must be a ShaderVariable or number" + assert isinstance(x, BaseVariable), "Second argument must be a ShaderVariable or number" + + result_type = base_utils.dtype_to_floating(dtypes.cross_type(x.var_type, y.var_type)) + return base_utils.new_base_var( + result_type, + base_utils.get_codegen_backend().binary_math_expr( + "pow", + base_utils.dtype_to_floating(x.var_type), + x.resolve(), + base_utils.dtype_to_floating(y.var_type), + y.resolve(), + ), + parents=[y, x], + lexical_unit=True + ) + +def pow(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: + _ = arithmetic_op_common(var, other, reverse=reverse, inplace=inplace) + experession = pow_expr(other, var) if reverse else pow_expr(var, other) + + if not inplace: + return experession - base_utils.append_contents(f"{var.resolve()} = pow({var.resolve()}, {other.resolve()});\n") + base_utils.append_contents(f"{var.resolve()} = {experession};\n") return var def neg(var: BaseVariable) -> BaseVariable: _mark_arith_unary(var, "-") + expr, _ = _resolve_arithmetic_unary_expr("-", var.var_type, var.resolve()) return base_utils.new_base_var( var.var_type, - f"-{var.resolve()}", + expr, parents=[var]) def absolute(var: BaseVariable) -> BaseVariable: diff --git a/vkdispatch/codegen/functions/base_functions/base_utils.py b/vkdispatch/codegen/functions/base_functions/base_utils.py index a6daaf5f..7a5d7d71 100644 --- a/vkdispatch/codegen/functions/base_functions/base_utils.py +++ b/vkdispatch/codegen/functions/base_functions/base_utils.py @@ -4,13 +4,18 @@ from typing import Any, Optional import numbers +import math -from ...._compat import numpy_compat as npc +from ....compat import numpy_compat as npc from vkdispatch.codegen.shader_writer import new_scaled_var, append_contents, new_name from vkdispatch.codegen.global_builder import get_codegen_backend from vkdispatch.codegen.shader_writer import new_var as new_var_impl +_I32_MIN = -(2 ** 31) +_I32_MAX = 2 ** 31 - 1 +_U32_MAX = 2 ** 32 - 1 + def new_base_var(var_type: dtypes.dtype, var_name: Optional[str], parents: list, @@ -45,9 +50,13 @@ def is_int_power_of_2(n: int) -> bool: def number_to_dtype(number: numbers.Number): if is_int_number(number): if number >= 0: - return dtypes.uint32 + if number <= _U32_MAX: + return dtypes.uint32 + return dtypes.uint64 - return dtypes.int32 + if number >= _I32_MIN and number <= _I32_MAX: + return dtypes.int32 + return dtypes.int64 elif is_float_number(number): return dtypes.float32 elif is_complex_number(number): @@ -62,33 +71,65 @@ def check_is_int(variable): return npc.is_integer_scalar(variable) def dtype_to_floating(var_type: dtypes.dtype) -> dtypes.dtype: - if var_type == dtypes.int32 or var_type == dtypes.uint32: - return dtypes.float32 + return dtypes.make_floating_dtype(var_type) + +def _inf_scalar_type(var_type: dtypes.dtype) -> dtypes.dtype: + """Extract the scalar float type from any dtype.""" + if dtypes.is_complex(var_type): + return var_type.child_type + if dtypes.is_vector(var_type) or dtypes.is_matrix(var_type): + return var_type.scalar + return var_type + +def format_number_literal(var: numbers.Number, *, force_float32: bool = False, dtype: Optional[dtypes.dtype] = None) -> str: + if is_complex_number(var): + return str(var) - if var_type == dtypes.ivec2 or var_type == dtypes.uvec2: - return dtypes.vec2 + if is_float_number(var) or (force_float32 and is_int_number(var)): + value = float(var) - if var_type == dtypes.ivec3 or var_type == dtypes.uvec3: - return dtypes.vec3 - - if var_type == dtypes.ivec4 or var_type == dtypes.uvec4: - return dtypes.vec4 - - return var_type + if math.isinf(value): + backend = get_codegen_backend() + scalar = _inf_scalar_type(dtype) if dtype is not None else dtypes.float32 + if scalar is dtypes.float64: + return backend.inf_f64_expr() if value > 0 else backend.ninf_f64_expr() + if scalar is dtypes.float16: + return backend.inf_f16_expr() if value > 0 else backend.ninf_f16_expr() + return backend.inf_f32_expr() if value > 0 else backend.ninf_f32_expr() + + if math.isnan(value): + return "(0.0f / 0.0f)" + + literal = repr(value) + if "e" not in literal and "E" not in literal and "." not in literal: + literal += ".0" + return literal + "f" -def resolve_input(var: Any) -> str: + return str(var) + +def resolve_input(var: Any, dtype: Optional[dtypes.dtype] = None) -> str: #print("Resolving input:", var) if is_number(var): - return str(var) - + return format_number_literal(var, dtype=dtype) + assert isinstance(var, BaseVariable), "Argument must be a ShaderVariable or number" return var.resolve() +def resolve_input_type(var: Any) -> Optional[dtypes.dtype]: + if is_number(var): + return number_to_dtype(var) + + if isinstance(var, BaseVariable): + return var.var_type + + return None + def backend_constructor(var_type: dtypes.dtype, *args) -> str: return get_codegen_backend().constructor( var_type, - [resolve_input(elem) for elem in args] + [resolve_input(elem, dtype=var_type) for elem in args], + arg_types=[resolve_input_type(elem) for elem in args], ) def to_dtype_base(var_type: dtypes.dtype, *args): diff --git a/vkdispatch/codegen/functions/block_synchonization.py b/vkdispatch/codegen/functions/block_synchonization.py index ca0da11c..3deccc45 100644 --- a/vkdispatch/codegen/functions/block_synchonization.py +++ b/vkdispatch/codegen/functions/block_synchonization.py @@ -1,4 +1,4 @@ -from ..global_builder import get_builder +from ..global_builder import get_builder, get_codegen_backend from . import utils @@ -6,7 +6,7 @@ def barrier(): # On Apple devices, a memory barrier is required before a barrier # to ensure memory operations are visible to all threads # (for some reason) - if get_builder().is_apple_device: + if get_builder().is_apple_device and get_codegen_backend().name == "glsl": memory_barrier() utils.append_contents(utils.codegen_backend().barrier_statement() + "\n") diff --git a/vkdispatch/codegen/functions/builtin_constants.py b/vkdispatch/codegen/functions/builtin_constants.py index f023fdb6..47812331 100644 --- a/vkdispatch/codegen/functions/builtin_constants.py +++ b/vkdispatch/codegen/functions/builtin_constants.py @@ -17,6 +17,38 @@ def ninf_f32(): lexical_unit=True ) +def inf_f64(): + return utils.new_var( + dtypes.float64, + utils.codegen_backend().inf_f64_expr(), + [], + lexical_unit=True + ) + +def ninf_f64(): + return utils.new_var( + dtypes.float64, + utils.codegen_backend().ninf_f64_expr(), + [], + lexical_unit=True + ) + +def inf_f16(): + return utils.new_var( + dtypes.float16, + utils.codegen_backend().inf_f16_expr(), + [], + lexical_unit=True + ) + +def ninf_f16(): + return utils.new_var( + dtypes.float16, + utils.codegen_backend().ninf_f16_expr(), + [], + lexical_unit=True + ) + def global_invocation_id(): return utils.new_var( dtypes.uvec3, diff --git a/vkdispatch/codegen/functions/common_builtins.py b/vkdispatch/codegen/functions/common_builtins.py index 741d590a..e801bdda 100644 --- a/vkdispatch/codegen/functions/common_builtins.py +++ b/vkdispatch/codegen/functions/common_builtins.py @@ -3,7 +3,7 @@ from typing import Any, Union, Tuple from . import utils -from ..._compat import numpy_compat as npc +from . import scalar_eval as se def comment(comment: str, preceding_new_line: bool = True) -> None: comment_text = str(comment).replace("\r\n", "\n").replace("\r", "\n") @@ -45,7 +45,7 @@ def abs(var: Any) -> Union[ShaderVariable, float]: def sign(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.sign(var) + return se.sign(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" @@ -58,7 +58,7 @@ def sign(var: Any) -> Union[ShaderVariable, float]: def floor(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.floor(var) + return se.floor(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" @@ -71,7 +71,7 @@ def floor(var: Any) -> Union[ShaderVariable, float]: def ceil(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.ceil(var) + return se.ceil(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" @@ -84,7 +84,7 @@ def ceil(var: Any) -> Union[ShaderVariable, float]: def trunc(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.trunc(var) + return se.trunc(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" @@ -97,7 +97,7 @@ def trunc(var: Any) -> Union[ShaderVariable, float]: def round(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.round(var) + return se.round(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" @@ -110,7 +110,7 @@ def round(var: Any) -> Union[ShaderVariable, float]: def round_even(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.round(var) + return se.round(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" utils.mark_backend_feature("roundEven") @@ -124,7 +124,7 @@ def round_even(var: Any) -> Union[ShaderVariable, float]: def fract(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return float(var - npc.floor(var)) + return float(var - se.floor(var)) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" utils.mark_backend_feature("fract") @@ -138,7 +138,7 @@ def fract(var: Any) -> Union[ShaderVariable, float]: def mod(x: Any, y: Any) -> Union[ShaderVariable, float]: if utils.is_number(y) and utils.is_number(x): - return npc.mod(x, y) + return se.mod(x, y) base_var = None @@ -160,14 +160,14 @@ def mod(x: Any, y: Any) -> Union[ShaderVariable, float]: def modf(x: Any, y: Any) -> Tuple[ShaderVariable, ShaderVariable]: if utils.is_number(y) and utils.is_number(x): - a, b = npc.modf(x, y) + a, b = se.modf(x, y) return float(a), float(b) if utils.is_number(x) and isinstance(y, ShaderVariable): utils.mark_backend_feature("mod") return utils.new_var( utils.dtype_to_floating(y.var_type), - f"mod({x}, {y.resolve()})", + f"mod({utils.resolve_input(x)}, {y.resolve()})", parents=[y] ) @@ -175,7 +175,7 @@ def modf(x: Any, y: Any) -> Tuple[ShaderVariable, ShaderVariable]: utils.mark_backend_feature("mod") return utils.new_var( utils.dtype_to_floating(x.var_type), - f"mod({x.resolve()}, {y})", + f"mod({x.resolve()}, {utils.resolve_input(y)})", parents=[x] ) @@ -192,7 +192,7 @@ def modf(x: Any, y: Any) -> Tuple[ShaderVariable, ShaderVariable]: def min(x: Any, y: Any) -> Union[ShaderVariable, float]: if utils.is_number(y) and utils.is_number(x): - return npc.minimum(x, y) + return se.minimum(x, y) base_var = None @@ -212,7 +212,7 @@ def min(x: Any, y: Any) -> Union[ShaderVariable, float]: def max(x: Any, y: Any) -> Union[ShaderVariable, float]: if utils.is_number(y) and utils.is_number(x): - return npc.maximum(x, y) + return se.maximum(x, y) base_var = None @@ -232,7 +232,7 @@ def max(x: Any, y: Any) -> Union[ShaderVariable, float]: def clip(x: Any, min_val: Any, max_val: Any) -> Union[ShaderVariable, float]: if utils.is_number(x) and utils.is_number(min_val) and utils.is_number(max_val): - return npc.clip(x, min_val, max_val) + return se.clip(x, min_val, max_val) base_var = None @@ -257,7 +257,7 @@ def clamp(x: Any, min_val: Any, max_val: Any) -> Union[ShaderVariable, float]: def mix(x: Any, y: Any, a: Any) -> Union[ShaderVariable, float]: if utils.is_number(y) and utils.is_number(x) and utils.is_number(a): - return npc.interp(a, [0, 1], [x, y]) + return se.interp(a, [0, 1], [x, y]) base_var = None @@ -303,7 +303,7 @@ def step(edge: Any, x: Any) -> Union[ShaderVariable, float]: def smoothstep(edge0: Any, edge1: Any, x: Any) -> Union[ShaderVariable, float]: if utils.is_number(edge0) and utils.is_number(edge1) and utils.is_number(x): - t = npc.clip((x - edge0) / (edge1 - edge0), 0.0, 1.0) + t = se.clip((x - edge0) / (edge1 - edge0), 0.0, 1.0) return float(t * t * (3.0 - 2.0 * t)) base_var = None @@ -328,7 +328,7 @@ def smoothstep(edge0: Any, edge1: Any, x: Any) -> Union[ShaderVariable, float]: def isnan(var: Any) -> Union[ShaderVariable, bool]: if utils.is_number(var): - return npc.isnan(var) + return se.isnan(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" @@ -341,7 +341,7 @@ def isnan(var: Any) -> Union[ShaderVariable, bool]: def isinf(var: Any) -> Union[ShaderVariable, bool]: if utils.is_number(var): - return npc.isinf(var) + return se.isinf(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" @@ -354,7 +354,7 @@ def isinf(var: Any) -> Union[ShaderVariable, bool]: def float_bits_to_int(var: Any) -> Union[ShaderVariable, int]: if utils.is_number(var): - return npc.float_bits_to_int(var) + return se.float_bits_to_int(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" @@ -367,7 +367,7 @@ def float_bits_to_int(var: Any) -> Union[ShaderVariable, int]: def float_bits_to_uint(var: Any) -> Union[ShaderVariable, int]: if utils.is_number(var): - return npc.float_bits_to_uint(var) + return se.float_bits_to_uint(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" @@ -380,7 +380,7 @@ def float_bits_to_uint(var: Any) -> Union[ShaderVariable, int]: def int_bits_to_float(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.int_bits_to_float(var) + return se.int_bits_to_float(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" @@ -393,7 +393,7 @@ def int_bits_to_float(var: Any) -> Union[ShaderVariable, float]: def uint_bits_to_float(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.uint_bits_to_float(var) + return se.uint_bits_to_float(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" diff --git a/vkdispatch/codegen/functions/complex_numbers.py b/vkdispatch/codegen/functions/complex_numbers.py index af6a33ce..e99f3d7b 100644 --- a/vkdispatch/codegen/functions/complex_numbers.py +++ b/vkdispatch/codegen/functions/complex_numbers.py @@ -4,28 +4,32 @@ from .common_builtins import fma -from .type_casting import to_complex +from .type_casting import to_complex, to_dtype from . import utils from .trigonometry import cos, sin def complex_from_euler_angle(angle: ShaderVariable): - return to_complex(cos(angle), sin(angle)) + if not isinstance(angle, ShaderVariable): + raise TypeError("complex_from_euler_angle expects a ShaderVariable angle") + + target_complex_type = dtypes.complex_from_float(dtypes.make_floating_dtype(angle.var_type)) + return to_dtype(target_complex_type, cos(angle), sin(angle)) def validate_complex_number(arg1: Any) -> Union[ShaderVariable, complex]: if isinstance(arg1, ShaderVariable): - assert arg1.var_type == dtypes.complex64, "Input variables to complex multiplication must be complex" + assert dtypes.is_complex(arg1.var_type), "Input variables to complex multiplication must be complex" return arg1 assert utils.is_number(arg1), "Argument must be ShaderVariable or number" return complex(arg1) -def _new_big_complex(arg1: Any, arg2: Any): - var_str = utils.backend_constructor(dtypes.complex64, arg1, arg2) +def _new_big_complex(var_type: dtypes.dtype, arg1: Any, arg2: Any): + var_str = utils.backend_constructor(var_type, arg1, arg2) return utils.new_var( - dtypes.complex64, + var_type, var_str, [utils.resolve_input(arg1), utils.resolve_input(arg2)], lexical_unit=True @@ -35,4 +39,19 @@ def mult_complex(arg1: ShaderVariable, arg2: ShaderVariable): a1 = validate_complex_number(arg1) a2 = validate_complex_number(arg2) - return _new_big_complex(fma(a1.real, a2.real, -a1.imag * a2.imag), fma(a1.real, a2.imag, a1.imag * a2.real)) + fallback_type = dtypes.complex64 + for normalized_arg in (a1, a2): + if isinstance(normalized_arg, ShaderVariable): + fallback_type = normalized_arg.var_type + break + + result_type = None + for normalized_arg in (a1, a2): + arg_type = normalized_arg.var_type if isinstance(normalized_arg, ShaderVariable) else fallback_type + result_type = arg_type if result_type is None else dtypes.cross_type(result_type, arg_type) + + return _new_big_complex( + result_type, # type: ignore[arg-type] + fma(a1.real, a2.real, -a1.imag * a2.imag), + fma(a1.real, a2.imag, a1.imag * a2.real), + ) diff --git a/vkdispatch/codegen/functions/control_flow.py b/vkdispatch/codegen/functions/control_flow.py index 107627c3..88fcad45 100644 --- a/vkdispatch/codegen/functions/control_flow.py +++ b/vkdispatch/codegen/functions/control_flow.py @@ -52,8 +52,18 @@ def else_if_all(*args: List[ShaderVariable]): utils.scope_increment() def return_statement(arg=None): - arg = arg if arg is not None else "" - utils.append_contents(f"return {arg};\n") + if arg is None: + utils.append_contents("return;\n") + return + + if isinstance(arg, str): + arg_expr = arg + elif isinstance(arg, ShaderVariable) or utils.is_number(arg): + arg_expr = utils.resolve_input(arg) + else: + arg_expr = str(arg) + + utils.append_contents(f"return {arg_expr};\n") def while_statement(arg: ShaderVariable): utils.append_contents(f"while({proc_bool(arg)}) {'{'}\n") @@ -75,7 +85,7 @@ def end(indent: bool = True): utils.append_contents("}\n") def logical_and(arg1: ShaderVariable, arg2: ShaderVariable): - return utils.new_var(dtypes.int32, f"({arg1} && {arg2})", [arg1, arg2]) + return utils.new_var(dtypes.int32, f"({proc_bool(arg1)} && {proc_bool(arg2)})", [arg1, arg2]) def logical_or(arg1: ShaderVariable, arg2: ShaderVariable): - return utils.new_var(dtypes.int32, f"({arg1} || {arg2})", [arg1, arg2]) \ No newline at end of file + return utils.new_var(dtypes.int32, f"({proc_bool(arg1)} || {proc_bool(arg2)})", [arg1, arg2]) diff --git a/vkdispatch/codegen/functions/exponential.py b/vkdispatch/codegen/functions/exponential.py index 1b67e6b4..68b2ebc6 100644 --- a/vkdispatch/codegen/functions/exponential.py +++ b/vkdispatch/codegen/functions/exponential.py @@ -1,105 +1,149 @@ +import vkdispatch.base.dtype as dtypes from ..variables.variables import ShaderVariable from typing import Any, Union from . import utils -from ..._compat import numpy_compat as npc +from . import scalar_eval as se + +def _is_glsl_backend() -> bool: + return utils.codegen_backend().name == "glsl" + +def _is_float64_dtype(var_type: dtypes.dtype) -> bool: + if dtypes.is_scalar(var_type): + return var_type == dtypes.float64 + + if dtypes.is_vector(var_type): + return var_type.scalar == dtypes.float64 + + return False + +def _float64_to_float32_dtype(var_type: dtypes.dtype) -> dtypes.dtype: + if var_type == dtypes.float64: + return dtypes.float32 + + if dtypes.is_vector(var_type) and var_type.scalar == dtypes.float64: + return dtypes.to_vector(dtypes.float32, var_type.child_count) + + raise TypeError(f"Unsupported fp64 fallback dtype: {var_type}") + +def _needs_glsl_float64_trig_fallback(var_type: dtypes.dtype) -> bool: + return _is_glsl_backend() and _is_float64_dtype(var_type) + +def process_float_var(var: ShaderVariable) -> bool: + pass + +def _unary_math_var(func_name: str, var: ShaderVariable) -> ShaderVariable: + result_type = utils.dtype_to_floating(var.var_type) + expr_arg_type = result_type + expr_arg = var.resolve() + expr_result_type = result_type + + if _needs_glsl_float64_trig_fallback(result_type) and func_name in {"exp", "exp2", "log", "log2"}: + expr_arg_type = _float64_to_float32_dtype(result_type) + expr_result_type = expr_arg_type + expr_arg = utils.backend_constructor_from_resolved(expr_arg_type, [expr_arg]) + + expr = utils.codegen_backend().unary_math_expr(func_name, expr_result_type, expr_arg) + + if expr_result_type != result_type: + expr = utils.backend_constructor_from_resolved(result_type, [expr]) + + return utils.new_var( + result_type, + expr, + parents=[var], + lexical_unit=True + ) def pow(x: Any, y: Any) -> Union[ShaderVariable, float]: if utils.is_number(y) and utils.is_number(x): - return npc.power(x, y) + return se.power(x, y) if utils.is_number(x) and isinstance(y, ShaderVariable): + result_type = utils.dtype_to_floating(y.var_type) return utils.new_var( - utils.dtype_to_floating(y.var_type), - f"pow({x}, {y.resolve()})", + result_type, + utils.codegen_backend().binary_math_expr( + "pow", + dtypes.float32, + utils.resolve_input(x), + result_type, + y.resolve(), + ), parents=[y] ) if utils.is_number(y) and isinstance(x, ShaderVariable): + result_type = utils.dtype_to_floating(x.var_type) return utils.new_var( - utils.dtype_to_floating(x.var_type), - f"pow({x.resolve()}, {y})", + result_type, + utils.codegen_backend().binary_math_expr( + "pow", + result_type, + x.resolve(), + dtypes.float32, + utils.resolve_input(y), + ), parents=[x] ) assert isinstance(y, ShaderVariable), "First argument must be a ShaderVariable or number" assert isinstance(x, ShaderVariable), "Second argument must be a ShaderVariable or number" + result_type = utils.dtype_to_floating(dtypes.cross_type(x.var_type, y.var_type)) return utils.new_var( - utils.dtype_to_floating(y.var_type), - f"pow({x.resolve()}, {y.resolve()})", + result_type, + utils.codegen_backend().binary_math_expr( + "pow", + utils.dtype_to_floating(x.var_type), + x.resolve(), + utils.dtype_to_floating(y.var_type), + y.resolve(), + ), parents=[y, x], lexical_unit=True ) def exp(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.exp(var) + return se.exp(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - utils.dtype_to_floating(var.var_type), - f"exp({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("exp", var) def exp2(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.exp2(var) + return se.exp2(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - utils.dtype_to_floating(var.var_type), - f"exp2({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("exp2", var) def log(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.log(var) + return se.log(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - utils.dtype_to_floating(var.var_type), - f"log({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("log", var) def log2(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.log2(var) + return se.log2(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("log2", var) - return utils.new_var( - utils.dtype_to_floating(var.var_type), - f"log2({var.resolve()})", - parents=[var], - lexical_unit=True - ) - +# has double def sqrt(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.sqrt(var) + return se.sqrt(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("sqrt", var) - return utils.new_var( - utils.dtype_to_floating(var.var_type), - f"sqrt({var.resolve()})", - parents=[var], - lexical_unit=True - ) - +# has double def inversesqrt(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return float(1.0 / npc.sqrt(var)) + return float(1.0 / se.sqrt(var)) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" utils.mark_backend_feature("inversesqrt") diff --git a/vkdispatch/codegen/functions/geometric.py b/vkdispatch/codegen/functions/geometric.py index 7e6fa864..6992a8ad 100644 --- a/vkdispatch/codegen/functions/geometric.py +++ b/vkdispatch/codegen/functions/geometric.py @@ -3,11 +3,11 @@ from typing import Any, Union from . import utils -from ..._compat import numpy_compat as npc +from . import scalar_eval as se def length(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.abs_value(var) + return se.abs_value(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" @@ -20,7 +20,7 @@ def length(var: Any) -> Union[ShaderVariable, float]: def distance(x: Any, y: Any) -> Union[ShaderVariable, float]: if utils.is_number(y) and utils.is_number(x): - return npc.abs_value(y - x) + return se.abs_value(y - x) base_var = None @@ -40,7 +40,7 @@ def distance(x: Any, y: Any) -> Union[ShaderVariable, float]: def dot(x: Any, y: Any) -> Union[ShaderVariable, float]: if utils.is_number(y) and utils.is_number(x): - return npc.dot(x, y) + return se.dot(x, y) base_var = None diff --git a/vkdispatch/codegen/functions/registers.py b/vkdispatch/codegen/functions/registers.py index 1aa2a622..64387ef1 100644 --- a/vkdispatch/codegen/functions/registers.py +++ b/vkdispatch/codegen/functions/registers.py @@ -4,7 +4,7 @@ from . import utils -from .type_casting import to_dtype, to_complex +from .type_casting import to_dtype, to_complex, to_complex32, to_complex64, to_complex128 def new_register(var_type: dtypes.dtype, *args, var_name: Optional[str] = None): new_var = utils.new_var( @@ -29,22 +29,65 @@ def new_register(var_type: dtypes.dtype, *args, var_name: Optional[str] = None): return new_var +def new_float16_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.float16, *args, var_name=var_name) + def new_float_register(*args, var_name: Optional[str] = None): return new_register(dtypes.float32, *args, var_name=var_name) +def new_float64_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.float64, *args, var_name=var_name) + +def new_int16_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.int16, *args, var_name=var_name) + def new_int_register(*args, var_name: Optional[str] = None): return new_register(dtypes.int32, *args, var_name=var_name) +def new_int64_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.int64, *args, var_name=var_name) + +def new_uint16_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uint16, *args, var_name=var_name) + def new_uint_register(*args, var_name: Optional[str] = None): return new_register(dtypes.uint32, *args, var_name=var_name) -def new_complex_register(*args, var_name: Optional[str] = None): +def new_uint64_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uint64, *args, var_name=var_name) + +def _new_complex_register(var_type: dtypes.dtype, complex_ctor, *args, var_name: Optional[str] = None): if len(args) > 0: - true_args = (to_complex(*args),) + true_args = (complex_ctor(*args),) else: true_args = (0,) - return new_register(dtypes.complex64, *true_args, var_name=var_name) + return new_register(var_type, *true_args, var_name=var_name) + +def new_complex_register(*args, var_name: Optional[str] = None): + if len(args) == 0: + return new_register(dtypes.complex64, 0, var_name=var_name) + + complex_value = to_complex(*args) + return new_register(complex_value.var_type, complex_value, var_name=var_name) + +def new_complex32_register(*args, var_name: Optional[str] = None): + return _new_complex_register(dtypes.complex32, to_complex32, *args, var_name=var_name) + +def new_complex64_register(*args, var_name: Optional[str] = None): + return _new_complex_register(dtypes.complex64, to_complex64, *args, var_name=var_name) + +def new_complex128_register(*args, var_name: Optional[str] = None): + return _new_complex_register(dtypes.complex128, to_complex128, *args, var_name=var_name) + +def new_hvec2_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.hvec2, *args, var_name=var_name) + +def new_hvec3_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.hvec3, *args, var_name=var_name) + +def new_hvec4_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.hvec4, *args, var_name=var_name) def new_vec2_register(*args, var_name: Optional[str] = None): return new_register(dtypes.vec2, *args, var_name=var_name) @@ -55,6 +98,33 @@ def new_vec3_register(*args, var_name: Optional[str] = None): def new_vec4_register(*args, var_name: Optional[str] = None): return new_register(dtypes.vec4, *args, var_name=var_name) +def new_dvec2_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.dvec2, *args, var_name=var_name) + +def new_dvec3_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.dvec3, *args, var_name=var_name) + +def new_dvec4_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.dvec4, *args, var_name=var_name) + +def new_ihvec2_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.ihvec2, *args, var_name=var_name) + +def new_ihvec3_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.ihvec3, *args, var_name=var_name) + +def new_ihvec4_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.ihvec4, *args, var_name=var_name) + +def new_uhvec2_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uhvec2, *args, var_name=var_name) + +def new_uhvec3_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uhvec3, *args, var_name=var_name) + +def new_uhvec4_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uhvec4, *args, var_name=var_name) + def new_uvec2_register(*args, var_name: Optional[str] = None): return new_register(dtypes.uvec2, *args, var_name=var_name) diff --git a/vkdispatch/codegen/functions/scalar_eval.py b/vkdispatch/codegen/functions/scalar_eval.py new file mode 100644 index 00000000..5d406ba2 --- /dev/null +++ b/vkdispatch/codegen/functions/scalar_eval.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import builtins +import math +import struct + +from typing import Any, Sequence, Tuple + + +def sign(value: float) -> float: + if value > 0: + return 1.0 + if value < 0: + return -1.0 + return 0.0 + + +def floor(value: float) -> float: + return float(math.floor(value)) + + +def ceil(value: float) -> float: + return float(math.ceil(value)) + + +def trunc(value: float) -> float: + return float(math.trunc(value)) + + +def round(value: float) -> float: + return float(builtins.round(value)) + + +def abs_value(value: Any) -> float: + return float(abs(value)) + + +def mod(x: float, y: float) -> float: + return float(x % y) + + +def modf(x: float, _unused: Any = None) -> Tuple[float, float]: + frac, whole = math.modf(x) + return float(frac), float(whole) + + +def minimum(x: float, y: float) -> float: + return float(x if x <= y else y) + + +def maximum(x: float, y: float) -> float: + return float(x if x >= y else y) + + +def clip(x: float, min_value: float, max_value: float) -> float: + return float(min(max(x, min_value), max_value)) + + +def interp(x: float, xp: Sequence[float], fp: Sequence[float]) -> float: + if len(xp) != len(fp): + raise ValueError("xp and fp must have the same length") + if len(xp) == 0: + raise ValueError("xp and fp must be non-empty") + if len(xp) == 1: + return float(fp[0]) + + if x <= xp[0]: + return float(fp[0]) + if x >= xp[-1]: + return float(fp[-1]) + + for index in range(1, len(xp)): + if x <= xp[index]: + x0 = xp[index - 1] + x1 = xp[index] + y0 = fp[index - 1] + y1 = fp[index] + + if x1 == x0: + return float(y0) + + t = (x - x0) / (x1 - x0) + return float(y0 + t * (y1 - y0)) + + return float(fp[-1]) + + +def isnan(value: float) -> bool: + return math.isnan(value) + + +def isinf(value: float) -> bool: + return math.isinf(value) + + +def float_bits_to_int(value: float) -> int: + return int(struct.unpack("=i", struct.pack("=f", float(value)))[0]) + + +def float_bits_to_uint(value: float) -> int: + return int(struct.unpack("=I", struct.pack("=f", float(value)))[0]) + + +def int_bits_to_float(value: int) -> float: + return float(struct.unpack("=f", struct.pack("=i", int(value)))[0]) + + +def uint_bits_to_float(value: int) -> float: + return float(struct.unpack("=f", struct.pack("=I", int(value)))[0]) + + +def power(x: float, y: float) -> float: + return float(math.pow(x, y)) + + +def exp(value: float) -> float: + return float(math.exp(value)) + + +def exp2(value: float) -> float: + if hasattr(math, "exp2"): + return float(math.exp2(value)) + return float(math.pow(2.0, value)) + + +def log(value: float) -> float: + return float(math.log(value)) + + +def log2(value: float) -> float: + return float(math.log2(value)) + + +def sqrt(value: float) -> float: + return float(math.sqrt(value)) + + +def sin(value: float) -> float: + return float(math.sin(value)) + + +def cos(value: float) -> float: + return float(math.cos(value)) + + +def tan(value: float) -> float: + return float(math.tan(value)) + + +def arcsin(value: float) -> float: + return float(math.asin(value)) + + +def arccos(value: float) -> float: + return float(math.acos(value)) + + +def arctan(value: float) -> float: + return float(math.atan(value)) + + +def arctan2(y: float, x: float) -> float: + return float(math.atan2(y, x)) + + +def sinh(value: float) -> float: + return float(math.sinh(value)) + + +def cosh(value: float) -> float: + return float(math.cosh(value)) + + +def tanh(value: float) -> float: + return float(math.tanh(value)) + + +def arcsinh(value: float) -> float: + return float(math.asinh(value)) + + +def arccosh(value: float) -> float: + return float(math.acosh(value)) + + +def arctanh(value: float) -> float: + return float(math.atanh(value)) + + +def dot(x: Any, y: Any) -> float: + if isinstance(x, (int, float, complex)) and isinstance(y, (int, float, complex)): + return float(x * y) + + return float(sum(a * b for a, b in zip(x, y))) diff --git a/vkdispatch/codegen/functions/subgroups.py b/vkdispatch/codegen/functions/subgroups.py index 477d3f53..23f90952 100644 --- a/vkdispatch/codegen/functions/subgroups.py +++ b/vkdispatch/codegen/functions/subgroups.py @@ -4,25 +4,60 @@ from . import utils def subgroup_add(arg1: ShaderVariable): - return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_add_expr(arg1.resolve()), [arg1], lexical_unit=True) + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_add_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) def subgroup_mul(arg1: ShaderVariable): - return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_mul_expr(arg1.resolve()), [arg1], lexical_unit=True) + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_mul_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) def subgroup_min(arg1: ShaderVariable): - return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_min_expr(arg1.resolve()), [arg1], lexical_unit=True) + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_min_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) def subgroup_max(arg1: ShaderVariable): - return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_max_expr(arg1.resolve()), [arg1], lexical_unit=True) + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_max_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) def subgroup_and(arg1: ShaderVariable): - return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_and_expr(arg1.resolve()), [arg1], lexical_unit=True) + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_and_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) def subgroup_or(arg1: ShaderVariable): - return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_or_expr(arg1.resolve()), [arg1], lexical_unit=True) + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_or_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) def subgroup_xor(arg1: ShaderVariable): - return utils.new_var(arg1.var_type, utils.codegen_backend().subgroup_xor_expr(arg1.resolve()), [arg1], lexical_unit=True) + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_xor_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) def subgroup_elect(): return utils.new_var(dtypes.int32, utils.codegen_backend().subgroup_elect_expr(), [], lexical_unit=True) diff --git a/vkdispatch/codegen/functions/trigonometry.py b/vkdispatch/codegen/functions/trigonometry.py index 504f25cc..19251db1 100644 --- a/vkdispatch/codegen/functions/trigonometry.py +++ b/vkdispatch/codegen/functions/trigonometry.py @@ -1,233 +1,250 @@ import vkdispatch.base.dtype as dtypes from ..variables.variables import ShaderVariable -from typing import Any, Union +from typing import Any, List, Union from . import utils -from ..._compat import numpy_compat as npc +from . import scalar_eval as se def dtype_to_floating(var_type: dtypes.dtype) -> dtypes.dtype: - if var_type == dtypes.int32 or var_type == dtypes.uint32: + return dtypes.make_floating_dtype(var_type) + +def _is_glsl_backend() -> bool: + return utils.codegen_backend().name == "glsl" + +def _is_float64_dtype(var_type: dtypes.dtype) -> bool: + if dtypes.is_scalar(var_type): + return var_type == dtypes.float64 + + if dtypes.is_vector(var_type): + return var_type.scalar == dtypes.float64 + + return False + +def _float64_to_float32_dtype(var_type: dtypes.dtype) -> dtypes.dtype: + if var_type == dtypes.float64: return dtypes.float32 - if var_type == dtypes.ivec2 or var_type == dtypes.uvec2: - return dtypes.vec2 + if dtypes.is_vector(var_type) and var_type.scalar == dtypes.float64: + return dtypes.to_vector(dtypes.float32, var_type.child_count) - if var_type == dtypes.ivec3 or var_type == dtypes.uvec3: - return dtypes.vec3 - - if var_type == dtypes.ivec4 or var_type == dtypes.uvec4: - return dtypes.vec4 - - return var_type + raise TypeError(f"Unsupported fp64 fallback dtype: {var_type}") -def radians(var: Any) -> Union[ShaderVariable, float]: - if utils.is_number(var): - return var * (3.141592653589793 / 180.0) +def _needs_glsl_float64_trig_fallback(var_type: dtypes.dtype) -> bool: + return _is_glsl_backend() and _is_float64_dtype(var_type) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - utils.mark_backend_feature("radians") +def _cast_expr(var_type: dtypes.dtype, expr: str) -> str: + return utils.backend_constructor_from_resolved(var_type, [expr]) + +def _unary_math_var(func_name: str, var: ShaderVariable) -> ShaderVariable: + result_type = dtype_to_floating(var.var_type) + expr_arg_type = result_type + expr_arg = var.resolve() + expr_result_type = result_type + + if _needs_glsl_float64_trig_fallback(result_type): + expr_arg_type = _float64_to_float32_dtype(result_type) + expr_result_type = expr_arg_type + expr_arg = _cast_expr(expr_arg_type, expr_arg) + + expr = utils.codegen_backend().unary_math_expr(func_name, expr_result_type, expr_arg) + + if expr_result_type != result_type: + expr = _cast_expr(result_type, expr) return utils.new_var( - dtype_to_floating(var.var_type), - f"radians({var.resolve()})", + result_type, + expr, parents=[var], lexical_unit=True ) +def _binary_math_var( + func_name: str, + result_type: dtypes.dtype, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + parents: List[ShaderVariable], + *, + lexical_unit: bool = False, +) -> ShaderVariable: + expr_result_type = result_type + expr_lhs_type = lhs_type + expr_rhs_type = rhs_type + expr_lhs = lhs_expr + expr_rhs = rhs_expr + + if _needs_glsl_float64_trig_fallback(result_type): + expr_result_type = _float64_to_float32_dtype(result_type) + + if _is_float64_dtype(lhs_type): + expr_lhs_type = _float64_to_float32_dtype(lhs_type) + expr_lhs = _cast_expr(expr_lhs_type, lhs_expr) + + if _is_float64_dtype(rhs_type): + expr_rhs_type = _float64_to_float32_dtype(rhs_type) + expr_rhs = _cast_expr(expr_rhs_type, rhs_expr) + + expr = utils.codegen_backend().binary_math_expr( + func_name, + expr_lhs_type, + expr_lhs, + expr_rhs_type, + expr_rhs, + ) + + if expr_result_type != result_type: + expr = _cast_expr(result_type, expr) + + return utils.new_var( + result_type, + expr, + parents=parents, + lexical_unit=lexical_unit, + ) + +def radians(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return var * (3.141592653589793 / 180.0) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + utils.mark_backend_feature("radians") + return _unary_math_var("radians", var) + def degrees(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return var * (180.0 / 3.141592653589793) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" utils.mark_backend_feature("degrees") - - return utils.new_var( - dtype_to_floating(var.var_type), - f"degrees({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("degrees", var) def sin(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.sin(var) + return se.sin(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - dtype_to_floating(var.var_type), - f"sin({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("sin", var) def cos(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.cos(var) + return se.cos(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - dtype_to_floating(var.var_type), - f"cos({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("cos", var) def tan(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.tan(var) + return se.tan(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - dtype_to_floating(var.var_type), - f"tan({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("tan", var) def asin(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.arcsin(var) + return se.arcsin(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - dtype_to_floating(var.var_type), - f"asin({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("asin", var) def acos(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.arccos(var) + return se.arccos(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - dtype_to_floating(var.var_type), - f"acos({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("acos", var) def atan(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.arctan(var) + return se.arctan(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - dtype_to_floating(var.var_type), - f"atan({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("atan", var) def atan2(y: Any, x: Any) -> Union[ShaderVariable, float]: if utils.is_number(y) and utils.is_number(x): - return npc.arctan2(y, x) + return se.arctan2(y, x) if utils.is_number(x) and isinstance(y, ShaderVariable): - return utils.new_var( - dtype_to_floating(y.var_type), - f"atan({y.resolve()}, {x})", - parents=[y] + result_type = dtype_to_floating(y.var_type) + scalar_result_type = result_type.scalar if dtypes.is_vector(result_type) else result_type + return _binary_math_var( + "atan2", + result_type, + result_type, + y.resolve(), + scalar_result_type, + utils.resolve_input(x), + [y], ) if utils.is_number(y) and isinstance(x, ShaderVariable): - return utils.new_var( - dtype_to_floating(x.var_type), - f"atan({y}, {x.resolve()})", - parents=[x] + result_type = dtype_to_floating(x.var_type) + scalar_result_type = result_type.scalar if dtypes.is_vector(result_type) else result_type + return _binary_math_var( + "atan2", + result_type, + scalar_result_type, + utils.resolve_input(y), + result_type, + x.resolve(), + [x], ) assert isinstance(y, ShaderVariable), "First argument must be a ShaderVariable or number" assert isinstance(x, ShaderVariable), "Second argument must be a ShaderVariable or number" - return utils.new_var( - dtype_to_floating(y.var_type), - f"atan({y.resolve()}, {x.resolve()})", - parents=[y, x], - lexical_unit=True + result_type = dtype_to_floating(dtypes.cross_type(y.var_type, x.var_type)) + return _binary_math_var( + "atan2", + result_type, + result_type, + y.resolve(), + dtype_to_floating(x.var_type), + x.resolve(), + [y, x], + lexical_unit=True, ) def sinh(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.sinh(var) + return se.sinh(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - dtype_to_floating(var.var_type), - f"sinh({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("sinh", var) def cosh(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.cosh(var) + return se.cosh(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - dtype_to_floating(var.var_type), - f"cosh({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("cosh", var) def tanh(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.tanh(var) + return se.tanh(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - dtype_to_floating(var.var_type), - f"tanh({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("tanh", var) def asinh(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.arcsinh(var) + return se.arcsinh(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - dtype_to_floating(var.var_type), - f"asinh({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("asinh", var) def acosh(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.arccosh(var) + return se.arccosh(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - dtype_to_floating(var.var_type), - f"acosh({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("acosh", var) def atanh(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): - return npc.arctanh(var) + return se.arctanh(var) assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" - - return utils.new_var( - dtype_to_floating(var.var_type), - f"atanh({var.resolve()})", - parents=[var], - lexical_unit=True - ) + return _unary_math_var("atanh", var) diff --git a/vkdispatch/codegen/functions/type_casting.py b/vkdispatch/codegen/functions/type_casting.py index d70d894f..276a479a 100644 --- a/vkdispatch/codegen/functions/type_casting.py +++ b/vkdispatch/codegen/functions/type_casting.py @@ -2,6 +2,7 @@ from typing import Optional from . import utils +from ..variables.variables import ShaderVariable def to_dtype(var_type: dtypes.dtype, *args): return utils.new_var( @@ -26,22 +27,88 @@ def str_to_dtype(var_type: dtypes.dtype, register=register ) +def to_float16(*args): + return to_dtype(dtypes.float16, *args) + def to_float(*args): return to_dtype(dtypes.float32, *args) +def to_float64(*args): + return to_dtype(dtypes.float64, *args) + +def to_int16(*args): + return to_dtype(dtypes.int16, *args) + def to_int(*args): return to_dtype(dtypes.int32, *args) +def to_int64(*args): + return to_dtype(dtypes.int64, *args) + +def to_uint16(*args): + return to_dtype(dtypes.uint16, *args) + def to_uint(*args): return to_dtype(dtypes.uint32, *args) -def to_complex(*args): +def to_uint64(*args): + return to_dtype(dtypes.uint64, *args) + +def _complex_from_real_arg(arg) -> dtypes.dtype: + if isinstance(arg, ShaderVariable): + if dtypes.is_complex(arg.var_type): + return arg.var_type + if dtypes.is_scalar(arg.var_type): + return dtypes.complex_from_float(dtypes.make_floating_dtype(arg.var_type)) + raise TypeError(f"Unsupported variable type for complex conversion: {arg.var_type}") + + if utils.is_number(arg): + base_type = utils.number_to_dtype(arg) + if dtypes.is_complex(base_type): + return base_type + return dtypes.complex_from_float(dtypes.make_floating_dtype(base_type)) + + raise TypeError(f"Unsupported argument type for complex conversion: {type(arg)}") + +def _infer_complex_dtype(*args) -> dtypes.dtype: + complex_type = _complex_from_real_arg(args[0]) + + for arg in args[1:]: + complex_type = dtypes.cross_type(complex_type, _complex_from_real_arg(arg)) + + return complex_type + +def _to_complex_dtype(var_type: dtypes.dtype, *args): assert len(args) == 1 or len(args) == 2, "Must give one of two arguments for complex init" + if len(args) == 1 and isinstance(args[0], ShaderVariable) and dtypes.is_complex(args[0].var_type): + return to_dtype(var_type, args[0]) + if len(args) == 1: - return to_dtype(dtypes.complex64, args[0], 0) + return to_dtype(var_type, args[0], 0) + + return to_dtype(var_type, *args) + +def to_complex32(*args): + return _to_complex_dtype(dtypes.complex32, *args) + +def to_complex(*args): + return _to_complex_dtype(_infer_complex_dtype(*args), *args) + +def to_complex64(*args): + return _to_complex_dtype(dtypes.complex64, *args) + +def to_complex128(*args): + return _to_complex_dtype(dtypes.complex128, *args) + +def to_hvec2(*args): + return to_dtype(dtypes.hvec2, *args) + +def to_hvec3(*args): + return to_dtype(dtypes.hvec3, *args) - return to_dtype(dtypes.complex64, *args) +def to_hvec4(*args): + return to_dtype(dtypes.hvec4, *args) def to_vec2(*args): return to_dtype(dtypes.vec2, *args) @@ -52,14 +119,23 @@ def to_vec3(*args): def to_vec4(*args): return to_dtype(dtypes.vec4, *args) -def to_uvec2(*args): - return to_dtype(dtypes.uvec2, *args) +def to_dvec2(*args): + return to_dtype(dtypes.dvec2, *args) -def to_uvec3(*args): - return to_dtype(dtypes.uvec3, *args) +def to_dvec3(*args): + return to_dtype(dtypes.dvec3, *args) -def to_uvec4(*args): - return to_dtype(dtypes.uvec4, *args) +def to_dvec4(*args): + return to_dtype(dtypes.dvec4, *args) + +def to_ihvec2(*args): + return to_dtype(dtypes.ihvec2, *args) + +def to_ihvec3(*args): + return to_dtype(dtypes.ihvec3, *args) + +def to_ihvec4(*args): + return to_dtype(dtypes.ihvec4, *args) def to_ivec2(*args): return to_dtype(dtypes.ivec2, *args) @@ -70,6 +146,24 @@ def to_ivec3(*args): def to_ivec4(*args): return to_dtype(dtypes.ivec4, *args) +def to_uhvec2(*args): + return to_dtype(dtypes.uhvec2, *args) + +def to_uhvec3(*args): + return to_dtype(dtypes.uhvec3, *args) + +def to_uhvec4(*args): + return to_dtype(dtypes.uhvec4, *args) + +def to_uvec2(*args): + return to_dtype(dtypes.uvec2, *args) + +def to_uvec3(*args): + return to_dtype(dtypes.uvec3, *args) + +def to_uvec4(*args): + return to_dtype(dtypes.uvec4, *args) + def to_mat2(*args): return to_dtype(dtypes.mat2, *args) diff --git a/vkdispatch/codegen/functions/utils.py b/vkdispatch/codegen/functions/utils.py index 85879d48..ddb866fb 100644 --- a/vkdispatch/codegen/functions/utils.py +++ b/vkdispatch/codegen/functions/utils.py @@ -1,6 +1,6 @@ import vkdispatch.base.dtype as dtypes from ..variables.variables import ShaderVariable -from typing import List +from typing import List, Optional from .base_functions.base_utils import * from ..global_builder import get_codegen_backend @@ -24,11 +24,33 @@ def mark_backend_feature(feature_name: str) -> None: def backend_type_name(var_type: dtypes.dtype) -> str: return codegen_backend().type_name(var_type) +def _resolve_arg_types(args: tuple) -> List[Optional[dtypes.dtype]]: + resolved_types: List[Optional[dtypes.dtype]] = [] + + for elem in args: + if isinstance(elem, ShaderVariable): + resolved_types.append(elem.var_type) + continue + + if is_number(elem): + resolved_types.append(number_to_dtype(elem)) + continue + + resolved_types.append(None) + + return resolved_types + def backend_constructor(var_type: dtypes.dtype, *args) -> str: + resolved_types = _resolve_arg_types(args) return codegen_backend().constructor( var_type, - [resolve_input(elem) for elem in args] + [resolve_input(elem) for elem in args], + arg_types=resolved_types, ) -def backend_constructor_from_resolved(var_type: dtypes.dtype, args: List[str]) -> str: - return codegen_backend().constructor(var_type, args) +def backend_constructor_from_resolved( + var_type: dtypes.dtype, + args: List[str], + arg_types: Optional[List[Optional[dtypes.dtype]]] = None, +) -> str: + return codegen_backend().constructor(var_type, args, arg_types=arg_types) diff --git a/vkdispatch/codegen/global_builder.py b/vkdispatch/codegen/global_builder.py index 204cd425..8a14b1b9 100644 --- a/vkdispatch/codegen/global_builder.py +++ b/vkdispatch/codegen/global_builder.py @@ -1,7 +1,8 @@ import threading import vkdispatch.base.dtype as dtypes from .shader_writer import set_shader_writer -from .backends import CodeGenBackend, GLSLBackend, CUDABackend +from .backends import CodeGenBackend, GLSLBackend, CUDABackend, OpenCLBackend +from vkdispatch.base.init import is_cuda, is_opencl from typing import Optional, TYPE_CHECKING, Union if TYPE_CHECKING: @@ -11,16 +12,12 @@ _shader_print_line_numbers = threading.local() _codegen_backend = threading.local() - def _make_runtime_default_codegen_backend() -> CodeGenBackend: - try: - from vkdispatch.base.backend import BACKEND_PYCUDA, get_active_backend_name + if is_cuda(): + return CUDABackend() - if get_active_backend_name() == BACKEND_PYCUDA: - return CUDABackend() - except Exception: - # If runtime backend metadata is unavailable, fall back to GLSL. - pass + if is_opencl(): + return OpenCLBackend() return GLSLBackend() @@ -52,6 +49,10 @@ def set_codegen_backend(backend: Optional[Union[CodeGenBackend, str]]): _codegen_backend.active_backend = CUDABackend() return + if backend_name == "opencl": + _codegen_backend.active_backend = OpenCLBackend() + return + raise ValueError(f"Unknown codegen backend '{backend}'") _codegen_backend.active_backend = backend diff --git a/vkdispatch/codegen/variables/bound_variables.py b/vkdispatch/codegen/variables/bound_variables.py index 5c6a25e4..228ff299 100644 --- a/vkdispatch/codegen/variables/bound_variables.py +++ b/vkdispatch/codegen/variables/bound_variables.py @@ -2,6 +2,7 @@ import vkdispatch.base.dtype as dtypes from ..functions import type_casting +from ..functions.base_functions import base_utils from ..global_builder import get_codegen_backend from typing import Callable, Optional @@ -21,14 +22,19 @@ def __init__(self, class BufferVariable(BoundVariable): read_lambda: Callable[[], None] write_lambda: Callable[[], None] + scalar_expr: Optional[str] + codegen_backend: Optional[object] def __init__(self, var_type: dtypes.dtype, binding: int, name: str, shape_var: "ShaderVariable" = None, + shape_var_factory: Optional[Callable[[], "ShaderVariable"]] = None, shape_name: Optional[str] = None, raw_name: Optional[str] = None, + scalar_expr: Optional[str] = None, + codegen_backend: Optional[object] = None, read_lambda: Callable[[], None] = None, write_lambda: Callable[[], None] = None, ) -> None: @@ -41,17 +47,62 @@ def __init__(self, self.read_lambda = read_lambda self.write_lambda = write_lambda - self.shape = shape_var + self._shape_var = shape_var + self._shape_var_factory = shape_var_factory self.shape_name = shape_name + self.scalar_expr = scalar_expr + self.codegen_backend = codegen_backend self.can_index = True self.use_child_type = False + @property + def shape(self) -> "ShaderVariable": + if self._shape_var is None: + assert self._shape_var_factory is not None, "Buffer shape variable factory is not available!" + self._shape_var = self._shape_var_factory() + + return self._shape_var + def read_callback(self): self.read_lambda() def write_callback(self): self.write_lambda() + def __getitem__(self, index) -> "ShaderVariable": + assert self.can_index, f"Variable '{self.resolve()}' of type '{self.var_type.name}' cannot be indexed into!" + + return_type = self.var_type.child_type if self.use_child_type else self.var_type + + if isinstance(index, tuple): + assert len(index) == 1, "Only single index is supported, cannot use multi-dimentional indexing!" + index = index[0] + + if base_utils.is_int_number(index): + return ShaderVariable( + return_type, + f"{self.resolve()}[{index}]", + parents=[self], + settable=self.settable, + lexical_unit=True, + buffer_root=self, + buffer_index_expr=str(index), + ) + + assert isinstance(index, ShaderVariable), f"Index must be a ShaderVariable or int type, not {type(index)}!" + assert dtypes.is_scalar(index.var_type), "Indexing variable must be a scalar!" + assert dtypes.is_integer_dtype(index.var_type), "Indexing variable must be an integer type!" + + return ShaderVariable( + return_type, + f"{self.resolve()}[{index.resolve()}]", + parents=[self, index], + settable=self.settable, + lexical_unit=True, + buffer_root=self, + buffer_index_expr=index.resolve(), + ) + class ImageVariable(BoundVariable): dimensions: int = 0 read_lambda: Callable[[], None] @@ -89,14 +140,27 @@ def sample(self, coord: "ShaderVariable", lod: "ShaderVariable" = None) -> "Shad if self.dimensions == 1: sample_coord_string = f"((({coord.resolve()}) + 0.5) / {backend.texture_size_expr(self.resolve(), 0, self.dimensions)})" elif self.dimensions == 2: - coord_expr = backend.constructor(dtypes.vec2, [f"{coord.resolve()}.x", f"{coord.resolve()}.y"]) + coord_expr = backend.constructor( + dtypes.vec2, + [ + backend.component_access_expr(coord.resolve(), "x", coord.var_type), + backend.component_access_expr(coord.resolve(), "y", coord.var_type), + ] + ) tex_size_expr = backend.constructor( dtypes.vec2, [backend.texture_size_expr(self.resolve(), 0, self.dimensions)] ) sample_coord_string = f"(({coord_expr} + 0.5) / {tex_size_expr})" elif self.dimensions == 3: - coord_expr = backend.constructor(dtypes.vec3, [f"{coord.resolve()}.x", f"{coord.resolve()}.y", f"{coord.resolve()}.z"]) + coord_expr = backend.constructor( + dtypes.vec3, + [ + backend.component_access_expr(coord.resolve(), "x", coord.var_type), + backend.component_access_expr(coord.resolve(), "y", coord.var_type), + backend.component_access_expr(coord.resolve(), "z", coord.var_type), + ] + ) tex_size_expr = backend.constructor( dtypes.vec3, [backend.texture_size_expr(self.resolve(), 0, self.dimensions)] diff --git a/vkdispatch/codegen/variables/variables.py b/vkdispatch/codegen/variables/variables.py index 3bebd883..e8e776ee 100644 --- a/vkdispatch/codegen/variables/variables.py +++ b/vkdispatch/codegen/variables/variables.py @@ -13,24 +13,14 @@ ENABLE_SCALED_AND_OFFSET_INT = True def var_types_to_floating(var_type: dtypes.dtype) -> dtypes.dtype: - if var_type == dtypes.int32 or var_type == dtypes.uint32: - return dtypes.float32 - - if var_type == dtypes.ivec2 or var_type == dtypes.uvec2: - return dtypes.vec2 - - if var_type == dtypes.ivec3 or var_type == dtypes.uvec3: - return dtypes.vec3 - - if var_type == dtypes.ivec4 or var_type == dtypes.uvec4: - return dtypes.vec4 - - return var_type + return dtypes.make_floating_dtype(var_type) class ShaderVariable(BaseVariable): _initilized: bool is_complex: bool is_conjugate: Optional[bool] + buffer_root: Optional["ShaderVariable"] + buffer_index_expr: Optional[str] def __init__(self, var_type: dtypes.dtype, @@ -40,7 +30,9 @@ def __init__(self, settable: bool = False, register: bool = False, parents: List["ShaderVariable"] = None, - is_conjugate: bool = False + is_conjugate: bool = False, + buffer_root: Optional["ShaderVariable"] = None, + buffer_index_expr: Optional[str] = None, ) -> None: super().__setattr__("_initilized", False) @@ -56,6 +48,8 @@ def __init__(self, self.is_complex = False self.is_conjugate = None + self.buffer_root = buffer_root + self.buffer_index_expr = buffer_index_expr if dtypes.is_complex(self.var_type): self.can_index = True @@ -80,6 +74,28 @@ def __init__(self, self._initilized = True + def _buffer_component_expr(self, component_index_expr: str) -> Optional[str]: + if self.buffer_root is None or self.buffer_index_expr is None: + return None + + if not (dtypes.is_vector(self.var_type) or dtypes.is_complex(self.var_type)): + return None + + scalar_expr = getattr(self.buffer_root, "scalar_expr", None) + if scalar_expr is None: + return None + + backend = getattr(self.buffer_root, "codegen_backend", None) + if backend is None: + backend = get_codegen_backend() + + return backend.buffer_component_expr( + scalar_expr, + self.var_type, + self.buffer_index_expr, + component_index_expr, + ) + def __getitem__(self, index) -> "ShaderVariable": assert self.can_index, f"Variable '{self.resolve()}' of type '{self.var_type.name}' cannot be indexed into!" @@ -90,11 +106,31 @@ def __getitem__(self, index) -> "ShaderVariable": index = index[0] if base_utils.is_int_number(index): + component_expr = self._buffer_component_expr(str(index)) + if component_expr is not None: + return ShaderVariable( + return_type, + component_expr, + parents=[self], + settable=self.settable, + lexical_unit=True + ) + return ShaderVariable(return_type, f"{self.resolve()}[{index}]", parents=[self], settable=self.settable, lexical_unit=True) assert isinstance(index, ShaderVariable), f"Index must be a ShaderVariable or int type, not {type(index)}!" assert dtypes.is_scalar(index.var_type), "Indexing variable must be a scalar!" assert dtypes.is_integer_dtype(index.var_type), "Indexing variable must be an integer type!" + + component_expr = self._buffer_component_expr(index.resolve()) + if component_expr is not None: + return ShaderVariable( + return_type, + component_expr, + parents=[self, index], + settable=self.settable, + lexical_unit=True + ) return ShaderVariable(return_type, f"{self.resolve()}[{index.resolve()}]", parents=[self, index], settable=self.settable, lexical_unit=True) @@ -109,15 +145,18 @@ def swizzle(self, components: str) -> "ShaderVariable": sample_type = self.var_type if dtypes.is_scalar(self.var_type) else self.var_type.child_type return_type = sample_type if len(components) == 1 else dtypes.to_vector(sample_type, len(components)) + backend = get_codegen_backend() + base_expr = self.resolve() if dtypes.is_scalar(self.var_type): assert all(c == 'x' for c in components), f"Cannot swizzle scalar variable '{self.resolve()}' with components other than 'x'!" - swizzle_expr = f"{self.resolve()}.x" + scalar_x_expr = backend.component_access_expr(base_expr, "x", self.var_type) + swizzle_expr = scalar_x_expr if len(components) > 1: - swizzle_expr = get_codegen_backend().constructor( + swizzle_expr = backend.constructor( return_type, - [f"{self.resolve()}.x" for _ in components] + [scalar_x_expr for _ in components] ) return ShaderVariable( @@ -125,8 +164,8 @@ def swizzle(self, components: str) -> "ShaderVariable": name=swizzle_expr, parents=[self], lexical_unit=True, - settable=self.settable, - register=self.register + settable=self.settable and len(components) == 1, + register=self.register and len(components) == 1 ) if self.var_type.shape[0] < 4: @@ -138,11 +177,24 @@ def swizzle(self, components: str) -> "ShaderVariable": if self.var_type.shape[0] < 2: assert 'y' not in components, f"Cannot swizzle variable '{self.resolve()}' of type '{self.var_type.name}' with component 'y'!" - swizzle_expr = f"{self.resolve()}.{components}" + if len(components) == 1: + component_index = "xyzw".index(components) + component_expr = self._buffer_component_expr(str(component_index)) + if component_expr is not None: + return ShaderVariable( + var_type=return_type, + name=component_expr, + parents=[self], + lexical_unit=True, + settable=self.settable, + register=self.register + ) + + swizzle_expr = backend.component_access_expr(base_expr, components, self.var_type) if len(components) > 1: - swizzle_expr = get_codegen_backend().constructor( + swizzle_expr = backend.constructor( return_type, - [f"{self.resolve()}.{elem}" for elem in components] + [backend.component_access_expr(base_expr, elem, self.var_type) for elem in components] ) return ShaderVariable( @@ -150,8 +202,8 @@ def swizzle(self, components: str) -> "ShaderVariable": name=swizzle_expr, parents=[self], lexical_unit=True, - settable=self.settable, - register=self.register + settable=self.settable and len(components) == 1, + register=self.register and len(components) == 1 ) def conjugate(self) -> "ShaderVariable": @@ -175,16 +227,19 @@ def set_value(self, value: "ShaderVariable") -> None: self.read_callback() if base_utils.is_number(value): - if self.var_type == dtypes.complex64: + if dtypes.is_complex(self.var_type): complex_value = complex(value) complex_constructor = get_codegen_backend().constructor( - dtypes.complex64, - [str(complex_value.real), str(complex_value.imag)] + self.var_type, + [ + base_utils.format_number_literal(complex_value.real), + base_utils.format_number_literal(complex_value.imag), + ] ) base_utils.append_contents(f"{self.resolve()} = {complex_constructor};\n") return - base_utils.append_contents(f"{self.resolve()} = {value};\n") + base_utils.append_contents(f"{self.resolve()} = {base_utils.format_number_literal(value)};\n") return assert self.var_type == value.var_type, f"Cannot set variable of type '{self.var_type.name}' to value of type '{value.var_type.name}'!" @@ -218,6 +273,9 @@ def __setattr__(self, name: str, value: "ShaderVariable") -> "ShaderVariable": self.imag.set_value(value) return + + if dtypes.is_complex(self.var_type) and (name == "x" or name == "y"): + raise ValueError(f"Cannot set attribute '{name}' of complex variable '{self.resolve()}', use 'real' and 'imag' instead!") if dtypes.is_vector(self.var_type) and (name == "x" or name == "y" or name == "z" or name == "w"): if name == "x": @@ -254,7 +312,7 @@ def to_register(self, var_name: str = None) -> "ShaderVariable": def to_dtype(self, var_type: dtypes.dtype) -> "ShaderVariable": return base_utils.new_base_var( var_type, - get_codegen_backend().constructor(var_type, [self.resolve()]), + get_codegen_backend().constructor(var_type, [self.resolve()], arg_types=[self.var_type]), [self], lexical_unit=True ) @@ -316,16 +374,18 @@ def __init__(self, offset: int = 0, parents: List["ShaderVariable"] = None ) -> None: + # ShaderVariable.__init__ eagerly creates vector swizzles (`x`, `y`, ...), + # which call resolve() during construction. Pre-seed these fields so + # ScaledAndOfftsetIntVariable.resolve() is safe before super().__init__ completes. + object.__setattr__(self, "base_name", str(name)) + object.__setattr__(self, "scale", scale) + object.__setattr__(self, "offset", offset) super().__init__(var_type, name, parents=parents) - - self.base_name = str(name) - self.scale = scale - self.offset = offset def new_from_self(self, scale: int = 1, offset: int = 0): child_vartype = self.var_type - if isinstance(scale, float) or isinstance(offset, float): + if base_utils.is_float_number(scale) or base_utils.is_float_number(offset): child_vartype = var_types_to_floating(self.var_type) return ScaledAndOfftsetIntVariable( @@ -337,8 +397,14 @@ def new_from_self(self, scale: int = 1, offset: int = 0): ) def resolve(self) -> str: - scale_str = f" * {self.scale}" if self.scale != 1 else "" - offset_str = f" + {self.offset}" if self.offset != 0 else "" + scale_str = ( + f" * {base_utils.format_number_literal(self.scale)}" + if self.scale != 1 else "" + ) + offset_str = ( + f" + {base_utils.format_number_literal(self.offset)}" + if self.offset != 0 else "" + ) if scale_str == "" and offset_str == "": return self.base_name diff --git a/vkdispatch/_compat/__init__.py b/vkdispatch/compat/__init__.py similarity index 100% rename from vkdispatch/_compat/__init__.py rename to vkdispatch/compat/__init__.py diff --git a/vkdispatch/_compat/numpy_compat.py b/vkdispatch/compat/numpy_compat.py similarity index 56% rename from vkdispatch/_compat/numpy_compat.py rename to vkdispatch/compat/numpy_compat.py index 62e9dbf9..7d42ab43 100644 --- a/vkdispatch/_compat/numpy_compat.py +++ b/vkdispatch/compat/numpy_compat.py @@ -48,245 +48,11 @@ def ceil(value: float) -> float: return float(_np.ceil(value)) return float(math.ceil(value)) - -def floor(value: float) -> float: - if HAS_NUMPY: - return float(_np.floor(value)) - return float(math.floor(value)) - - -def trunc(value: float) -> float: - if HAS_NUMPY: - return float(_np.trunc(value)) - return float(math.trunc(value)) - - def round(value: float) -> float: if HAS_NUMPY: return float(_np.round(value)) return float(builtins.round(value)) - -def sign(value: float) -> float: - if HAS_NUMPY: - return float(_np.sign(value)) - - if value > 0: - return 1.0 - if value < 0: - return -1.0 - return 0.0 - - -def abs_value(value: Any) -> float: - if HAS_NUMPY: - return float(_np.abs(value)) - return float(abs(value)) - - -def minimum(x: float, y: float) -> float: - if HAS_NUMPY: - return float(_np.minimum(x, y)) - return float(x if x <= y else y) - - -def maximum(x: float, y: float) -> float: - if HAS_NUMPY: - return float(_np.maximum(x, y)) - return float(x if x >= y else y) - - -def clip(x: float, min_value: float, max_value: float) -> float: - if HAS_NUMPY: - return float(_np.clip(x, min_value, max_value)) - return float(min(max(x, min_value), max_value)) - - -def mod(x: float, y: float) -> float: - if HAS_NUMPY: - return float(_np.mod(x, y)) - return float(x % y) - - -def modf(x: float, _unused: Any = None) -> Tuple[float, float]: - if HAS_NUMPY: - frac, whole = _np.modf(x) - return float(frac), float(whole) - - frac, whole = math.modf(x) - return float(frac), float(whole) - - -def interp(x: float, xp: Sequence[float], fp: Sequence[float]) -> float: - if HAS_NUMPY: - return float(_np.interp(x, xp, fp)) - - if len(xp) != len(fp): - raise ValueError("xp and fp must have the same length") - if len(xp) == 0: - raise ValueError("xp and fp must be non-empty") - if len(xp) == 1: - return float(fp[0]) - - if x <= xp[0]: - return float(fp[0]) - if x >= xp[-1]: - return float(fp[-1]) - - for index in range(1, len(xp)): - if x <= xp[index]: - x0 = xp[index - 1] - x1 = xp[index] - y0 = fp[index - 1] - y1 = fp[index] - - if x1 == x0: - return float(y0) - - t = (x - x0) / (x1 - x0) - return float(y0 + t * (y1 - y0)) - - return float(fp[-1]) - - -def isnan(value: float) -> bool: - if HAS_NUMPY: - return bool(_np.isnan(value)) - return math.isnan(value) - - -def isinf(value: float) -> bool: - if HAS_NUMPY: - return bool(_np.isinf(value)) - return math.isinf(value) - - -def power(x: float, y: float) -> float: - if HAS_NUMPY: - return float(_np.power(x, y)) - return float(math.pow(x, y)) - - -def exp(value: float) -> float: - if HAS_NUMPY: - return float(_np.exp(value)) - return float(math.exp(value)) - - -def exp2(value: float) -> float: - if HAS_NUMPY: - return float(_np.exp2(value)) - if hasattr(math, "exp2"): - return float(math.exp2(value)) - return float(math.pow(2.0, value)) - - -def log(value: float) -> float: - if HAS_NUMPY: - return float(_np.log(value)) - return float(math.log(value)) - - -def log2(value: float) -> float: - if HAS_NUMPY: - return float(_np.log2(value)) - return float(math.log2(value)) - - -def sqrt(value: float) -> float: - if HAS_NUMPY: - return float(_np.sqrt(value)) - return float(math.sqrt(value)) - - -def sin(value: float) -> float: - if HAS_NUMPY: - return float(_np.sin(value)) - return float(math.sin(value)) - - -def cos(value: float) -> float: - if HAS_NUMPY: - return float(_np.cos(value)) - return float(math.cos(value)) - - -def tan(value: float) -> float: - if HAS_NUMPY: - return float(_np.tan(value)) - return float(math.tan(value)) - - -def arcsin(value: float) -> float: - if HAS_NUMPY: - return float(_np.arcsin(value)) - return float(math.asin(value)) - - -def arccos(value: float) -> float: - if HAS_NUMPY: - return float(_np.arccos(value)) - return float(math.acos(value)) - - -def arctan(value: float) -> float: - if HAS_NUMPY: - return float(_np.arctan(value)) - return float(math.atan(value)) - - -def arctan2(y: float, x: float) -> float: - if HAS_NUMPY: - return float(_np.arctan2(y, x)) - return float(math.atan2(y, x)) - - -def sinh(value: float) -> float: - if HAS_NUMPY: - return float(_np.sinh(value)) - return float(math.sinh(value)) - - -def cosh(value: float) -> float: - if HAS_NUMPY: - return float(_np.cosh(value)) - return float(math.cosh(value)) - - -def tanh(value: float) -> float: - if HAS_NUMPY: - return float(_np.tanh(value)) - return float(math.tanh(value)) - - -def arcsinh(value: float) -> float: - if HAS_NUMPY: - return float(_np.arcsinh(value)) - return float(math.asinh(value)) - - -def arccosh(value: float) -> float: - if HAS_NUMPY: - return float(_np.arccosh(value)) - return float(math.acosh(value)) - - -def arctanh(value: float) -> float: - if HAS_NUMPY: - return float(_np.arctanh(value)) - return float(math.atanh(value)) - - -def dot(x: Any, y: Any) -> float: - if HAS_NUMPY: - return float(_np.dot(x, y)) - - if isinstance(x, (int, float, complex)) and isinstance(y, (int, float, complex)): - return float(x * y) - - return float(sum(a * b for a, b in zip(x, y))) - - def angle(value: complex) -> float: if HAS_NUMPY: return float(_np.angle(value)) @@ -319,16 +85,32 @@ class HostDType: kind: str +INT16 = HostDType("int16", 2, "h", "int") +UINT16 = HostDType("uint16", 2, "H", "uint") INT32 = HostDType("int32", 4, "i", "int") UINT32 = HostDType("uint32", 4, "I", "uint") +INT64 = HostDType("int64", 8, "q", "int") +UINT64 = HostDType("uint64", 8, "Q", "uint") +FLOAT16 = HostDType("float16", 2, "e", "float") FLOAT32 = HostDType("float32", 4, "f", "float") +FLOAT64 = HostDType("float64", 8, "d", "float") +COMPLEX32 = HostDType("complex32", 4, "ee", "complex") COMPLEX64 = HostDType("complex64", 8, "ff", "complex") +COMPLEX128 = HostDType("complex128", 16, "dd", "complex") _HOST_DTYPES = { + "int16": INT16, + "uint16": UINT16, "int32": INT32, "uint32": UINT32, + "int64": INT64, + "uint64": UINT64, + "float16": FLOAT16, "float32": FLOAT32, + "float64": FLOAT64, + "complex32": COMPLEX32, "complex64": COMPLEX64, + "complex128": COMPLEX128, } @@ -355,6 +137,16 @@ def host_dtype_name(dtype: Any) -> str: raise ValueError(f"Unsupported dtype ({dtype})!") +def _numpy_dtype_or_none(dtype_name: str): + if not HAS_NUMPY: + return None + + try: + return _np.dtype(dtype_name) + except TypeError: + return None + + def dtype_itemsize(dtype: Any) -> int: if isinstance(dtype, HostDType): return dtype.itemsize @@ -455,7 +247,13 @@ def from_buffer(buffer: bytes, dtype: Any, shape: Tuple[int, ...]): dtype_name = host_dtype_name(dtype) if HAS_NUMPY: - return _np.frombuffer(buffer, dtype=_np.dtype(dtype_name)).reshape(shape) + np_dtype = _numpy_dtype_or_none(dtype_name) + if np_dtype is not None: + return _np.frombuffer(buffer, dtype=np_dtype).reshape(shape) + + if dtype_name == "complex32": + half_pairs = _np.frombuffer(buffer, dtype=_np.float16).reshape(*shape, 2) + return half_pairs[..., 0].astype(_np.float32) + (1j * half_pairs[..., 1].astype(_np.float32)) return CompatArray(buffer, host_dtype(dtype_name), tuple(shape)) @@ -516,16 +314,19 @@ def pack_values(values: Sequence[Any], dtype: Any) -> bytes: dtype_name = host_dtype_name(dtype) if HAS_NUMPY: - array = _np.asarray(values_list, dtype=_np.dtype(dtype_name)) - return array.tobytes() + np_dtype = _numpy_dtype_or_none(dtype_name) + if np_dtype is not None: + array = _np.asarray(values_list, dtype=np_dtype) + return array.tobytes() host = host_dtype(dtype_name) if host.kind == "complex": output = bytearray() + pack_fmt = "=" + host.struct_format for value in values_list: coerced = _coerce_scalar(value, host) - output.extend(struct.pack("=ff", float(coerced.real), float(coerced.imag))) + output.extend(struct.pack(pack_fmt, float(coerced.real), float(coerced.imag))) return bytes(output) pack_fmt = "=" + host.struct_format @@ -539,13 +340,16 @@ def unpack_values(data: bytes, dtype: Any) -> List[Any]: dtype_name = host_dtype_name(dtype) if HAS_NUMPY: - return _np.frombuffer(data, dtype=_np.dtype(dtype_name)).tolist() + np_dtype = _numpy_dtype_or_none(dtype_name) + if np_dtype is not None: + return _np.frombuffer(data, dtype=np_dtype).tolist() host = host_dtype(dtype_name) if host.kind == "complex": values: List[Any] = [] - for real, imag in struct.iter_unpack("=ff", data): + unpack_fmt = "=" + host.struct_format + for real, imag in struct.iter_unpack(unpack_fmt, data): values.append(complex(real, imag)) return values @@ -558,26 +362,3 @@ def unpack_values(data: bytes, dtype: Any) -> List[Any]: return values - -def float_bits_to_int(value: float) -> int: - if HAS_NUMPY: - return int(_np.frombuffer(_np.float32(value).tobytes(), dtype=_np.int32)[0]) - return int(struct.unpack("=i", struct.pack("=f", float(value)))[0]) - - -def float_bits_to_uint(value: float) -> int: - if HAS_NUMPY: - return int(_np.frombuffer(_np.float32(value).tobytes(), dtype=_np.uint32)[0]) - return int(struct.unpack("=I", struct.pack("=f", float(value)))[0]) - - -def int_bits_to_float(value: int) -> float: - if HAS_NUMPY: - return float(_np.frombuffer(_np.int32(value).tobytes(), dtype=_np.float32)[0]) - return float(struct.unpack("=f", struct.pack("=i", int(value)))[0]) - - -def uint_bits_to_float(value: int) -> float: - if HAS_NUMPY: - return float(_np.frombuffer(_np.uint32(value).tobytes(), dtype=_np.float32)[0]) - return float(struct.unpack("=f", struct.pack("=I", int(value)))[0]) diff --git a/vkdispatch/execution_pipeline/buffer_builder.py b/vkdispatch/execution_pipeline/buffer_builder.py index 43086904..01418bae 100644 --- a/vkdispatch/execution_pipeline/buffer_builder.py +++ b/vkdispatch/execution_pipeline/buffer_builder.py @@ -11,7 +11,7 @@ import vkdispatch as vd import vkdispatch.codegen as vc -from .._compat import numpy_compat as npc +from ..compat import numpy_compat as npc from vkdispatch.base.dtype import to_numpy_dtype @@ -150,6 +150,12 @@ def _write_payload(self, instance_index: int, element_slice: slice, payload: byt if len(payload) != expected_size: raise ValueError(f"Packed value size mismatch! Expected {expected_size}, got {len(payload)}") + if npc.HAS_NUMPY: + np = npc.numpy_module() + row = self.backing_buffer[instance_index] + row[element_slice] = np.frombuffer(payload, dtype=np.uint8) + return + start = instance_index * self.instance_bytes + element_slice.start end = start + expected_size @@ -178,7 +184,7 @@ def _setitem_python(self, key: Tuple[str, str], value: Any) -> None: return # Broadcast scalar values across all instances for scalar fields. - if not isinstance(value, (list, tuple)) and not isinstance(value, npc.CompatArray) and buffer_element.shape == (1,): + if not isinstance(value, (list, tuple)) and not npc.is_array_like(value) and buffer_element.shape == (1,): payload = self._pack_single_instance_value([value], key, buffer_element) for instance_index in range(self.instance_count): self._write_payload(instance_index, buffer_element.memory_slice, payload) @@ -186,7 +192,7 @@ def _setitem_python(self, key: Tuple[str, str], value: Any) -> None: expected_element_count = npc.prod(buffer_element.shape) - if isinstance(value, npc.CompatArray): + if npc.is_array_like(value): flat_values = npc.flatten(value) expected_total = expected_element_count * self.instance_count @@ -224,7 +230,9 @@ def __setitem__( if self.backing_buffer is None: raise RuntimeError("BufferBuilder.prepare(...) must be called before assigning values") - if npc.HAS_NUMPY: + buffer_element = self.element_map[key] + + if npc.HAS_NUMPY and not npc.is_host_dtype(buffer_element.dtype): self._setitem_numpy(key, value) return @@ -236,7 +244,7 @@ def __repr__(self) -> str: for key, elem in self.element_map.items(): buffer_element = self.element_map[key] - if npc.HAS_NUMPY: + if npc.HAS_NUMPY and not npc.is_host_dtype(buffer_element.dtype): value = (self.backing_buffer[:, buffer_element.memory_slice]).view(buffer_element.dtype) else: decoded_instances = [] diff --git a/vkdispatch/execution_pipeline/command_graph.py b/vkdispatch/execution_pipeline/command_graph.py index 13ac8d25..efdfc40f 100644 --- a/vkdispatch/execution_pipeline/command_graph.py +++ b/vkdispatch/execution_pipeline/command_graph.py @@ -18,6 +18,9 @@ import dataclasses +def _runtime_supports_push_constants() -> bool: + return True + @dataclasses.dataclass class BufferBindInfo: """A dataclass to hold information about a buffer binding.""" @@ -63,9 +66,10 @@ class CommandGraph(CommandList): uniform_bindings: Any uniform_constants_size: int - uniform_constants_buffer: vd.Buffer + uniform_constants_buffer: Optional[vd.Buffer] uniform_descriptors: List[Tuple[DescriptorSet, int, int]] + recorded_descriptor_sets: List[DescriptorSet] name_to_pc_key_dict: Dict[str, List[Tuple[str, str]]] queued_pc_values: Dict[Tuple[str, str], Any] @@ -84,12 +88,75 @@ def __init__(self, reset_on_submit: bool = False, submit_on_record: bool = False self.queued_pc_values = {} self.uniform_descriptors = [] + self.recorded_descriptor_sets = [] self._reset_on_submit = reset_on_submit self.submit_on_record = submit_on_record + # Lazily allocate host-uploaded UBO backing only when needed by non-CUDA backends. self.uniform_constants_size = 0 - self.uniform_constants_buffer = vd.Buffer(shape=(4096,), var_type=vd.uint32) # Create a base static constants buffer at size 4k bytes + self.uniform_constants_buffer = None + + def _ensure_uniform_constants_capacity(self, uniform_word_size: int) -> None: + if self.uniform_constants_buffer is not None and uniform_word_size <= self.uniform_constants_size: + return + + # Grow exponentially to reduce reallocation churn for larger UBO layouts. + if self.uniform_constants_size == 0: + self.uniform_constants_size = max(4096, uniform_word_size) + else: + self.uniform_constants_size = max(uniform_word_size, self.uniform_constants_size * 2) + self.uniform_constants_buffer = vd.Buffer(shape=(self.uniform_constants_size,), var_type=vd.uint32) + + def _prepare_submission_state(self, instance_count: int) -> None: + if len(self.pc_builder.element_map) > 0 and ( + self.pc_builder.instance_count != instance_count or not self.buffers_valid + ): + + assert _runtime_supports_push_constants(), ( + "Push constants not supported for backends without push-constant support " + "(OpenCL). Use UBO-backed variables instead." + ) + + self.pc_builder.prepare(instance_count) + + for key, value in self.pc_values.items(): + self.pc_builder[key] = value + + if len(self.uniform_builder.element_map) > 0 and not self.buffers_valid: + self.uniform_builder.prepare(1) + + for key, value in self.uniform_values.items(): + self.uniform_builder[key] = value + + uniform_word_size = (self.uniform_builder.instance_bytes + 3) // 4 + uniform_payload = self.uniform_builder.tobytes() + + if vd.is_cuda(): + for descriptor_set, offset, size in self.uniform_descriptors: + descriptor_set.set_inline_uniform_payload(uniform_payload[offset:offset + size]) + else: + self._ensure_uniform_constants_capacity(uniform_word_size) + assert self.uniform_constants_buffer is not None + + for descriptor_set, offset, size in self.uniform_descriptors: + descriptor_set.bind_buffer(self.uniform_constants_buffer, 0, offset, size, True, write_access=False) + + self.uniform_constants_buffer.write(uniform_payload) + + if not self.buffers_valid: + self.buffers_valid = True + + def prepare_for_cuda_graph_capture(self, instance_count: int = None) -> None: + """Initialize internal data uploads before torch CUDA graph capture. + + This method performs one-time uniform/push-constant staging without submitting + the command list, so only kernel launches are captured by ``torch.cuda.graph``. + """ + if instance_count is None: + instance_count = 1 + + self._prepare_submission_state(instance_count) def reset(self) -> None: """Reset the command graph by clearing the push constant buffer and descriptor @@ -100,15 +167,29 @@ def reset(self) -> None: self.pc_builder.reset() self.uniform_builder.reset() - self.pc_values = {} - self.uniform_values = {} - self.name_to_pc_key_dict = {} - self.queued_pc_values = {} + for descriptor_set in self.recorded_descriptor_sets: + descriptor_set.destroy() + + self.pc_values.clear() + self.uniform_values.clear() + self.name_to_pc_key_dict.clear() + self.queued_pc_values.clear() + self.uniform_descriptors.clear() + self.recorded_descriptor_sets.clear() - self.uniform_descriptors = [] self.buffers_valid = False + + def _destroy(self) -> None: + self.reset() + super()._destroy() def bind_var(self, name: str): + if not _runtime_supports_push_constants(): + raise RuntimeError( + "CommandGraph.bind_var() is disabled for backends without push-constant " + "support (OpenCL). Pass Variable values directly at shader invocation." + ) + def register_var(key: Tuple[str, str]): if not name in self.name_to_pc_key_dict.keys(): self.name_to_pc_key_dict[name] = [] @@ -118,6 +199,12 @@ def register_var(key: Tuple[str, str]): return register_var def set_var(self, name: str, value: Any): + if not _runtime_supports_push_constants(): + raise RuntimeError( + "CommandGraph.set_var() is disabled for backends without push-constant " + "support (OpenCL). Pass Variable values directly at shader invocation." + ) + if name not in self.name_to_pc_key_dict.keys(): raise ValueError("Variable not bound!") @@ -154,18 +241,36 @@ def record_shader(self, """ descriptor_set = DescriptorSet(plan) + self.recorded_descriptor_sets.append(descriptor_set) if shader_uuid is None: shader_uuid = shader_description.name + "_" + str(uuid.uuid4()) + if (not _runtime_supports_push_constants()) and len(pc_values) > 0: + raise RuntimeError( + "Push-constant Variable payloads are disabled for backends without " + "push-constant support (OpenCL). " + "Variable values must be UBO-backed and provided at shader invocation." + ) + if len(shader_description.pc_structure) != 0: + if not _runtime_supports_push_constants(): + raise RuntimeError( + "Kernels should not emit push-constant layouts for backends without " + "push-constant support (OpenCL). Use UBO-backed variables." + ) self.pc_builder.register_struct(shader_uuid, shader_description.pc_structure) - - uniform_offset, uniform_range = self.uniform_builder.register_struct(shader_uuid, shader_description.uniform_structure) - self.uniform_descriptors.append((descriptor_set, uniform_offset, uniform_range)) + uniform_field_names = {elem.name for elem in shader_description.uniform_structure} + resolved_uniform_values: Dict[Tuple[str, str], Any] = {} - self.uniform_values[(shader_uuid, shader_description.exec_count_name)] = [exec_limits[0], exec_limits[1], exec_limits[2], 0] + if shader_description.exec_count_name is not None: + resolved_uniform_values[(shader_uuid, shader_description.exec_count_name)] = [ + exec_limits[0], + exec_limits[1], + exec_limits[2], + 0, + ] for buffer_bind_info in bound_buffers: descriptor_set.bind_buffer( @@ -175,7 +280,8 @@ def record_shader(self, write_access=buffer_bind_info.write_access, ) - self.uniform_values[(shader_uuid, buffer_bind_info.shape_name)] = buffer_bind_info.buffer.shader_shape + if buffer_bind_info.shape_name in uniform_field_names: + resolved_uniform_values[(shader_uuid, buffer_bind_info.shape_name)] = buffer_bind_info.buffer.shader_shape for sampler_bind_info in bound_samplers: descriptor_set.bind_sampler( @@ -186,7 +292,14 @@ def record_shader(self, ) for key, value in uniform_values.items(): - self.uniform_values[(shader_uuid, key)] = value + resolved_uniform_values[(shader_uuid, key)] = value + + if len(shader_description.uniform_structure) > 0: + uniform_offset, uniform_range = self.uniform_builder.register_struct(shader_uuid, shader_description.uniform_structure) + self.uniform_descriptors.append((descriptor_set, uniform_offset, uniform_range)) + + for key, value in resolved_uniform_values.items(): + self.uniform_values[key] = value for key, value in pc_values.items(): self.pc_values[(shader_uuid, key)] = value @@ -194,11 +307,15 @@ def record_shader(self, super().record_compute_plan(plan, descriptor_set, blocks) self.buffers_valid = False - + if self.submit_on_record: self.submit() - - def submit(self, instance_count: int = None, queue_index: int = -2) -> None: + + def submit( + self, + instance_count: int = None, + queue_index: int = -2 + ) -> None: """Submit the command list to the specified device with additional data to append to the front of the command list. @@ -210,30 +327,8 @@ def submit(self, instance_count: int = None, queue_index: int = -2) -> None: if instance_count is None: instance_count = 1 - - if len(self.pc_builder.element_map) > 0 and ( - self.pc_builder.instance_count != instance_count or not self.buffers_valid - ): - - self.pc_builder.prepare(instance_count) - - for key, value in self.pc_values.items(): - self.pc_builder[key] = value - - if len(self.uniform_builder.element_map) > 0 and not self.buffers_valid: - - self.uniform_builder.prepare(1) - - for key, value in self.uniform_values.items(): - self.uniform_builder[key] = value - - for descriptor_set, offset, size in self.uniform_descriptors: - descriptor_set.bind_buffer(self.uniform_constants_buffer, 0, offset, size, True, write_access=False) - self.uniform_constants_buffer.write(self.uniform_builder.tobytes()) - - if not self.buffers_valid: - self.buffers_valid = True + self._prepare_submission_state(instance_count) for key, val in self.queued_pc_values.items(): self.pc_builder[key] = val @@ -243,7 +338,12 @@ def submit(self, instance_count: int = None, queue_index: int = -2) -> None: if len(self.pc_builder.element_map) > 0: my_data = self.pc_builder.tobytes() - super().submit(data=my_data, queue_index=queue_index, instance_count=instance_count) + super().submit( + data=my_data, + queue_index=queue_index, + instance_count=instance_count, + cuda_stream=None, + ) if self._reset_on_submit: self.reset() @@ -253,9 +353,6 @@ def submit_any(self, instance_count: int = None) -> None: _global_graph = threading.local() -#__default_graph = None -#__custom_graph = None - def _get_global_graph() -> Optional[CommandGraph]: return getattr(_global_graph, 'custom_graph', None) @@ -279,4 +376,4 @@ def set_global_graph(graph: CommandGraph = None) -> CommandGraph: return assert _get_global_graph() is None, "A global CommandGraph is already set for the current thread!" - _global_graph.custom_graph = graph \ No newline at end of file + _global_graph.custom_graph = graph diff --git a/vkdispatch/execution_pipeline/cuda_graph_capture.py b/vkdispatch/execution_pipeline/cuda_graph_capture.py new file mode 100644 index 00000000..a96f6a9e --- /dev/null +++ b/vkdispatch/execution_pipeline/cuda_graph_capture.py @@ -0,0 +1,51 @@ +import vkdispatch as vd + +from contextlib import contextmanager + +import threading + +import typing + +class CUDAGraphCapture: + cuda_stream = typing.Any + uniform_buffers = typing.List[typing.Any] + + def add_uniform_buffer(self, buffer): + self.uniform_buffers.append(buffer) + +_cap = threading.local() + +def _set_capture(capture): + _cap.capture = capture + +def get_cuda_capture() -> CUDAGraphCapture: + return getattr(_cap, "capture", None) + +@contextmanager +def cuda_graph_capture(cuda_stream=None): + assert vd.is_cuda(), "CUDA graph capture is only supported when using the CUDA backend." + + cap = CUDAGraphCapture() + cap.cuda_stream = cuda_stream + cap.uniform_buffers = [] + + _set_capture(cap) + + try: + yield cap + finally: + _set_capture(None) + +@contextmanager +def suspend_cuda_capture(): + """Temporarily disable vkdispatch CUDA capture state for non-captured ops.""" + cap = get_cuda_capture() + if cap is None: + yield + return + + _set_capture(None) + try: + yield + finally: + _set_capture(cap) diff --git a/vkdispatch/fft/__init__.py b/vkdispatch/fft/__init__.py index b16e51ef..5dab17ff 100644 --- a/vkdispatch/fft/__init__.py +++ b/vkdispatch/fft/__init__.py @@ -22,6 +22,15 @@ from .functions import fft, fft2, fft3, ifft, ifft2, ifft3 from .functions import rfft, rfft2, rfft3, irfft, irfft2, irfft3 +from .src_functions import fft_src, fft2_src, fft3_src, ifft_src, ifft2_src, ifft3_src +from .src_functions import rfft_src, rfft2_src, rfft3_src, irfft_src, irfft2_src, irfft3_src + +from .src_functions import fft_print_src, fft2_print_src, fft3_print_src, ifft_print_src, ifft2_print_src, ifft3_print_src +from .src_functions import rfft_print_src, rfft2_print_src, rfft3_print_src, irfft_print_src, irfft2_print_src, irfft3_print_src + from .functions import convolve, convolve2D, convolve2DR, transpose +from .src_functions import convolve_src, convolve2D_src, convolve2DR_src, transpose_src +from .src_functions import convolve_print_src, convolve2D_print_src, convolve2DR_print_src + from .prime_utils import pad_dim \ No newline at end of file diff --git a/vkdispatch/fft/config.py b/vkdispatch/fft/config.py index ca8e1d6d..02628e84 100644 --- a/vkdispatch/fft/config.py +++ b/vkdispatch/fft/config.py @@ -1,98 +1,196 @@ import vkdispatch as vd import vkdispatch.codegen as vc import dataclasses -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Dict -from .._compat import numpy_compat as npc +from ..compat import numpy_compat as npc +import vkdispatch.base.dtype as dtypes from .prime_utils import prime_factors, group_primes, default_register_limit, default_max_prime -@dataclasses.dataclass -class FFTRegisterStageConfig: - """ - Configuration for an FFT register stage. - - Attributes: - - primes (Tuple[int]): The prime numbers used for factorization. - fft_length (int): The length of each FFT stage. - instance_count (int): The number of instances required to achieve the desired level of parallelism. - registers_used (int): The total number of registers used by the FFT stage. - remainder (int): The remainder of `N` divided by `registers_used`. - remainder_offset (int): A flag indicating whether the remainder is non-zero. - extra_ffts (int): The additional number of FFT stages required to process the remainder. - thread_count (int): The total number of threads used in the computation. - sdata_size (int): The size of the shared memory buffer used to store intermediate results. - sdata_width (int): The width of each element in the shared memory buffer. - sdata_width_padded (int): The padded width of each element in the shared memory buffer. - - """ - - primes: Tuple[int] - fft_length: int - instance_count: int - registers_used: int - remainder: int - remainder_offset: int - extra_ffts: int - thread_count: int - sdata_size: int - sdata_width: int - sdata_width_padded: int - - def __init__(self, primes: List[int], max_register_count: int, N: int): - """ - Initializes the FFTRegisterStageConfig object. +from .stages import FFTRegisterStageConfig - Parameters: +def plan_fft_stages(N: int, max_register_count: int, compute_item_size: int) -> Tuple[FFTRegisterStageConfig]: + all_factors = prime_factors(N) - primes (List[int]): The prime numbers to use for factorization. - max_register_count (int): The maximum number of registers allowed per thread. - N (int): The length of the input data. + for factor in all_factors: + assert factor <= default_max_prime(), f"A prime factor of {N} is {factor}, which exceeds the maximum prime supported {default_max_prime()}" - """ - self.primes = tuple(primes) - self.fft_length = int(round(npc.prod(primes))) - instance_primes = prime_factors(N // self.fft_length) - - self.instance_count = 1 + prime_groups = group_primes(all_factors, max_register_count) - while len(instance_primes) > 0: - if self.instance_count * self.fft_length * instance_primes[0] > max_register_count: - break - self.instance_count *= instance_primes[0] - instance_primes = instance_primes[1:] + stages = [] + input_stride = 1 - self.registers_used = self.fft_length * self.instance_count + for group in prime_groups: + stage = FFTRegisterStageConfig( + group, + max_register_count, + N, + compute_item_size, + input_stride + ) + stages.append(stage) + input_stride = stage.output_stride - self.remainder = N % self.registers_used - assert self.remainder % self.fft_length == 0, "Remainder must be divisible by the FFT length" - self.remainder_offset = 1 if self.remainder != 0 else 0 - self.extra_ffts = self.remainder // self.fft_length + return tuple(stages) - self.thread_count = N // self.registers_used + self.remainder_offset - - self.sdata_width = self.registers_used - - threads_primes = prime_factors(self.thread_count) +@dataclasses.dataclass +class FFTPlanCandidate: + max_register_count: int + stages: Tuple[FFTRegisterStageConfig] + register_count: int + batch_threads: int + transfer_count: Optional[int] = None + + def __init__(self, N: int, max_register_count: int,compute_item_size: int): + stages = plan_fft_stages(N, max_register_count, compute_item_size) + register_count = max(stage.registers_used for stage in stages) + batch_threads = max(stage.thread_count for stage in stages) + + if register_count > max_register_count: + self.max_register_count = None + self.stages = None + self.register_count = None + self.batch_threads = None + self.transfer_count = None + return + + transfer_count = 0 + output_stride = 1 + + for stage_index in range(len(stages) - 1): + output_stage = stages[stage_index] + input_stage = stages[stage_index + 1] + + output_keys = output_stage.get_output_format(register_count).keys() + input_keys = input_stage.get_input_format(register_count).keys() + + if output_keys != input_keys: + transfer_count += 1 + + output_stride *= output_stage.fft_length + + self.max_register_count = max_register_count + self.stages = stages + self.register_count = register_count + self.batch_threads = batch_threads + self.transfer_count = transfer_count + +def register_limit_candidates(N: int, initial_limit: int) -> List[int]: + divisors = {1} + + for factor in prime_factors(N): + divisors.update(divisor * factor for divisor in tuple(divisors)) + + candidates = [initial_limit] + candidates.extend( + divisor + for divisor in sorted(divisors) + if initial_limit < divisor <= N + ) + return candidates + +def required_batch_threads_limit(batch_inner_count: int) -> int: + context = vd.get_context() + thread_dimension_limit = ( + context.max_workgroup_size[1] + if batch_inner_count > 1 + else context.max_workgroup_size[0] + ) + return max(1, min(int(thread_dimension_limit), int(context.max_workgroup_invocations))) + +def select_fft_plan_candidate( + N: int, + batch_inner_count: int, + compute_item_size: int, + max_register_count: Optional[int], +) -> FFTPlanCandidate: + batch_threads_limit = required_batch_threads_limit(batch_inner_count) + dimension_name = "y" if batch_inner_count > 1 else "x" + + if max_register_count is not None: + requested_limit = min(max_register_count, N) + candidate = FFTPlanCandidate( + N=N, + max_register_count=requested_limit, + compute_item_size=compute_item_size, + ) + + assert candidate.stages is not None, f"Failed to create an FFT plan candidate for N={N} with max_register_count={requested_limit}" + + if candidate.batch_threads <= batch_threads_limit: + return candidate + + best_candidate = candidate + explicit_text = "requested" + searched_limit = requested_limit + else: + max_registers = default_register_limit() - while self.sdata_width < 16 and len(threads_primes) > 0: - self.sdata_width *= threads_primes[0] - threads_primes = threads_primes[1:] + if N==16 or N==8 or N==4 or N==2 and vd.get_devices()[0].is_nvidia(): + max_registers = max(2, N//2) + + baseline_limit = min(8, N) + requested_limit = baseline_limit + candidate_limits = register_limit_candidates(max_registers, baseline_limit) + searched_limit = candidate_limits[-1] + + baseline_candidate = FFTPlanCandidate( + N=N, + max_register_count=baseline_limit, + compute_item_size=compute_item_size, + ) + best_candidate = baseline_candidate if baseline_candidate.stages is not None else None + + if best_candidate is not None and baseline_candidate.batch_threads <= batch_threads_limit: + for candidate_limit in candidate_limits[1:]: + candidate = FFTPlanCandidate( + N=N, + max_register_count=candidate_limit, + compute_item_size=compute_item_size, + ) + + if candidate.stages is None: + continue + + if best_candidate is None or candidate.batch_threads < best_candidate.batch_threads: + best_candidate = candidate + + if candidate.batch_threads > batch_threads_limit: + continue + + if candidate.transfer_count < baseline_candidate.transfer_count: + return candidate + + return baseline_candidate + + for candidate_limit in candidate_limits[1:]: + candidate = FFTPlanCandidate( + N=N, + max_register_count=candidate_limit, + compute_item_size=compute_item_size, + ) + if candidate.stages is None: + continue - self.sdata_width_padded = self.sdata_width + if best_candidate is None or candidate.batch_threads < best_candidate.batch_threads: + best_candidate = candidate - if self.sdata_width_padded % 2 == 0: - self.sdata_width_padded += 1 + if candidate.batch_threads <= batch_threads_limit: + return candidate - self.sdata_size = self.sdata_width_padded * int(npc.prod(threads_primes)) + explicit_text = "default" - if self.sdata_size > vd.get_context().max_shared_memory // vd.complex64.item_size: - self.sdata_width_padded = self.sdata_width - self.sdata_size = self.sdata_width_padded * int(npc.prod(threads_primes)) + raise ValueError( + f"Unable to build an FFT plan for size {N}: minimum achievable batch thread count " + f"{best_candidate.batch_threads} exceeds the device's local {dimension_name}-dimension " + f"limit {batch_threads_limit} (starting from {explicit_text} max_register_count=" + f"{requested_limit}, searched up to {searched_limit})." + ) @dataclasses.dataclass class FFTConfig: N: int + compute_type: dtypes.dtype register_count: int max_prime_radix: int stages: Tuple[FFTRegisterStageConfig] @@ -107,10 +205,21 @@ class FFTConfig: sdata_row_size: int sdata_row_size_padded: int - def __init__(self, buffer_shape: Tuple, axis: int = None, max_register_count: int = None): + def __init__( + self, + buffer_shape: Tuple, + axis: int = None, + max_register_count: int = None, + compute_type: dtypes.dtype = vd.complex64, + ): if axis is None: axis = len(buffer_shape) - 1 + if not dtypes.is_complex(compute_type): + raise ValueError(f"compute_type must be a complex dtype, got {compute_type}") + + self.compute_type = compute_type + total_buffer_length = int(round(npc.prod(buffer_shape))) N = buffer_shape[axis] @@ -123,14 +232,6 @@ def __init__(self, buffer_shape: Tuple, axis: int = None, max_register_count: in self.N = N - if max_register_count is None: - max_register_count = default_register_limit() - - if N==16 or N==8 or N==4 or N==2 and vd.get_devices()[0].is_nvidia(): - max_register_count = max(2, N//2) - - max_register_count = min(max_register_count, N) - all_factors = prime_factors(N) for factor in all_factors: @@ -138,13 +239,14 @@ def __init__(self, buffer_shape: Tuple, axis: int = None, max_register_count: in self.max_prime_radix = max(all_factors) - prime_groups = group_primes(all_factors, max_register_count) - - self.stages = tuple([FFTRegisterStageConfig(group, max_register_count, N) for group in prime_groups]) - register_utilizations = [stage.registers_used for stage in self.stages] - self.register_count = max(register_utilizations) - - assert self.register_count <= max_register_count, f"Register count {self.register_count} exceeds max register count {max_register_count}" + plan_candidate = select_fft_plan_candidate( + N=N, + batch_inner_count=self.batch_inner_count, + compute_item_size=self.compute_type.item_size, + max_register_count=max_register_count, + ) + self.stages = plan_candidate.stages + self.register_count = plan_candidate.register_count self.sdata_allocation = 1 self.sdata_row_size = 1 @@ -158,9 +260,9 @@ def __init__(self, buffer_shape: Tuple, axis: int = None, max_register_count: in self.sdata_row_size = stage.sdata_width self.sdata_row_size_padded = stage.sdata_width_padded - self.thread_counts = [stage.thread_count for stage in self.stages] + self.thread_counts = tuple(stage.thread_count for stage in self.stages) - self.batch_threads = max(self.thread_counts) + self.batch_threads = plan_candidate.batch_threads def __str__(self): return f"FFT Config:\nN: {self.N}\nregister_count: {self.register_count}\nstages:\n{self.stages}\nlocal_size: {self.thread_counts}" diff --git a/vkdispatch/fft/context.py b/vkdispatch/fft/context.py index 2afa1ece..8a6bc7cc 100644 --- a/vkdispatch/fft/context.py +++ b/vkdispatch/fft/context.py @@ -1,5 +1,6 @@ import vkdispatch as vd import vkdispatch.codegen as vc +import vkdispatch.base.dtype as dtypes import contextlib from typing import Optional, Tuple, Union, List, Dict @@ -31,12 +32,13 @@ def __init__(self, buffer_shape: Tuple, axis: int = None, max_register_count: int = None, + compute_type: dtypes.dtype = vd.complex64, name: str = None): self.shader_context = shader_context self.declared_shader_args = False self.declarer = None - self.config = FFTConfig(buffer_shape, axis, max_register_count) + self.config = FFTConfig(buffer_shape, axis, max_register_count, compute_type=compute_type) self.grid = FFTGridManager(self.config, True, True) self.resources = FFTResources(self.config, self.grid) @@ -63,6 +65,8 @@ def declare_shader_args(self, types: List) -> List[vc.ShaderVariable]: def make_io_manager(self, output_map: Optional[vd.MappingFunction], + output_type: dtypes.dtype = vd.complex64, + input_type: Optional[dtypes.dtype] = None, input_map: Optional[vd.MappingFunction] = None, kernel_map: Optional[vd.MappingFunction] = None) -> IOManager: assert not self.declared_shader_args, f"Shader arguments already declared with {self.declarer}" @@ -72,6 +76,8 @@ def make_io_manager(self, default_registers=self.registers, shader_context=self.shader_context, output_map=output_map, + output_type=output_type, + input_type=input_type, input_map=input_map, kernel_map=kernel_map ) @@ -127,7 +133,8 @@ def register_shuffle(self, def compile_shader(self): self.fft_callable = self.shader_context.get_function( local_size=self.grid.local_size, - exec_count=self.grid.exec_size + exec_count=self.grid.exec_size, + name=self.name ) def get_callable(self) -> vd.ShaderFunction: @@ -148,7 +155,7 @@ def execute(self, inverse: bool): self.register_shuffle(output_stage=i-1, input_stage=i) self.resources.stage_begin(i) - for ii, invocation in enumerate(self.resources.invocations[i]): + for ii, invocation in enumerate(self.config.stages[i].invocations): self.resources.invocation_gaurd(i, ii) self.registers.slice_set(invocation.register_selection, radix_composite( @@ -156,7 +163,7 @@ def execute(self, inverse: bool): inverse=inverse, register_list=self.registers.register_slice(invocation.register_selection), primes=stage.primes, - twiddle_index=invocation.inner_block_offset, + twiddle_index=invocation.get_inner_block_offset(self.resources.tid), twiddle_N=invocation.block_width )) @@ -166,7 +173,9 @@ def execute(self, inverse: bool): @contextlib.contextmanager def fft_context(buffer_shape: Tuple, axis: Optional[int] = None, - max_register_count: Optional[int] = None): + max_register_count: Optional[int] = None, + compute_type: dtypes.dtype = vd.complex64, + name: Optional[str] = None): try: with vd.shader_context(vc.ShaderFlags.NO_EXEC_BOUNDS) as context: @@ -174,7 +183,9 @@ def fft_context(buffer_shape: Tuple, shader_context=context, buffer_shape=buffer_shape, axis=axis, - max_register_count=max_register_count + max_register_count=max_register_count, + compute_type=compute_type, + name=name ) yield fft_context diff --git a/vkdispatch/fft/cooley_tukey.py b/vkdispatch/fft/cooley_tukey.py index 006e0763..f2821907 100644 --- a/vkdispatch/fft/cooley_tukey.py +++ b/vkdispatch/fft/cooley_tukey.py @@ -3,7 +3,7 @@ from typing import List, Union -from .._compat import numpy_compat as npc +from ..compat import numpy_compat as npc def get_angle_factor(inverse: bool) -> float: return 2 * npc.pi * (1 if inverse else -1) @@ -46,6 +46,9 @@ def _apply_twiddle_to_register( if isinstance(twiddle, complex): if _apply_constant_twiddle(resources, register, twiddle): return + + twiddle = vc.to_dtype(register.var_type, twiddle.real, twiddle.imag) + resources.radix_registers[0][:] = vc.mult_complex(register, twiddle) register[:] = resources.radix_registers[0] @@ -81,7 +84,8 @@ def radix_P(resources: FFTResources, inverse: bool, register_list: List[vc.Shade continue omega = npc.exp_complex(1j * angle_factor * i * j / len(register_list)) - resources.omega_register[:] = vc.mult_complex(register_list[j], omega) + typed_omega = vc.to_dtype(register_list[j].var_type, omega.real, omega.imag) + resources.omega_register[:] = vc.mult_complex(register_list[j], typed_omega) resources.radix_registers[i] += resources.omega_register for i in range(0, len(register_list)): @@ -118,7 +122,9 @@ def apply_twiddle_factors( _apply_twiddle_to_register(resources, register_list[i], omega) continue - resources.omega_register.real = (angle_factor * i / twiddle_N) * twiddle_index + angle_scale = vc.to_dtype(resources.omega_register.real.var_type, angle_factor * i / twiddle_N) + twiddle_scale = vc.to_dtype(resources.omega_register.real.var_type, twiddle_index) + resources.omega_register.real = angle_scale * twiddle_scale resources.omega_register[:] = vc.complex_from_euler_angle(resources.omega_register.real) resources.radix_registers[0][:] = vc.mult_complex(register_list[i], resources.omega_register) register_list[i][:] = resources.radix_registers[0] diff --git a/vkdispatch/fft/functions.py b/vkdispatch/fft/functions.py index 9c400b4b..0818a8eb 100644 --- a/vkdispatch/fft/functions.py +++ b/vkdispatch/fft/functions.py @@ -1,8 +1,105 @@ import vkdispatch as vd from .shader_factories import make_fft_shader, make_convolution_shader, make_transpose_shader, get_transposed_size +from .precision import ( + ensure_supported_complex_precision, + resolve_compute_precision, + validate_complex_precision, +) -from typing import Tuple, Union, Optional +from typing import List, Tuple, Union, Optional + + +def _validate_map_argument_annotations(map_fn: vd.MappingFunction, map_name: str) -> None: + for buffer_type in map_fn.buffer_types: + if not hasattr(buffer_type, "__args__") or len(buffer_type.__args__) != 1: + raise ValueError( + f"{map_name} contains an annotation without exactly one type argument: {buffer_type}" + ) + + +def _resolve_output_precision( + buffers: Tuple[vd.Buffer, ...], + output_map: Optional[vd.MappingFunction], + output_type: Optional[vd.dtype], +) -> Optional[vd.dtype]: + if output_map is not None: + if output_type is not None: + raise ValueError("output_type cannot be provided when output_map is used") + return None + + resolved_output = buffers[0].var_type if output_type is None else output_type + validate_complex_precision(resolved_output, arg_name="output_type") + ensure_supported_complex_precision(resolved_output, role="Output") + return resolved_output + + +def _resolve_input_precision( + buffers: Tuple, + input_map: Optional[vd.MappingFunction], + output_map: Optional[vd.MappingFunction], + input_type: Optional[vd.dtype], + output_precision: Optional[vd.dtype], +) -> Optional[vd.dtype]: + if input_map is not None: + if input_type is not None: + raise ValueError("input_type cannot be provided when input_map is used") + return None + + if output_map is not None: + output_arg_count = len(output_map.buffer_types) + if len(buffers) <= output_arg_count: + raise ValueError( + "When output_map is used without input_map, an input buffer argument must be provided " + "after output_map arguments" + ) + + resolved_input = input_type + if resolved_input is None: + inferred_input = buffers[output_arg_count] + if not hasattr(inferred_input, "var_type"): + raise ValueError( + "When output_map is used without input_map, the argument after output_map arguments " + "must be a buffer" + ) + resolved_input = inferred_input.var_type + + validate_complex_precision(resolved_input, arg_name="input_type") + ensure_supported_complex_precision(resolved_input, role="Input") + return resolved_input + + if output_precision is None: + raise ValueError("output_precision must be provided when output_map is not used") + + resolved_input = output_precision if input_type is None else input_type + validate_complex_precision(resolved_input, arg_name="input_type") + ensure_supported_complex_precision(resolved_input, role="Input") + + if resolved_input != output_precision: + raise ValueError( + "input_type must match output_type when input_map is None (default FFT path is in-place)" + ) + + return resolved_input + + +def _resolve_kernel_precision( + buffers: Tuple[vd.Buffer, ...], + kernel_map: Optional[vd.MappingFunction], + kernel_type: Optional[vd.dtype], +) -> Optional[vd.dtype]: + if kernel_map is not None: + if kernel_type is not None: + raise ValueError("kernel_type cannot be provided when kernel_map is used") + return None + + if len(buffers) < 2: + raise ValueError("Kernel precision inference requires a kernel buffer argument") + + resolved_kernel = buffers[1].var_type if kernel_type is None else kernel_type + validate_complex_precision(resolved_kernel, arg_name="kernel_type") + ensure_supported_complex_precision(resolved_kernel, role="Kernel") + return resolved_kernel def fft( *buffers: vd.Buffer, @@ -16,13 +113,36 @@ def fft( r2c: bool = False, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + compute_type: vd.dtype = None, input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None): assert len(buffers) >= 1, "At least one buffer must be provided" + + if input_map is None and output_map is None and len(buffers) != 1: + raise ValueError("fft() expects exactly one buffer unless input_map/output_map are used") if buffer_shape is None: buffer_shape = buffers[0].shape + resolved_output_type = _resolve_output_precision(buffers, output_map, output_type) + resolved_input_type = _resolve_input_precision(buffers, input_map, output_map, input_type, resolved_output_type) + + io_precisions: List[vd.dtype] = [] + if output_map is None: + io_precisions.append(resolved_output_type) + else: + _validate_map_argument_annotations(output_map, "output_map") + + if input_map is None: + if resolved_input_type is not None: + io_precisions.append(resolved_input_type) + else: + _validate_map_argument_annotations(input_map, "input_map") + + resolved_compute_type = resolve_compute_precision(io_precisions, compute_type) + fft_shader = make_fft_shader( tuple(buffer_shape), axis, @@ -31,6 +151,9 @@ def fft( r2c=r2c, input_map=input_map, output_map=output_map, + input_type=resolved_input_type, + output_type=resolved_output_type, + compute_type=resolved_compute_type, input_signal_range=input_signal_range) if print_shader: @@ -38,18 +161,80 @@ def fft( fft_shader(*buffers, graph=graph) -def fft2(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): +def fft2( + buffer: vd.Buffer, + graph: vd.CommandGraph = None, + print_shader: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + compute_type: vd.dtype = None, +): assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' - fft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 2, input_map=input_map) - fft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 1, output_map=output_map) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=len(buffer.shape) - 2, + input_map=input_map, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=len(buffer.shape) - 1, + output_map=output_map, + output_type=output_type if output_map is None else None, + input_type=input_type if output_map is None else None, + compute_type=compute_type, + ) -def fft3(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): +def fft3( + buffer: vd.Buffer, + graph: vd.CommandGraph = None, + print_shader: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + compute_type: vd.dtype = None, +): assert len(buffer.shape) == 3, 'Buffer must have 3 dimensions' - fft(buffer, graph=graph, print_shader=print_shader, axis=0, input_map=input_map) - fft(buffer, graph=graph, print_shader=print_shader, axis=1) - fft(buffer, graph=graph, print_shader=print_shader, axis=2, output_map=output_map) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=0, + input_map=input_map, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=1, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=2, + output_map=output_map, + output_type=output_type if output_map is None else None, + input_type=input_type if output_map is None else None, + compute_type=compute_type, + ) def ifft( @@ -60,54 +245,225 @@ def ifft( name: str = None, normalize: bool = True, input_map: vd.MappingFunction = None, - output_map: vd.MappingFunction = None): - fft(buffer, graph=graph, print_shader=print_shader, axis=axis, name=name, inverse=True, normalize_inverse=normalize, input_map=input_map, output_map=output_map) + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + compute_type: vd.dtype = None): + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=axis, + name=name, + inverse=True, + normalize_inverse=normalize, + input_map=input_map, + output_map=output_map, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) -def ifft2(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): +def ifft2( + buffer: vd.Buffer, + graph: vd.CommandGraph = None, + print_shader: bool = False, + normalize: bool = True, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + compute_type: vd.dtype = None, +): assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' - ifft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 2, normalize=normalize, input_map=input_map) - ifft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 1, normalize=normalize, output_map=output_map) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=len(buffer.shape) - 2, + normalize=normalize, + input_map=input_map, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=len(buffer.shape) - 1, + normalize=normalize, + output_map=output_map, + output_type=output_type if output_map is None else None, + input_type=input_type if output_map is None else None, + compute_type=compute_type, + ) -def ifft3(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): +def ifft3( + buffer: vd.Buffer, + graph: vd.CommandGraph = None, + print_shader: bool = False, + normalize: bool = True, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + compute_type: vd.dtype = None, +): assert len(buffer.shape) == 3, 'Buffer must have 3 dimensions' - ifft(buffer, graph=graph, print_shader=print_shader, axis=0, normalize=normalize, input_map=input_map) - ifft(buffer, graph=graph, print_shader=print_shader, axis=1, normalize=normalize) - ifft(buffer, graph=graph, print_shader=print_shader, axis=2, normalize=normalize, output_map=output_map) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=0, + normalize=normalize, + input_map=input_map, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=1, + normalize=normalize, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=2, + normalize=normalize, + output_map=output_map, + output_type=output_type if output_map is None else None, + input_type=input_type if output_map is None else None, + compute_type=compute_type, + ) -def rfft(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, name: str = None): - fft(buffer, buffer_shape=buffer.real_shape, graph=graph, print_shader=print_shader, name=name, r2c=True) +def rfft( + buffer: vd.RFFTBuffer, + graph: vd.CommandGraph = None, + print_shader: bool = False, + name: str = None, + compute_type: vd.dtype = None, +): + fft( + buffer, + buffer_shape=buffer.real_shape, + graph=graph, + print_shader=print_shader, + name=name, + r2c=True, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) -def rfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False): +def rfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, compute_type: vd.dtype = None): assert len(buffer.real_shape) == 2 or len(buffer.real_shape) == 3, 'Buffer must have 2 or 3 dimensions' - rfft(buffer, graph=graph, print_shader=print_shader) - fft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.real_shape) - 2) + rfft(buffer, graph=graph, print_shader=print_shader, compute_type=compute_type) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=len(buffer.real_shape) - 2, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) -def rfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False): +def rfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, compute_type: vd.dtype = None): assert len(buffer.real_shape) == 3, 'Buffer must have 3 dimensions' - rfft(buffer, graph=graph, print_shader=print_shader) - fft(buffer, graph=graph, print_shader=print_shader, axis=1) - fft(buffer, graph=graph, print_shader=print_shader, axis=0) + rfft(buffer, graph=graph, print_shader=print_shader, compute_type=compute_type) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=1, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=0, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) -def irfft(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, name: str = None, normalize: bool = True): - fft(buffer, buffer_shape=buffer.real_shape, graph=graph, print_shader=print_shader, name=name, inverse=True, normalize_inverse=normalize, r2c=True) +def irfft( + buffer: vd.RFFTBuffer, + graph: vd.CommandGraph = None, + print_shader: bool = False, + name: str = None, + normalize: bool = True, + compute_type: vd.dtype = None, +): + fft( + buffer, + buffer_shape=buffer.real_shape, + graph=graph, + print_shader=print_shader, + name=name, + inverse=True, + normalize_inverse=normalize, + r2c=True, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) -def irfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True): +def irfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, compute_type: vd.dtype = None): assert len(buffer.real_shape) == 2 or len(buffer.real_shape) == 3, 'Buffer must have 2 or 3 dimensions' - ifft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.real_shape) - 2, normalize=normalize) - irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=len(buffer.real_shape) - 2, + normalize=normalize, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) + irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize, compute_type=compute_type) -def irfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True): +def irfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, compute_type: vd.dtype = None): assert len(buffer.real_shape) == 3, 'Buffer must have 3 dimensions' - ifft(buffer, graph=graph, print_shader=print_shader, axis=0, normalize=normalize) - ifft(buffer, graph=graph, print_shader=print_shader, axis=1, normalize=normalize) - irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=0, + normalize=normalize, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=1, + normalize=normalize, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) + irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize, compute_type=compute_type) def convolve( *buffers: vd.Buffer, @@ -123,10 +479,43 @@ def convolve( kernel_inner_only: bool = False, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + kernel_type: vd.dtype = None, + compute_type: vd.dtype = None, input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None): + assert len(buffers) >= 1, "At least one buffer must be provided" + + if kernel_map is None and len(buffers) < 2: + raise ValueError("convolve() requires at least an output buffer and kernel buffer") + if buffer_shape is None: buffer_shape = buffers[0].shape + resolved_output_type = _resolve_output_precision(buffers, output_map, output_type) + resolved_input_type = _resolve_input_precision(buffers, input_map, output_map, input_type, resolved_output_type) + resolved_kernel_type = _resolve_kernel_precision(buffers, kernel_map, kernel_type) + + io_precisions: List[vd.dtype] = [] + + if output_map is None: + io_precisions.append(resolved_output_type) + else: + _validate_map_argument_annotations(output_map, "output_map") + + if input_map is None: + if resolved_input_type is not None: + io_precisions.append(resolved_input_type) + else: + _validate_map_argument_annotations(input_map, "input_map") + + if kernel_map is None: + io_precisions.append(resolved_kernel_type) + else: + _validate_map_argument_annotations(kernel_map, "kernel_map") + + resolved_compute_type = resolve_compute_precision(io_precisions, compute_type) + fft_shader = make_convolution_shader( tuple(buffer_shape), kernel_map, @@ -137,6 +526,10 @@ def convolve( normalize=normalize, input_map=input_map, output_map=output_map, + input_type=resolved_input_type, + output_type=resolved_output_type, + kernel_type=resolved_kernel_type, + compute_type=resolved_compute_type, input_signal_range=input_signal_range) if print_shader: @@ -155,7 +548,11 @@ def convolve2D( transposed_kernel: bool = False, kernel_inner_only: bool = False, input_map: vd.MappingFunction = None, - output_map: vd.MappingFunction = None): + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + kernel_type: vd.dtype = None, + compute_type: vd.dtype = None): assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' @@ -168,7 +565,15 @@ def convolve2D( if output_map is not None: output_buffers.append(buffer) - fft(*input_buffers, graph=graph, print_shader=print_shader, input_map=input_map) + fft( + *input_buffers, + graph=graph, + print_shader=print_shader, + input_map=input_map, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) convolve( buffer, kernel, @@ -179,9 +584,22 @@ def convolve2D( kernel_inner_only=kernel_inner_only, print_shader=print_shader, axis=len(buffer.shape) - 2, - normalize=normalize + normalize=normalize, + output_type=output_type, + input_type=input_type, + kernel_type=kernel_type, + compute_type=compute_type, + ) + ifft( + *output_buffers, + graph=graph, + print_shader=print_shader, + normalize=normalize, + output_map=output_map, + output_type=output_type if output_map is None else None, + input_type=input_type if output_map is None else None, + compute_type=compute_type, ) - ifft(*output_buffers, graph=graph, print_shader=print_shader, normalize=normalize, output_map=output_map) def convolve2DR( buffer: vd.RFFTBuffer, @@ -192,11 +610,12 @@ def convolve2DR( kernel_inner_only: bool = False, graph: vd.CommandGraph = None, print_shader: bool = False, - normalize: bool = True): + normalize: bool = True, + compute_type: vd.dtype = None): assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' - rfft(buffer, graph=graph, print_shader=print_shader) + rfft(buffer, graph=graph, print_shader=print_shader, compute_type=compute_type) convolve( buffer, kernel, @@ -207,9 +626,13 @@ def convolve2DR( kernel_inner_only=kernel_inner_only, print_shader=print_shader, axis=len(buffer.shape) - 2, - normalize=normalize + normalize=normalize, + output_type=buffer.var_type, + input_type=buffer.var_type, + kernel_type=kernel.var_type, + compute_type=compute_type, ) - irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize) + irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize, compute_type=compute_type) def transpose( in_buffer: vd.Buffer, @@ -218,25 +641,54 @@ def transpose( out_buffer: vd.Buffer = None, graph: vd.CommandGraph = None, kernel_inner_only: bool = False, - print_shader: bool = False) -> vd.Buffer: - + print_shader: bool = False, + input_type: vd.dtype = None, + output_type: vd.dtype = None, + compute_type: vd.dtype = None) -> vd.Buffer: + + resolved_input_type = in_buffer.var_type if input_type is None else input_type + validate_complex_precision(resolved_input_type, arg_name="input_type") + ensure_supported_complex_precision(resolved_input_type, role="Input") + + resolved_output_type = ( + out_buffer.var_type if (out_buffer is not None and output_type is None) + else in_buffer.var_type if output_type is None + else output_type + ) + validate_complex_precision(resolved_output_type, arg_name="output_type") + ensure_supported_complex_precision(resolved_output_type, role="Output") + + resolved_compute_type = resolve_compute_precision( + [resolved_input_type, resolved_output_type], + compute_type, + ) + transposed_size = get_transposed_size( tuple(in_buffer.shape), - axis=axis + axis=axis, + compute_type=resolved_compute_type, ) if out_buffer is None: - out_buffer = vd.Buffer((transposed_size,), var_type=in_buffer.var_type) + out_buffer = vd.Buffer((transposed_size,), var_type=resolved_output_type) + else: + if out_buffer.var_type != resolved_output_type: + raise ValueError( + f"out_buffer type ({out_buffer.var_type.name}) does not match output_type ({resolved_output_type.name})" + ) assert out_buffer.size >= transposed_size, f"Output buffer size {out_buffer.size} does not match expected transposed size {transposed_size}" if conv_shape is None: conv_shape = in_buffer.shape - + transpose_shader = make_transpose_shader( tuple(conv_shape), axis=axis, - kernel_inner_only=kernel_inner_only + kernel_inner_only=kernel_inner_only, + input_type=resolved_input_type, + output_type=resolved_output_type, + compute_type=resolved_compute_type, ) if print_shader: @@ -244,4 +696,4 @@ def transpose( transpose_shader(out_buffer, in_buffer, graph=graph) - return out_buffer \ No newline at end of file + return out_buffer diff --git a/vkdispatch/fft/global_memory_iterators.py b/vkdispatch/fft/global_memory_iterators.py index e897846a..c621f6b6 100644 --- a/vkdispatch/fft/global_memory_iterators.py +++ b/vkdispatch/fft/global_memory_iterators.py @@ -7,6 +7,13 @@ from .registers import FFTRegisters from .memory_iterators import memory_reads_iterator, memory_writes_iterator, MemoryOp + +def _cast_if_needed(value: vc.ShaderVariable, dst_type): + if value.var_type == dst_type: + return value + + return vc.to_dtype(dst_type, value) + def global_batch_offset( registers: FFTRegisters, r2c: bool = False, @@ -57,7 +64,7 @@ def from_memory_op(cls, inverse=inverse) def write_to_buffer(self, - buffer: vc.Buff[vc.c64], + buffer: vc.Buffer, register: Optional[vc.ShaderVariable] = None, io_index: Optional[vc.ShaderVariable] = None): if register is None: @@ -67,21 +74,18 @@ def write_to_buffer(self, io_index = self.io_index if not self.r2c: - buffer[io_index] = register + buffer[io_index] = _cast_if_needed(register, buffer.var_type) return if not self.inverse: vc.if_statement(self.fft_index < (self.fft_size // 2) + 1) - buffer[io_index] = register + buffer[io_index] = _cast_if_needed(register, buffer.var_type) vc.end() return - packed_value = buffer[io_index // 2] - vc.if_statement((io_index % 2) == 0) - packed_value.real = register.real - vc.else_statement() - packed_value.imag = register.real - vc.end() + out_scalar_type = buffer.var_type.child_type + out_real = _cast_if_needed(register.real, out_scalar_type) + buffer[io_index // 2][io_index % 2] = out_real def global_writes_iterator( registers: FFTRegisters, @@ -171,11 +175,11 @@ def signal_range_end(self, register: vc.ShaderVariable): return vc.else_statement() - register[:] = vc.to_complex(0) + register[:] = vc.to_dtype(register.var_type, 0) vc.end() def read_from_buffer(self, - buffer: vc.Buff[vc.c64], + buffer: vc.Buffer, register: Optional[vc.ShaderVariable] = None, io_index: Optional[vc.ShaderVariable] = None): self.check_in_signal_range() @@ -187,26 +191,23 @@ def read_from_buffer(self, register = self.register if not self.r2c: - register[:] = buffer[io_index] + register[:] = _cast_if_needed(buffer[io_index], register.var_type) self.signal_range_end(register) return if not self.inverse: - packed_value = buffer[io_index // 2] - vc.if_statement((io_index % 2) == 0) - register[:] = vc.to_complex(packed_value.real) - vc.else_statement() - register[:] = vc.to_complex(packed_value.imag) - vc.end() + packed_real = buffer[io_index // 2][io_index % 2] + packed_complex = vc.to_complex(packed_real) + register[:] = _cast_if_needed(packed_complex, register.var_type) self.signal_range_end(register) return vc.if_statement(self.fft_index >= (self.fft_size // 2) + 1) self.io_index_2[:] = self.r2c_inverse_offset - io_index - register[:] = buffer[self.io_index_2] + register[:] = _cast_if_needed(buffer[self.io_index_2], register.var_type) register.imag = -register.imag vc.else_statement() - register[:] = buffer[io_index] + register[:] = _cast_if_needed(buffer[io_index], register.var_type) vc.end() self.signal_range_end(register) @@ -302,7 +303,7 @@ def from_memory_op(cls, ) def write_to_buffer(self, - buffer: vc.Buff[vc.c64], + buffer: vc.Buffer, register: Optional[vc.ShaderVariable] = None, io_index: Optional[vc.ShaderVariable] = None): if io_index is None: @@ -311,7 +312,7 @@ def write_to_buffer(self, if register is None: register = self.register - buffer[io_index] = register + buffer[io_index] = _cast_if_needed(register, buffer.var_type) def global_trasposed_write_iterator(registers: FFTRegisters, inner_only: bool = False): vc.comment("""Writing registers to global memory in transposed order. diff --git a/vkdispatch/fft/grid_manager.py b/vkdispatch/fft/grid_manager.py index 22d642af..5d6aa4e9 100644 --- a/vkdispatch/fft/grid_manager.py +++ b/vkdispatch/fft/grid_manager.py @@ -6,7 +6,7 @@ from .config import FFTConfig from .prime_utils import prime_factors -from .._compat import numpy_compat as npc +from ..compat import numpy_compat as npc def allocation_valid(workgroup_size: int, shared_memory_size: int): valid_workgroup = workgroup_size <= vd.get_context().max_workgroup_invocations @@ -16,11 +16,12 @@ def allocation_valid(workgroup_size: int, shared_memory_size: int): def allocate_inline_batches( batch_num: int, batch_threads: int, - N: int, + shared_elements: int, + element_size: int, max_workgroup_size: int, max_total_threads: int): - shared_memory_allocation = N * vd.complex64.item_size + shared_memory_allocation = shared_elements * element_size batch_num_primes = prime_factors(batch_num) prime_index = 0 workgroup_size = batch_threads @@ -157,6 +158,7 @@ def __init__(self, config: FFTConfig, force_sdata: bool = False, declare_variabl config.batch_inner_count, config.batch_threads, config.sdata_allocation if make_sdata_buffer else 0, + config.compute_type.item_size, min(vd.get_context().max_workgroup_size[0], 4), vd.get_context().max_workgroup_invocations) @@ -171,6 +173,7 @@ def __init__(self, config: FFTConfig, force_sdata: bool = False, declare_variabl config.batch_outer_count, config.batch_threads * self.inline_batches_inner, config.sdata_allocation * self.inline_batches_inner if make_sdata_buffer else 0, + config.compute_type.item_size, vd.get_context().max_workgroup_size[ 1 if self.inline_batches_inner == 1 else 2 ], diff --git a/vkdispatch/fft/io_manager.py b/vkdispatch/fft/io_manager.py index 1f54fc99..b91d6bd9 100644 --- a/vkdispatch/fft/io_manager.py +++ b/vkdispatch/fft/io_manager.py @@ -1,5 +1,6 @@ import vkdispatch as vd import vkdispatch.codegen as vc +import vkdispatch.base.dtype as dtypes from typing import Optional, Tuple @@ -55,11 +56,22 @@ def __init__(self, default_registers: FFTRegisters, shader_context: vd.ShaderContext, output_map: Optional[vd.MappingFunction], + output_type: dtypes.dtype = vd.complex64, + input_type: Optional[dtypes.dtype] = None, input_map: Optional[vd.MappingFunction] = None, kernel_map: Optional[vd.MappingFunction] = None): self.default_registers = default_registers - self.output_proxy = IOProxy(vd.complex64 if output_map is None else output_map, "Output") - self.input_proxy = IOProxy(input_map, "Input") + self.output_proxy = IOProxy(output_type if output_map is None else output_map, "Output") + + if input_map is not None: + self.input_proxy = IOProxy(input_map, "Input") + elif output_map is not None: + if input_type is None: + raise ValueError("input_type must be provided when output_map is used without input_map") + self.input_proxy = IOProxy(input_type, "Input") + else: + self.input_proxy = IOProxy(None, "Input") + self.kernel_proxy = IOProxy(kernel_map, "Kernel") output_types = self.output_proxy.buffer_types @@ -163,4 +175,4 @@ def read_kernel(self, registers: Optional[FFTRegisters] = None, format_transpose registers, format_transposed=format_transposed, inner_only=inner_only - ) \ No newline at end of file + ) diff --git a/vkdispatch/fft/memory_iterators.py b/vkdispatch/fft/memory_iterators.py index 4c85e046..a7793ab7 100644 --- a/vkdispatch/fft/memory_iterators.py +++ b/vkdispatch/fft/memory_iterators.py @@ -22,14 +22,14 @@ def memory_reads_iterator(resources: FFTResources, stage_index: int = 0): resources.stage_begin(stage_index) index_list = list(range(resources.config.register_count)) - invocations = resources.invocations[stage_index] + invocations = resources.config.stages[stage_index].invocations for ii, invocation in enumerate(invocations): resources.invocation_gaurd(stage_index, ii) register_indicies = index_list[invocation.register_selection] - offset = invocation.instance_id + offset = invocation.get_offset(resources.tid) stride = resources.config.N // resources.config.stages[stage_index].fft_length for i in range(len(register_indicies)): @@ -58,14 +58,14 @@ def memory_writes_iterator(resources: FFTResources, stage_index: int = -1): index_list = list(range(resources.config.register_count)) element_count = resources.config.stages[stage_index].fft_length - invocations = resources.invocations[stage_index] + invocations = resources.config.stages[stage_index].invocations for i in range(element_count): for ii, invocation in enumerate(invocations): resources.invocation_gaurd(stage_index, ii) - offset = invocation.sub_sequence_offset - stride = resources.output_strides[stage_index] + offset = invocation.get_sub_sequence_offset(resources.tid) + stride = resources.config.stages[stage_index].input_stride fft_index = offset + i * stride diff --git a/vkdispatch/fft/precision.py b/vkdispatch/fft/precision.py new file mode 100644 index 00000000..d9d6d640 --- /dev/null +++ b/vkdispatch/fft/precision.py @@ -0,0 +1,99 @@ +import vkdispatch as vd + +from typing import Iterable, List, Optional + + +_COMPLEX_PRECISION_ORDER = (vd.complex32, vd.complex64, vd.complex128) +_COMPLEX_PRECISION_RANK = {dtype: rank for rank, dtype in enumerate(_COMPLEX_PRECISION_ORDER)} + + +def is_complex_precision(dtype) -> bool: + return dtype in _COMPLEX_PRECISION_RANK + + +def validate_complex_precision(dtype, *, arg_name: str) -> None: + if not is_complex_precision(dtype): + raise ValueError(f"{arg_name} must be one of complex32, complex64, or complex128 (got {dtype})") + + +def promote_complex_precisions(dtypes: Iterable) -> vd.dtype: + candidates = list(dtypes) + if len(candidates) == 0: + raise ValueError("At least one complex dtype is required for promotion") + + for candidate in candidates: + validate_complex_precision(candidate, arg_name="dtype") + + return max(candidates, key=lambda dtype: _COMPLEX_PRECISION_RANK[dtype]) + + +def default_compute_precision(io_precisions: Iterable) -> vd.dtype: + promoted = promote_complex_precisions(io_precisions) + + # Default to at least complex64 for numerical stability. + if _COMPLEX_PRECISION_RANK[promoted] < _COMPLEX_PRECISION_RANK[vd.complex64]: + return vd.complex64 + + return promoted + + +def supports_complex_precision(dtype) -> bool: + validate_complex_precision(dtype, arg_name="dtype") + scalar_type = dtype.child_type + + for device in vd.get_context().device_infos: + if scalar_type == vd.float16: + if device.float_16_support != 1: + return False + + # Half precision in storage buffers typically needs one of these capabilities. + if ( + device.storage_buffer_16_bit_access != 1 + and device.uniform_and_storage_buffer_16_bit_access != 1 + ): + return False + + if scalar_type == vd.float64 and device.float_64_support != 1: + return False + + return True + + +def ensure_supported_complex_precision(dtype, *, role: str) -> None: + if not supports_complex_precision(dtype): + raise ValueError(f"{role} precision '{dtype.name}' is not supported on the active device set") + + +def resolve_compute_precision(io_precisions: List, compute_precision: Optional[vd.dtype]) -> vd.dtype: + if compute_precision is not None: + validate_complex_precision(compute_precision, arg_name="compute_type") + ensure_supported_complex_precision(compute_precision, role="Compute") + return compute_precision + + for io_precision in io_precisions: + validate_complex_precision(io_precision, arg_name="io_precision") + + if len(io_precisions) == 0: + for candidate in (vd.complex64, vd.complex32): + if supports_complex_precision(candidate): + return candidate + + raise ValueError( + "Unable to resolve a default compute precision supported by all active devices" + ) + + target = default_compute_precision(io_precisions) + if supports_complex_precision(target): + return target + + # Auto fallback: drop from complex128 to complex64 when fp64 is unsupported. + for candidate in (vd.complex64, vd.complex32): + if ( + _COMPLEX_PRECISION_RANK[candidate] <= _COMPLEX_PRECISION_RANK[target] + and supports_complex_precision(candidate) + ): + return candidate + + raise ValueError( + "Unable to resolve an auto compute precision supported by all active devices" + ) diff --git a/vkdispatch/fft/prime_utils.py b/vkdispatch/fft/prime_utils.py index 2db85020..ee1624fa 100644 --- a/vkdispatch/fft/prime_utils.py +++ b/vkdispatch/fft/prime_utils.py @@ -1,13 +1,10 @@ from typing import List import vkdispatch as vd -from .._compat import numpy_compat as npc +from ..compat import numpy_compat as npc def default_register_limit(): - if vd.get_devices()[0].is_nvidia(): - return 16 - - return 15 + return 16 def default_max_prime(): return 13 diff --git a/vkdispatch/fft/registers.py b/vkdispatch/fft/registers.py index 6fe671b3..31c79e32 100644 --- a/vkdispatch/fft/registers.py +++ b/vkdispatch/fft/registers.py @@ -32,7 +32,7 @@ def __init__(self, resources: FFTResources, count: int, name: str): self.config = resources.config self.registers = [ - vc.new_complex_register(var_name=f"{name}_reg_{i}") for i in range(count) + vc.new_register(self.config.compute_type, var_name=f"{name}_reg_{i}") for i in range(count) ] self.count = count @@ -53,40 +53,13 @@ def __setitem__(self, index: int, value: vc.ShaderVariable): self.registers[index][:] = value def normalize(self): + normalization = vc.to_dtype(self.config.compute_type.child_type, self.config.N) for i in range(self.count): - self.registers[i][:] = self.registers[i] / self.config.N - - def get_input_format(self, stage_index: int = 0) -> Dict[int, int]: - in_format = {} - - stride = self.config.N // self.config.stages[stage_index].fft_length - - register_count = len(self.registers) - register_index_list = list(range(register_count)) - - for invocation in self.resources.invocations[stage_index]: - sub_registers = register_index_list[invocation.register_selection] - - for i in range(len(sub_registers)): - in_format[invocation.get_read_index(stride * i)] = sub_registers[i] - - return in_format - - def get_output_format(self, stage_index: int = -1) -> Dict[int, int]: - out_format = {} - - register_count = len(self.registers) - register_index_list = list(range(register_count)) - - for jj in range(self.config.stages[stage_index].fft_length): - for invocation in self.resources.invocations[stage_index]: - out_format[invocation.get_write_index(jj)] = register_index_list[invocation.register_selection][jj] - - return out_format + self.registers[i][:] = self.registers[i] / normalization def try_shuffle(self, output_stage: int = -1, input_stage: int = 0) -> bool: - out_format = self.get_output_format(output_stage) - in_format = self.get_input_format(input_stage) + out_format = self.config.stages[output_stage].get_output_format(len(self.registers)) + in_format = self.config.stages[input_stage].get_input_format(len(self.registers)) if out_format.keys() != in_format.keys(): return False diff --git a/vkdispatch/fft/resources.py b/vkdispatch/fft/resources.py index 17b2085d..f63bd04e 100644 --- a/vkdispatch/fft/resources.py +++ b/vkdispatch/fft/resources.py @@ -8,59 +8,6 @@ from .config import FFTConfig from .grid_manager import FFTGridManager -@dataclasses.dataclass -class FFTRegisterStageInvocation: - output_stride: int - block_width: int - inner_block_offset: vc.ShaderVariable - sub_sequence_offset: vc.ShaderVariable - register_selection: slice - - instance_id: vc.ShaderVariable - - instance_id0: int - inner_block_offset0: int - sub_sequence_offset0: int - - def __init__(self, - stage_fft_length: int, - stage_instance_count: int, - output_stride: int, - instance_index: int, - tid: vc.ShaderVariable, - N: int): - self.output_stride = output_stride - - self.block_width = output_stride * stage_fft_length - - instance_index_stride = N // (stage_fft_length * stage_instance_count) - - self.instance_id = tid + instance_index_stride * instance_index - - self.inner_block_offset = self.instance_id % output_stride - - if output_stride == 1: - self.inner_block_offset = 0 - - self.sub_sequence_offset = self.instance_id * stage_fft_length - self.inner_block_offset * (stage_fft_length - 1) - - # pretend tid is 0, used for calculating register shuffles - self.instance_id0 = instance_index_stride * instance_index - self.inner_block_offset0 = self.instance_id0 % output_stride - self.sub_sequence_offset0 = self.instance_id0 * stage_fft_length - self.inner_block_offset0 * (stage_fft_length - 1) - - if self.block_width == N: - self.inner_block_offset = self.instance_id - self.sub_sequence_offset = self.inner_block_offset - - self.register_selection = slice(instance_index * stage_fft_length, (instance_index + 1) * stage_fft_length) - - def get_write_index(self, fft_index: int): - return self.sub_sequence_offset0 + fft_index * self.output_stride - - def get_read_index(self, offset: int): - return self.instance_id0 + offset - @dataclasses.dataclass class FFTResources: input_batch_offset: vc.ShaderVariable @@ -78,49 +25,21 @@ class FFTResources: config: FFTConfig - output_strides: List[int] - invocations: List[List[FFTRegisterStageInvocation]] - def __init__(self, config: FFTConfig, grid: FFTGridManager): self.tid = grid.tid self.grid = grid self.config = config self.input_batch_offset = vc.new_uint_register(var_name="input_batch_offset") self.output_batch_offset = vc.new_uint_register(var_name="output_batch_offset") - self.omega_register = vc.new_complex_register(var_name="omega_register") + self.omega_register = vc.new_register(config.compute_type, var_name="omega_register") self.subsequence_offset = vc.new_uint_register(var_name="subsequence_offset") self.io_index = vc.new_uint_register(var_name="io_index") self.io_index_2 = vc.new_uint_register(var_name="io_index_2") self.radix_registers = [ - vc.new_complex_register(var_name=f"radix_register_{i}") for i in range(config.max_prime_radix) + vc.new_register(config.compute_type, var_name=f"radix_register_{i}") for i in range(config.max_prime_radix) ] - self.output_strides = [] - self.invocations = [] - - output_stride = 1 - stage_count = len(config.stages) - - for i in range(stage_count): - stage = config.stages[i] - stage_invocations = [] - - for ii in range(stage.instance_count): - stage_invocations.append(FFTRegisterStageInvocation( - stage.fft_length, - stage.instance_count, - output_stride, - ii, - self.tid, - config.N - )) - - self.output_strides.append(output_stride) - self.invocations.append(stage_invocations) - - output_stride *= config.stages[i].fft_length - def stage_begin(self, stage_index: int): thread_count = self.config.stages[stage_index].thread_count @@ -144,4 +63,3 @@ def invocation_end(self, stage_index: int): if stage.remainder_offset == 1: vc.end() - diff --git a/vkdispatch/fft/sdata_manager.py b/vkdispatch/fft/sdata_manager.py index f7e41fa7..d00ff31e 100644 --- a/vkdispatch/fft/sdata_manager.py +++ b/vkdispatch/fft/sdata_manager.py @@ -11,7 +11,7 @@ from .memory_iterators import memory_reads_iterator, memory_writes_iterator class FFTSDataManager: - sdata: vc.Buff[vc.c64] + sdata: vc.Buffer sdata_offset: Union[vc.Const[vc.u32], Literal[0]] sdata_row_size: int @@ -46,7 +46,7 @@ def __init__(self, config: FFTConfig, grid: FFTGridManager, default_registers: F total_inner_batches = grid.inline_batches_inner * grid.inline_batches_outer self.sdata = vc.shared_buffer( - vd.complex64, + config.compute_type, config.sdata_allocation * total_inner_batches, var_name="sdata") @@ -90,7 +90,7 @@ def read_from_sdata(self, registers: Optional[FFTRegisters] = None, stage_index: def write_to_sdata(self, registers: Optional[FFTRegisters] = None, stage_index: int = -1): self.op_write() - self.use_padding = self.padding_enabled and self.resources.output_strides[stage_index] < 32 + self.use_padding = self.padding_enabled and self.resources.config.stages[stage_index].input_stride < 32 if registers is None: registers = self.default_registers @@ -101,4 +101,4 @@ def write_to_sdata(self, registers: Optional[FFTRegisters] = None, stage_index: if self.use_padding: self.resources.io_index[:] = self.resources.io_index + (self.resources.io_index // self.sdata_row_size) - self.sdata[self.resources.io_index] = registers[write_op.register_id] \ No newline at end of file + self.sdata[self.resources.io_index] = registers[write_op.register_id] diff --git a/vkdispatch/fft/shader_factories.py b/vkdispatch/fft/shader_factories.py index 7ccf92c7..28a481fd 100644 --- a/vkdispatch/fft/shader_factories.py +++ b/vkdispatch/fft/shader_factories.py @@ -2,7 +2,7 @@ import vkdispatch.codegen as vc from vkdispatch.codegen.abreviations import * -from .._compat import numpy_compat as npc +from ..compat import numpy_compat as npc from typing import Tuple, Optional from functools import lru_cache @@ -17,12 +17,28 @@ def make_fft_shader( r2c: bool = False, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None, + input_type: vd.dtype = None, + output_type: vd.dtype = None, + compute_type: vd.dtype = None, input_signal_range: Optional[Tuple[Optional[int], Optional[int]]] = None) -> vd.ShaderFunction: - with vd.fft.fft_context(buffer_shape, axis=axis) as ctx: + if output_type is None: + output_type = vd.complex64 + + if input_type is None and input_map is None: + input_type = output_type + + if compute_type is None: + compute_type = vd.complex64 + + name = f"fft_shader_{buffer_shape}_{axis}_{inverse}_{normalize_inverse}_{r2c}" + + with vd.fft.fft_context(buffer_shape, axis=axis, compute_type=compute_type, name=name) as ctx: io_manager = ctx.make_io_manager( input_map=input_map, - output_map=output_map + output_map=output_map, + output_type=output_type, + input_type=input_type, ) io_manager.read_input( @@ -46,9 +62,10 @@ def make_fft_shader( @lru_cache(maxsize=None) def get_transposed_size( buffer_shape: Tuple, - axis: int = None) -> vd.ShaderFunction: + axis: int = None, + compute_type: vd.dtype = vd.complex64) -> vd.ShaderFunction: - config = vd.fft.FFTConfig(buffer_shape, axis) + config = vd.fft.FFTConfig(buffer_shape, axis, compute_type=compute_type) grid = vd.fft.FFTGridManager(config, True, False) return npc.prod(grid.local_size) * npc.prod(grid.workgroup_count) * config.register_count @@ -57,10 +74,13 @@ def get_transposed_size( def make_transpose_shader( buffer_shape: Tuple, axis: int = None, - kernel_inner_only: bool = False) -> vd.ShaderFunction: + kernel_inner_only: bool = False, + input_type: vd.dtype = vd.complex64, + output_type: vd.dtype = vd.complex64, + compute_type: vd.dtype = vd.complex64) -> vd.ShaderFunction: - with vd.fft.fft_context(buffer_shape, axis=axis) as ctx: - args = ctx.declare_shader_args([vc.Buffer[c64], vc.Buffer[c64]]) + with vd.fft.fft_context(buffer_shape, axis=axis, compute_type=compute_type) as ctx: + args = ctx.declare_shader_args([vc.Buffer[output_type], vc.Buffer[input_type]]) if kernel_inner_only: vc.if_statement(ctx.grid.global_outer_offset == 0) @@ -95,23 +115,43 @@ def make_convolution_shader( kernel_inner_only: bool = False, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None, + input_type: vd.dtype = None, + output_type: vd.dtype = None, + kernel_type: vd.dtype = None, + compute_type: vd.dtype = None, input_signal_range: Optional[Tuple[Optional[int], Optional[int]]] = None) -> vd.ShaderFunction: + if output_type is None: + output_type = vd.complex64 + + if input_type is None and input_map is None: + input_type = output_type + + if kernel_type is None: + kernel_type = vd.complex64 + + if compute_type is None: + compute_type = vd.complex64 + if kernel_map is None: - def kernel_map_func(kernel_buffer: vc.Buffer[c64]): + def kernel_map_func(kernel_buffer: vc.Buffer[kernel_type]): read_op = vd.fft.read_op() - kernel_val = vc.new_complex_register() + kernel_val = vc.new_register(compute_type) read_op.read_from_buffer(kernel_buffer, register=kernel_val) read_op.register[:] = vc.mult_complex(read_op.register, kernel_val.conjugate()) - kernel_map = vd.map(kernel_map_func, input_types=[vc.Buffer[c64]]) + kernel_map = vd.map(kernel_map_func, input_types=[vc.Buffer[kernel_type]]) + + name = f"convolution_shader_{buffer_shape}_{axis}" - with vd.fft.fft_context(buffer_shape, axis=axis) as ctx: + with vd.fft.fft_context(buffer_shape, axis=axis, compute_type=compute_type, name=name) as ctx: io_manager = ctx.make_io_manager( input_map=input_map, output_map=output_map, + output_type=output_type, + input_type=input_type, kernel_map=kernel_map ) @@ -123,10 +163,6 @@ def kernel_map_func(kernel_buffer: vc.Buffer[c64]): ctx.execute(inverse=False) ctx.register_shuffle() - vc.comment("""Convolution pipeline phase 2/3. -Apply one or more frequency-domain kernels to the transformed input spectrum. -For multi-kernel runs, restore from backup registers so each kernel sees -identical FFT-domain source values before inverse transformation.""") backup_registers = None if kernel_num > 1: @@ -134,17 +170,19 @@ def kernel_map_func(kernel_buffer: vc.Buffer[c64]): backup_registers.read_from_registers(ctx.registers) for kern_index in range(kernel_num): - vc.comment(f"""Convolution pipeline phase 3/3. Kernel {kern_index + 1}/{kernel_num}. -Map this kernel onto the current spectrum. -Run inverse FFT back to the spatial domain, optionally normalize by length, -and write this kernel's output slice to global memory.""") + vc.comment(f"""Convolution pipeline phase 2/3. Kernel {kern_index + 1}/{kernel_num}. +Map this kernel onto the current spectrum.""") if backup_registers is not None: ctx.registers.read_from_registers(backup_registers) set_global_kernel_index(kern_index) io_manager.read_kernel(format_transposed=transposed_kernel, inner_only=kernel_inner_only) - + + vc.comment(f"""Convolution pipeline phase 3/3. +Run inverse FFT back to the spatial domain, optionally normalize by length, +and write this kernel's output slice to global memory.""") + ctx.execute(inverse=True) if normalize: diff --git a/vkdispatch/fft/src_functions.py b/vkdispatch/fft/src_functions.py new file mode 100644 index 00000000..e8952bb3 --- /dev/null +++ b/vkdispatch/fft/src_functions.py @@ -0,0 +1,342 @@ +import vkdispatch as vd + +from .shader_factories import make_fft_shader, make_convolution_shader, make_transpose_shader, get_transposed_size + +from typing import Tuple, Union, Optional + +def fft_src( + buffer_shape: Tuple, + axis: int = None, + inverse: bool = False, + normalize_inverse: bool = True, + r2c: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None, + line_numbers: bool = False) -> vd.ShaderSource: + + fft_shader = make_fft_shader( + tuple(buffer_shape), + axis, + inverse=inverse, + normalize_inverse=normalize_inverse, + r2c=r2c, + input_map=input_map, + output_map=output_map, + input_signal_range=input_signal_range) + + return fft_shader.get_src(line_numbers=line_numbers) + +def fft2_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer Shape must have 2 or 3 dimensions' + + return ( + fft_src(axis=len(buffer_shape) - 2, input_map=input_map), + fft_src(axis=len(buffer_shape) - 1, output_map=output_map) + ) + +def fft3_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions' + + return ( + fft_src(buffer_shape, axis=0, input_map=input_map), + fft_src(buffer_shape, axis=1), + fft_src(buffer_shape, axis=2, output_map=output_map) + ) + + +def ifft_src( + buffer_shape: Tuple, + axis: int = None, + normalize: bool = True, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None): + return fft_src(buffer_shape, axis=axis, inverse=True, normalize_inverse=normalize, input_map=input_map, output_map=output_map) + +def ifft2_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + + return ( + ifft_src(buffer_shape, axis=len(buffer_shape) - 2, normalize=normalize, input_map=input_map), + ifft_src(buffer_shape, axis=len(buffer_shape) - 1, normalize=normalize, output_map=output_map) + ) + +def ifft3_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions' + + return ( + ifft_src(buffer_shape, axis=0, normalize=normalize, input_map=input_map), + ifft_src(buffer_shape, axis=1, normalize=normalize), + ifft_src(buffer_shape, axis=2, normalize=normalize, output_map=output_map) + ) + + +def rfft_src(buffer_shape: Tuple): + return fft_src(buffer_shape, r2c=True) + +def rfft2_src(buffer_shape: Tuple): + assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + + return ( + rfft_src(buffer_shape), + fft_src(buffer_shape, axis=len(buffer_shape) - 2) + ) + +def rfft3_src(buffer_shape: Tuple): + assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions' + + return ( + rfft_src(buffer_shape), + fft_src(buffer_shape, axis=1), + fft_src(buffer_shape, axis=0) + ) + +def irfft_src(buffer_shape: Tuple, normalize: bool = True): + return fft_src(buffer_shape, inverse=True, normalize_inverse=normalize, r2c=True) + +def irfft2_src(buffer_shape: Tuple, normalize: bool = True): + assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + + return ( + ifft_src(buffer_shape, axis=len(buffer_shape) - 2, normalize=normalize), + irfft_src(buffer_shape, normalize=normalize) + ) + +def irfft3_src(buffer_shape: Tuple, normalize: bool = True): + assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions' + + return ( + ifft_src(buffer_shape, axis=0, normalize=normalize), + ifft_src(buffer_shape, axis=1, normalize=normalize), + irfft_src(buffer_shape, normalize=normalize) + ) + +def convolve_src( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + kernel_num: int = 1, + axis: int = None, + normalize: bool = True, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None, + line_numbers: bool = False) -> vd.ShaderSource: + + fft_shader = make_convolution_shader( + tuple(buffer_shape), + kernel_map, + kernel_num, + axis, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + normalize=normalize, + input_map=input_map, + output_map=output_map, + input_signal_range=input_signal_range) + + return fft_shader.get_src(line_numbers=line_numbers) + +def convolve2D_src( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + normalize: bool = True, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None): + + assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + + return ( + fft_src(buffer_shape, input_map=input_map), + convolve_src( + buffer_shape, + kernel_map=kernel_map, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + axis=len(buffer_shape) - 2, + normalize=normalize + ), + ifft_src(buffer_shape, normalize=normalize, output_map=output_map) + ) + +def convolve2DR_src( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + normalize: bool = True): + + assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + + return ( + rfft_src(buffer_shape), + convolve_src( + buffer_shape, + kernel_map=kernel_map, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + axis=len(buffer_shape) - 2, + normalize=normalize + ), + irfft_src(buffer_shape, normalize=normalize) + ) + +def transpose_src( + buffer_shape: Tuple, + axis: int = None, + kernel_inner_only: bool = False, + line_numbers: bool = False) -> vd.Buffer: + + transpose_shader = make_transpose_shader( + tuple(buffer_shape), + axis=axis, + kernel_inner_only=kernel_inner_only + ) + + return transpose_shader.get_src(line_numbers=line_numbers) + + +def fft_print_src( + buffer_shape: Tuple, + axis: int = None, + inverse: bool = False, + normalize_inverse: bool = True, + r2c: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None, + line_numbers: bool = False) -> vd.ShaderSource: + + print(fft_src( + buffer_shape, + axis, + inverse=inverse, + normalize_inverse=normalize_inverse, + r2c=r2c, + input_map=input_map, + output_map=output_map, + input_signal_range=input_signal_range, + line_numbers=line_numbers)) + +def fft2_print_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + srcs = fft2_src(buffer_shape, input_map=input_map, output_map=output_map) + print(f"// FFT Stage 1 (axis {len(buffer_shape) - 2}):\n{srcs[0]}\n// FFT Stage 2 (axis {len(buffer_shape) - 1}):\n{srcs[1]}") + +def fft3_print_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + srcs = fft3_src(buffer_shape, input_map=input_map, output_map=output_map) + print(f"// FFT Stage 1 (axis 0):\n{srcs[0]}\n// FFT Stage 2 (axis 1):\n{srcs[1]}\n// FFT Stage 3 (axis 2):\n{srcs[2]}") + +def ifft_print_src( + buffer_shape: Tuple, + axis: int = None, + normalize: bool = True, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None): + print(ifft_src(buffer_shape, axis=axis, normalize=normalize, input_map=input_map, output_map=output_map)) + +def ifft2_print_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + srcs = ifft2_src(buffer_shape, normalize=normalize, input_map=input_map, output_map=output_map) + print(f"// IFFT Stage 1 (axis {len(buffer_shape) - 2}):\n{srcs[0]}\n// IFFT Stage 2 (axis {len(buffer_shape) - 1}):\n{srcs[1]}") + +def ifft3_print_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + srcs = ifft3_src(buffer_shape, normalize=normalize, input_map=input_map, output_map=output_map) + print(f"// IFFT Stage 1 (axis 0):\n{srcs[0]}\n// IFFT Stage 2 (axis 1):\n{srcs[1]}\n// IFFT Stage 3 (axis 2):\n{srcs[2]}") + +def rfft_print_src(buffer_shape: Tuple): + print(rfft_src(buffer_shape)) + +def rfft2_print_src(buffer_shape: Tuple): + srcs = rfft2_src(buffer_shape) + print(f"// RFFT Stage 1:\n{srcs[0]}\n// RFFT Stage 2 (axis {len(buffer_shape) - 2}):\n{srcs[1]}") + +def rfft3_print_src(buffer_shape: Tuple): + srcs = rfft3_src(buffer_shape) + print(f"// RFFT Stage 1:\n{srcs[0]}\n// RFFT Stage 2 (axis 1):\n{srcs[1]}\n// RFFT Stage 3 (axis 0):\n{srcs[2]}") + +def irfft_print_src(buffer_shape: Tuple, normalize: bool = True): + print(irfft_src(buffer_shape, normalize=normalize)) + +def irfft2_print_src(buffer_shape: Tuple, normalize: bool = True): + srcs = irfft2_src(buffer_shape, normalize=normalize) + print(f"// IRFFT Stage 1 (axis {len(buffer_shape) - 2}):\n{srcs[0]}\n// IRFFT Stage 2:\n{srcs[1]}") + +def irfft3_print_src(buffer_shape: Tuple, normalize: bool = True): + srcs = irfft3_src(buffer_shape, normalize=normalize) + print(f"// IRFFT Stage 1 (axis 0):\n{srcs[0]}\n// IRFFT Stage 2 (axis 1):\n{srcs[1]}\n// IRFFT Stage 3:\n{srcs[2]}") + +def convolve_print_src( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + kernel_num: int = 1, + axis: int = None, + normalize: bool = True, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None, + line_numbers: bool = False) -> vd.ShaderSource: + + print(convolve_src( + buffer_shape, + kernel_map=kernel_map, + kernel_num=kernel_num, + axis=axis, + normalize=normalize, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + input_map=input_map, + output_map=output_map, + input_signal_range=input_signal_range, + line_numbers=line_numbers + )) + +def convolve2D_print_src( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + normalize: bool = True, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None): + srcs = convolve2D_src( + buffer_shape, + kernel_map=kernel_map, + normalize=normalize, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + input_map=input_map, + output_map=output_map + ) + print(f"// FFT Stage (axis {len(buffer_shape) - 2}):\n{srcs[0]}\n// Convolution Stage (axis {len(buffer_shape) - 2}):\n{srcs[1]}\n// IFFT Stage:\n{srcs[2]}") + +def convolve2DR_print_src( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + normalize: bool = True): + srcs = convolve2DR_src( + buffer_shape, + kernel_map=kernel_map, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + normalize=normalize + ) + print(f"// RFFT Stage:\n{srcs[0]}\n// Convolution Stage (axis {len(buffer_shape) - 2}):\n{srcs[1]}\n// IRFFT Stage:\n{srcs[2]}") + +def transpose_print_src( + buffer_shape: Tuple, + axis: int = None, + kernel_inner_only: bool = False, + line_numbers: bool = False) -> vd.Buffer: + + print(transpose_src( + buffer_shape, + axis=axis, + kernel_inner_only=kernel_inner_only, + line_numbers=line_numbers + )) \ No newline at end of file diff --git a/vkdispatch/fft/stages.py b/vkdispatch/fft/stages.py new file mode 100644 index 00000000..0cb348fd --- /dev/null +++ b/vkdispatch/fft/stages.py @@ -0,0 +1,198 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc +import dataclasses +from typing import List, Tuple, Dict + +from ..compat import numpy_compat as npc +from .prime_utils import prime_factors + +@dataclasses.dataclass +class FFTStagePlanInvocation: + fft_length: int + input_stride: int + instance_index: int + instance_index_stride: int + block_width: int + full_width_block: bool + instance_id0: int + inner_block_offset0: int + sub_sequence_offset0: int + register_selection: slice + + def __init__(self, + stage_fft_length: int, + stage_instance_count: int, + input_stride: int, + instance_index: int, + N: int): + self.fft_length = stage_fft_length + self.input_stride = input_stride + self.instance_index = instance_index + self.block_width = input_stride * stage_fft_length + self.instance_index_stride = N // (stage_fft_length * stage_instance_count) + + self.full_width_block = self.block_width == N + + # pretend tid is 0, used for calculating register shuffles + self.instance_id0 = self.instance_index_stride * instance_index + self.inner_block_offset0 = self.instance_id0 % input_stride + self.sub_sequence_offset0 = self.instance_id0 * stage_fft_length - self.inner_block_offset0 * (stage_fft_length - 1) + + self.register_selection = slice(instance_index * stage_fft_length, (instance_index + 1) * stage_fft_length) + + def get_offset(self, tid: vc.ShaderVariable): + return tid + self.instance_index_stride * self.instance_index + + def get_inner_block_offset(self, tid: vc.ShaderVariable): + if self.input_stride == 1: + return 0 + + if self.full_width_block: + return self.get_offset(tid) + + return self.get_offset(tid) % self.input_stride + + def get_sub_sequence_offset(self, tid: vc.ShaderVariable): + if self.full_width_block: + return self.get_offset(tid) + + return self.get_offset(tid) * self.fft_length - self.get_inner_block_offset(tid) * (self.fft_length - 1) + + def get_write_index(self, fft_index: int): + return self.sub_sequence_offset0 + fft_index * self.input_stride + + def get_read_index(self, offset: int): + return self.instance_id0 + offset + +@dataclasses.dataclass +class FFTRegisterStageConfig: + """ + Configuration for an FFT register stage. + + Attributes: + + primes (Tuple[int]): The prime numbers used for factorization. + fft_length (int): The length of each FFT stage. + instance_count (int): The number of instances required to achieve the desired level of parallelism. + registers_used (int): The total number of registers used by the FFT stage. + remainder (int): The remainder of `N` divided by `registers_used`. + remainder_offset (int): A flag indicating whether the remainder is non-zero. + extra_ffts (int): The additional number of FFT stages required to process the remainder. + thread_count (int): The total number of threads used in the computation. + sdata_size (int): The size of the shared memory buffer used to store intermediate results. + sdata_width (int): The width of each element in the shared memory buffer. + sdata_width_padded (int): The padded width of each element in the shared memory buffer. + + """ + + N: int + primes: Tuple[int] + fft_length: int + instance_count: int + registers_used: int + remainder: int + remainder_offset: int + extra_ffts: int + thread_count: int + sdata_size: int + sdata_width: int + sdata_width_padded: int + input_stride: int + output_stride: int + invocations: Tuple[FFTStagePlanInvocation] + + def __init__(self, primes: List[int], + max_register_count: int, + N: int, + compute_item_size: int, + input_stride: int): + """ + Initializes the FFTRegisterStageConfig object. + + Parameters: + + primes (List[int]): The prime numbers to use for factorization. + max_register_count (int): The maximum number of registers allowed per thread. + N (int): The length of the input data. + + """ + self.N = N + self.primes = tuple(primes) + self.input_stride = input_stride + self.fft_length = int(round(npc.prod(primes))) + self.output_stride = self.input_stride * self.fft_length + instance_primes = prime_factors(N // self.fft_length) + + self.instance_count = 1 + + while len(instance_primes) > 0: + if self.instance_count * self.fft_length * instance_primes[0] > max_register_count: + break + self.instance_count *= instance_primes[0] + instance_primes = instance_primes[1:] + + self.registers_used = self.fft_length * self.instance_count + + self.remainder = N % self.registers_used + assert self.remainder % self.fft_length == 0, "Remainder must be divisible by the FFT length" + self.remainder_offset = 1 if self.remainder != 0 else 0 + self.extra_ffts = self.remainder // self.fft_length + + self.thread_count = N // self.registers_used + self.remainder_offset + + self.sdata_width = self.registers_used + + threads_primes = prime_factors(self.thread_count) + + while self.sdata_width < 16 and len(threads_primes) > 0: + self.sdata_width *= threads_primes[0] + threads_primes = threads_primes[1:] + + self.sdata_width_padded = self.sdata_width + + if self.sdata_width_padded % 2 == 0: + self.sdata_width_padded += 1 + + self.sdata_size = self.sdata_width_padded * int(npc.prod(threads_primes)) + + if self.sdata_size > vd.get_context().max_shared_memory // compute_item_size: + self.sdata_width_padded = self.sdata_width + self.sdata_size = self.sdata_width_padded * int(npc.prod(threads_primes)) + + invocations = [] + for instance_index in range(self.instance_count): + invocations.append(FFTStagePlanInvocation( + stage_fft_length=self.fft_length, + stage_instance_count=self.instance_count, + input_stride=input_stride, + instance_index=instance_index, + N=N + )) + + self.invocations = tuple(invocations) + + def get_input_format(self, register_count: int) -> Dict[int, int]: + in_format = {} + + stride = self.N // self.fft_length + + register_index_list = list(range(register_count)) + + for invocation in self.invocations: + sub_registers = register_index_list[invocation.register_selection] + + for i in range(len(sub_registers)): + in_format[invocation.get_read_index(stride * i)] = sub_registers[i] + + return in_format + + def get_output_format(self, register_count: int) -> Dict[int, int]: + out_format = {} + + register_index_list = list(range(register_count)) + + for jj in range(self.fft_length): + for invocation in self.invocations: + out_format[invocation.get_write_index(jj)] = register_index_list[invocation.register_selection][jj] + + return out_format \ No newline at end of file diff --git a/vkdispatch/reduce/operations.py b/vkdispatch/reduce/operations.py index 9cabb583..4081982b 100644 --- a/vkdispatch/reduce/operations.py +++ b/vkdispatch/reduce/operations.py @@ -7,6 +7,8 @@ from typing import Union from typing import Optional + + @dataclasses.dataclass class ReduceOp: name: str @@ -31,14 +33,14 @@ class ReduceOp: SubgroupMin = ReduceOp( name="min", reduction=lambda x, y: vc.min(x, y), - identity=vc.inf_f32, + identity="inf", subgroup_reduction=vc.subgroup_min ) SubgroupMax = ReduceOp( name="max", reduction=lambda x, y: vc.max(x, y), - identity=vc.ninf_f32, + identity="-inf", subgroup_reduction=vc.subgroup_max ) @@ -61,4 +63,4 @@ class ReduceOp: reduction=lambda x, y: x ^ y, identity=0, subgroup_reduction=vc.subgroup_xor -) \ No newline at end of file +) diff --git a/vkdispatch/reduce/reduce_function.py b/vkdispatch/reduce/reduce_function.py index 6691b141..e8438498 100644 --- a/vkdispatch/reduce/reduce_function.py +++ b/vkdispatch/reduce/reduce_function.py @@ -6,7 +6,7 @@ from typing import List, Optional -from .._compat import numpy_compat as npc +from ..compat import numpy_compat as npc class ReduceFunction: def __init__(self, @@ -49,11 +49,26 @@ def make_stages(self): self.group_size, True, ) + + def get_src(self, line_numbers: bool = None) -> str: + self.make_stages() + + return [ + self.stage1.get_src(line_numbers), + self.stage2.get_src(line_numbers) + ] + + def print_src(self, line_numbers: bool = None): + srcs = self.get_src(line_numbers) + + print(f"// Reduction Stage 1:\n{srcs[0]}\n// Reduction Stage 2:\n{srcs[1]}") def __repr__(self) -> str: self.make_stages() - return f"Stage 1:\n{self.stage1}\nStage 2:\n{self.stage2}" + srcs = self.get_src() + + return f"// Reduction Stage 1:\n{srcs[0]}\n// Reduction Stage 2:\n{srcs[1]}" def __call__(self, *args, **kwargs) -> vd.Buffer: self.make_stages() diff --git a/vkdispatch/reduce/stage.py b/vkdispatch/reduce/stage.py index a9c91770..3bce6759 100644 --- a/vkdispatch/reduce/stage.py +++ b/vkdispatch/reduce/stage.py @@ -8,16 +8,16 @@ @dataclasses.dataclass class ReductionParams: - input_offset: vd.int32 - input_size: vd.int32 - input_stride: vd.int32 - input_y_batch_stride: vd.int32 - input_z_batch_stride: vd.int32 + input_offset: vd.uint32 + input_size: vd.uint32 + input_stride: vd.uint32 + input_y_batch_stride: vd.uint32 + input_z_batch_stride: vd.uint32 - output_offset: vd.int32 - output_stride: vd.int32 - output_y_batch_stride: vd.int32 - output_z_batch_stride: vd.int32 + output_offset: vd.uint32 + output_stride: vd.uint32 + output_y_batch_stride: vd.uint32 + output_z_batch_stride: vd.uint32 __static_global_io_index: vc.ShaderVariable = None @@ -36,7 +36,14 @@ def global_reduce( map_func: Optional[vd.MappingFunction] = None): ind = (vc.global_invocation_id().x * params.input_stride).to_register("ind") - reduction_aggregate = vc.new_register(out_type, reduction.identity, var_name="reduction_aggregate") + + reduction_identity = reduction.identity + if reduction_identity == "inf": + reduction_identity = vc.inf_f32() if out_type == vd.float32 else vc.inf_f64() + elif reduction_identity == "-inf": + reduction_identity = vc.ninf_f32() if out_type == vd.float32 else vc.ninf_f64() + + reduction_aggregate = vc.new_register(out_type, reduction_identity, var_name="reduction_aggregate") batch_offset = vc.workgroup_id().y * params.input_y_batch_stride inside_batch_offset = vc.workgroup_id().z * params.input_z_batch_stride @@ -76,15 +83,25 @@ def workgroup_reduce( sdata[tid] = reduction_aggregate vc.barrier() + + subgroup_reduce_size = vd.get_context().subgroup_size + + if not vd.get_context().subgroup_enabled: + subgroup_reduce_size = 1 current_size = group_size // 2 - while current_size > vd.get_context().subgroup_size: + while current_size > subgroup_reduce_size: vc.if_statement(tid < current_size) sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + current_size]) - if current_size // 2 > vd.get_context().subgroup_size: + if current_size // 2 > subgroup_reduce_size: vc.end() else: - vc.else_if_statement(tid < 2*vc.subgroup_size()) + tid_limit = 2 + + if subgroup_reduce_size != 1: + tid_limit = 2*vc.subgroup_size() + + vc.else_if_statement(tid < tid_limit) sdata[tid] = vc.new_register(out_type, 0) vc.end() @@ -99,22 +116,28 @@ def subgroup_reduce( reduction: ReduceOp, group_size: int): tid = vc.local_invocation_id().x - subgroup_size = vd.get_context().subgroup_size + subgroup_reduce_size = vd.get_context().subgroup_size - if group_size > subgroup_size: - vc.if_all(tid < subgroup_size) - sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + subgroup_size]) + if not vd.get_context().subgroup_enabled: + subgroup_reduce_size = 1 + + if group_size > subgroup_reduce_size: + vc.if_statement(tid < subgroup_reduce_size) + sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + subgroup_reduce_size]) vc.end() + + if subgroup_reduce_size == 1: + return sdata[tid].to_register("local_var") + vc.subgroup_barrier() - if reduction.subgroup_reduction is not None: local_var = sdata[tid].to_register("local_var") local_var[:] = reduction.subgroup_reduction(local_var) return local_var else: - current_size = subgroup_size // 2 + current_size = subgroup_reduce_size // 2 while current_size > 1: vc.if_statement(tid < current_size) sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + current_size]) @@ -134,6 +157,8 @@ def make_reduction_stage( output_is_input: bool, map_func: Optional[vd.MappingFunction] = None, input_types: List = None) -> vd.ShaderFunction: + + name = f"reduction_stage_{reduction.name}_{out_type.name}_{input_types}_{group_size}" with vd.shader_context() as context: signature_type_array = [] @@ -162,4 +187,4 @@ def make_reduction_stage( input_variables[0][batch_offset + output_offset + params.output_offset] = local_var vc.end() - return context.get_function(local_size=(group_size, 1, 1)) + return context.get_function(local_size=(group_size, 1, 1), name=name) diff --git a/vkdispatch/shader/context.py b/vkdispatch/shader/context.py index 74688e63..9bd5713c 100644 --- a/vkdispatch/shader/context.py +++ b/vkdispatch/shader/context.py @@ -1,9 +1,8 @@ import vkdispatch as vd import vkdispatch.codegen as vc -from .signature import ShaderSignature - -from typing import List +from .signature import ShaderSignature, ShaderArgumentType +from typing import List, Optional, Any import contextlib @@ -15,21 +14,64 @@ class ShaderContext: def __init__(self, builder: vc.ShaderBuilder): self.builder = builder self.signature = None - + self.shader_function = None + def get_function(self, local_size=None, workgroups=None, - exec_count=None) -> vd.ShaderFunction: - return vd.ShaderFunction.from_description( - self.builder.build("shader"), + exec_count=None, + name: Optional[str] = None) -> vd.ShaderFunction: + if self.shader_function is not None: + return self.shader_function + + description = self.builder.build("shader" if name is None else name) + + # Resource bindings are declared before final shader layout is known. + # For some shader construction paths (e.g. from_description), signatures are + # pre-populated and still hold logical bindings assuming a reserved UBO at 0. + binding_shift = description.resource_binding_base - 1 + if binding_shift != 0: + binding_access_len = len(description.binding_access) + needs_remap = False + + for shader_arg in self.signature.arguments: + if ( + shader_arg.binding is not None + and ( + shader_arg.arg_type == ShaderArgumentType.BUFFER + or shader_arg.arg_type == ShaderArgumentType.IMAGE + ) + and shader_arg.binding >= binding_access_len + ): + needs_remap = True + break + + if needs_remap: + for shader_arg in self.signature.arguments: + if ( + shader_arg.binding is not None + and ( + shader_arg.arg_type == ShaderArgumentType.BUFFER + or shader_arg.arg_type == ShaderArgumentType.IMAGE + ) + ): + shader_arg.binding += binding_shift + + self.shader_function = vd.ShaderFunction( + description, self.signature, local_size=local_size, workgroups=workgroups, exec_count=exec_count ) + + return self.shader_function - def declare_input_arguments(self, annotations: List): - self.signature = ShaderSignature.from_type_annotations(self.builder, annotations) + def declare_input_arguments(self, + annotations: List, + names: Optional[List[str]] = None, + defaults: Optional[List[Any]] = None): + self.signature = ShaderSignature.from_type_annotations(self.builder, annotations, names, defaults) return self.signature.get_variables() @contextlib.contextmanager diff --git a/vkdispatch/shader/decorator.py b/vkdispatch/shader/decorator.py index 88e2ab8e..0dbe5239 100644 --- a/vkdispatch/shader/decorator.py +++ b/vkdispatch/shader/decorator.py @@ -1,9 +1,12 @@ import vkdispatch as vd import vkdispatch.codegen as vc +import dataclasses import inspect from typing import Callable, TypeVar +from .context import shader_context + import sys if sys.version_info >= (3, 10): @@ -12,6 +15,31 @@ else: P = ... # Placeholder for older Python versions +def inspect_function_signature(func: Callable): + func_signature = inspect.signature(func) + + annotations = [] + names = [] + defaults = [] + + for param in func_signature.parameters.values(): + if param.annotation == inspect.Parameter.empty: + raise ValueError("All parameters must be annotated") + + + if not dataclasses.is_dataclass(param.annotation): # issubclass(param.annotation.__origin__, dataclasses.dataclass): + if not hasattr(param.annotation, '__args__'): + raise TypeError(f"Argument '{param.name}: vd.{param.annotation}' must have a type annotation") + + if len(param.annotation.__args__) != 1: + raise ValueError(f"Type '{param.name}: vd.{param.annotation.__name__}' must have exactly one type argument") + + annotations.append(param.annotation) + names.append(param.name) + defaults.append(param.default if param.default != inspect.Parameter.empty else None) + + return annotations, names, defaults + def shader( exec_size=None, local_size=None, @@ -43,12 +71,11 @@ def shader( raise ValueError("Cannot specify both 'workgroups' and 'exec_size'") def decorator_callback(func: Callable[P, None]) -> Callable[P, None]: - return vd.ShaderFunction( - func, - local_size=local_size, - workgroups=workgroups, - exec_count=exec_size, - flags=flags - ) + with shader_context(flags=flags) as context: + annotations, names, defaults = inspect_function_signature(func) + args = context.declare_input_arguments(annotations, names, defaults) + func(*args) + + return context.get_function(local_size=local_size, workgroups=workgroups, exec_count=exec_size, name=func.__name__) return decorator_callback diff --git a/vkdispatch/shader/shader_function.py b/vkdispatch/shader/shader_function.py index 72c9ee83..18e135ab 100644 --- a/vkdispatch/shader/shader_function.py +++ b/vkdispatch/shader/shader_function.py @@ -12,12 +12,10 @@ from .signature import ShaderArgumentType, ShaderSignature import uuid -import sys import dataclasses -from .._compat import numpy_compat as npc -from ..base.backend import BACKEND_PYCUDA, BACKEND_VULKAN +from ..compat import numpy_compat as npc class LaunchParametersHolder: def __init__(self, names_and_defaults, args, kwargs) -> None: @@ -58,7 +56,7 @@ class ExectionBounds: def __init__(self, names_and_defaults, local_size, workgroups, exec_size) -> None: self.names_and_defaults = names_and_defaults - self.local_size = local_size + self.local_size = tuple(local_size) self.workgroups = workgroups self.exec_size = exec_size @@ -134,56 +132,49 @@ def get_blocks_and_limits(self, args, kwargs) -> Tuple[Tuple[int, int, int], Tup return (my_blocks, my_limits) +@dataclasses.dataclass +class ShaderSource: + name: str + code: str + local_size: Tuple[int, int, int] + + def __repr__(self): + return f"// ====== Source Code for '{self.name}', workgroup_size: {self.local_size} ======\n{self.code}" + class ShaderFunction: plan: ComputePlan - func: Callable shader_description: vc.ShaderDescription shader_signature: ShaderSignature bounds: ExectionBounds ready: bool + name: str source: str flags: vc.ShaderFlags + local_size: Union[Tuple[int, int, int], Callable, None] + workgroups: Union[Tuple[int, int, int], Callable, None] + exec_size: Union[Tuple[int, int, int], Callable, None] def __init__(self, - func: Callable, + shader_description: vc.ShaderDescription, + shader_signature: ShaderSignature, local_size=None, workgroups=None, exec_count=None, - flags: vc.ShaderFlags = vc.ShaderFlags.NONE) -> None: + flags: vc.ShaderFlags = vc.ShaderFlags.NONE, + name: str = None) -> None: self.plan = None - self.func = func - self.shader_description = None - self.shader_signature = None + self.shader_description = shader_description + self.shader_signature = shader_signature self.bounds = None self.ready = False + self.name = name if name is not None else None self.source = None self.local_size = local_size self.workgroups = workgroups self.exec_size = exec_count self.flags = flags - def from_description( - shader_description: vc.ShaderDescription, - shader_signature: ShaderSignature, - local_size=None, - workgroups=None, - exec_count=None, - - ) -> "ShaderFunction": - shader_obj = ShaderFunction( - func=None, - local_size=local_size, - workgroups=workgroups, - exec_count=exec_count, - flags=vc.ShaderFlags.NONE - ) - - shader_obj.shader_description = shader_description - shader_obj.shader_signature = shader_signature - - return shader_obj - def build(self): if self.ready: return @@ -194,70 +185,65 @@ def build(self): else [vd.get_context().max_workgroup_size[0], 1, 1] ) - if self.shader_description is None or self.shader_signature is None: - assert self.shader_description is None and self.shader_signature is None, "Shader description and signature must both be set or both be None!" - assert self.func is not None, "Cannot build a shader without a function!" - - builder = vc.ShaderBuilder( - flags=self.flags, - is_apple_device=vd.get_context().is_apple() - ) - old_builder = vc.set_builder(builder) - - signature = ShaderSignature.from_inspectable_function(builder, self.func) - - self.func(*signature.get_variables()) - - vc.set_builder(old_builder) - - self.shader_description = builder.build(self.func.__module__ + "." + self.func.__name__) - self.shader_signature = signature - self.bounds = ExectionBounds(self.shader_signature.get_names_and_defaults(), my_local_size, self.workgroups, self.exec_size) - if not sys.implementation.name == "Brython": - runtime_backend = vd.get_backend() - shader_backend_name = ( - self.shader_description.backend.name - if self.shader_description.backend is not None - else "glsl" - ) - - if runtime_backend == BACKEND_PYCUDA and shader_backend_name != "cuda": - raise RuntimeError( - "PyCUDA runtime backend requires CUDA codegen output. " - "Call vd.initialize(backend='pycuda') before building shaders." - ) + shader_backend_name = ( + self.shader_description.backend.name + if self.shader_description.backend is not None + else "glsl" + ) - if runtime_backend == BACKEND_VULKAN and shader_backend_name == "cuda": - raise RuntimeError( - "Vulkan runtime backend cannot execute CUDA codegen output. " - "Use GLSL codegen or initialize with backend='pycuda'." - ) + if vd.is_dummy(): + pass + elif vd.is_cuda() and shader_backend_name != "cuda": + raise RuntimeError( + "The selected CUDA runtime backend requires CUDA codegen output. " + "Call vd.initialize(backend='cuda') " + "before building shaders." + ) + elif vd.is_opencl() and shader_backend_name != "opencl": + raise RuntimeError( + "The selected OpenCL runtime backend requires OpenCL codegen output. " + "Call vd.initialize(backend='opencl') " + "before building shaders." + ) + elif vd.is_vulkan() and shader_backend_name == "cuda": + raise RuntimeError( + "Vulkan runtime backend cannot execute CUDA codegen output. " + "Use GLSL codegen or initialize with backend='cuda'." + ) + elif vd.is_vulkan() and shader_backend_name == "opencl": + raise RuntimeError( + "Vulkan runtime backend cannot execute OpenCL codegen output. " + "Use GLSL codegen or initialize with backend='opencl'." + ) self.source = self.shader_description.make_source( my_local_size[0], my_local_size[1], my_local_size[2] ) try: - self.plan = ComputePlan( - self.source, - self.shader_description.binding_type_list, - self.shader_description.pc_size, - self.shader_description.name - ) + if not vd.is_dummy(): + self.plan = ComputePlan( + self.source, + self.shader_description.binding_type_list, + self.shader_description.pc_size, + self.shader_description.name + ) except Exception as e: print(f"Error building shader: {e}") - print(self.make_repr()) + print(self.get_src(build=False, line_numbers=True)) raise e self.ready = True def __repr__(self) -> str: - self.build() - return self.make_repr() + return self.get_src().__repr__() - def make_repr(self, line_numbers: bool = None) -> str: + def get_src(self, line_numbers: bool = None, build: bool = True) -> ShaderSource: + if build: + self.build() + result = "" if line_numbers is None: @@ -268,9 +254,14 @@ def make_repr(self, line_numbers: bool = None) -> str: result += f"{line_prefix}{line}\n" - return result + return ShaderSource(name=self.name, code=result, local_size=self.bounds.local_size) + + def print_src(self, line_numbers: bool = None): + print(self.get_src(line_numbers)) def __call__(self, *args, **kwargs): + assert not vd.is_dummy(), "Cannot execute shader functions with dummy backend!" + self.build() if not self.ready: diff --git a/vkdispatch/shader/signature.py b/vkdispatch/shader/signature.py index c9cb53b7..8d6f4a46 100644 --- a/vkdispatch/shader/signature.py +++ b/vkdispatch/shader/signature.py @@ -19,6 +19,16 @@ import enum +_PUSH_CONSTANT_UNSUPPORTED_BACKENDS = set() + + +def _push_constant_not_supported_error(backend_name: str) -> str: + return ( + f"Push Constants are not supported for the {backend_name.upper()} backend. " + "Use Const instead." + ) + + class ShaderArgumentType(enum.Enum): BUFFER = 0 IMAGE = 1 @@ -139,6 +149,9 @@ def from_type_annotations(cls, value_name = shader_param.raw_name arg_type = ShaderArgumentType.CONSTANT elif(issubclass(annotations[i].__origin__, vc.Variable)): + if builder.backend.name in _PUSH_CONSTANT_UNSUPPORTED_BACKENDS: + raise NotImplementedError(_push_constant_not_supported_error(builder.backend.name)) + shader_param = builder.declare_variable(annotations[i].__args__[0]) arg_type = ShaderArgumentType.VARIABLE value_name = shader_param.raw_name diff --git a/vkdispatch/vkfft/vkfft_plan.py b/vkdispatch/vkfft/vkfft_plan.py index 64f201f3..0ad12dea 100644 --- a/vkdispatch/vkfft/vkfft_plan.py +++ b/vkdispatch/vkfft/vkfft_plan.py @@ -1,4 +1,4 @@ -from vkdispatch.base.backend import native +from vkdispatch.backends.backend_selection import native import vkdispatch as vd