diff --git a/README.md b/README.md index ea295ee2..10878cbc 100644 --- a/README.md +++ b/README.md @@ -1,67 +1,189 @@ -# Getting Started +# vkdispatch -Welcome to **vkdispatch**! This guide will help you install the library and run your first GPU-accelerated code. +`vkdispatch` is a Python GPU computing framework for writing single-source kernels in Python and dispatching them across multiple runtime backends. -**[WARNING: The documentation is still under heavy development, and has many missing sections]** +It combines runtime code generation, execution helpers, and FFT/reduction utilities in one package. The default PyPI install ships with the Vulkan backend. CUDA and OpenCL backends can be enabled with optional runtime dependencies. -> **Note:** vkdispatch requires a Vulkan-compatible GPU and drivers installed on your system. Please ensure your system meets these requirements before proceeding. +## Highlights + +- Single-source Python shaders via `@vd.shader` and `vkdispatch.codegen` +- Multiple runtime backends: Vulkan, CUDA, OpenCL, and a dummy codegen-only backend +- Backend-aware code generation: GLSL for Vulkan, CUDA source for CUDA, and OpenCL C for OpenCL +- Native FFT workflows through `vd.fft`, including mapping hooks for fusion and custom I/O +- VkFFT-backed transforms through `vd.vkfft` on the Vulkan backend +- Reductions through `vd.reduce` +- Batched submission and deferred execution through `vd.CommandGraph` +- CUDA interop through `__cuda_array_interface__` and CUDA Graph capture helpers ## Installation -The default installation method for `vkdispatch` is through PyPI (pip): +### Default Vulkan Install + +To install `vkdispatch` with the Vulkan backend, run: ```bash -# Install the package pip install vkdispatch ``` -On mainstream platforms — Windows (x86_64), macOS (x86_64 and Apple Silicon/arm64), and Linux (x86_64) — pip will download a **prebuilt wheel** (built with `cibuildwheel` on GitHub Actions and tagged as *manylinux* where applicable), so no compiler is needed. +This installs the core library, the code generation system, and the Vulkan runtime backend. The Vulkan backend is designed to run on systems supporting Vulkan 1.2 or higher, including macOS via a statically linked MoltenVK. Alternate backends can be added with optional dependencies as described below. + +On mainstream platforms - Windows (`x86_64`), macOS (`x86_64` and Apple Silicon/`arm64`), and Linux (`x86_64`) - `pip` will usually download a prebuilt wheel, so no compiler is needed. + +On less common platforms, `pip` may fall back to a source build, which takes a few minutes. See [Building From Source](https://sharhar.github.io/vkdispatch/tutorials/building_from_source.html) for toolchain requirements and developer-oriented instructions. + +### Core package + +For cases where only the codegen component is needed, or in environments where only the CUDA or OpenCL backends are needed, install the core package: + +```bash +pip install vkdispatch-core +``` + +This installs the core library and codegen components, but not the Vulkan runtime backend. To enable runtime features beyond pure codegen, install the optional dependencies below. + +### Optional components + +- Optional CLI: `pip install vkdispatch-core[cli]` +- CUDA runtime backend: `pip install vkdispatch-core[cuda]` +- OpenCL runtime backend: `pip install vkdispatch-core[opencl]` + +## Runtime backends + +`vkdispatch` currently supports these runtime backends: + +- `vulkan` +- `cuda` +- `opencl` +- `dummy` + +If you do not explicitly select a backend, ``vkdispatch`` prefers Vulkan. When the Vulkan backend cannot be imported because it is not installed, initialization falls back to CUDA and then OpenCL. + +You can select a backend explicitly in Python: + +```python +import vkdispatch as vd + +vd.initialize(backend="vulkan") +# vd.initialize(backend="cuda") +# vd.initialize(backend="opencl") +# vd.initialize(backend="dummy") +``` + +You can also select the backend with an environment variable: + +```bash +export VKDISPATCH_BACKEND=vulkan +``` + +The dummy backend is useful for codegen-only workflows, source inspection, and development environments where no GPU runtime is available. + +There are two intended shader-generation modes: -On less common platforms (e.g., non-Apple ARM or other niche architectures), pip may fall back to a **source build**, which takes a few minutes. See **[Building From Source](https://sharhar.github.io/vkdispatch/tutorials/building_from_source.html)** for toolchain requirements and developer-oriented instructions. +- Default mode: generate for the current machine/runtime. This is the normal path and is how `vkdispatch` picks backend-specific defaults and limits. +- Custom mode: initialize with `backend="dummy"` and optionally tune the dummy device limits when you want controlled codegen without relying on the current runtime. -> **Tip:** If you see output like `Building wheel for vkdispatch (pyproject.toml)`, you’re compiling from source. -## Verifying Your Installation +## Verifying your installation -To ensure `vkdispatch` is installed correctly and can detect your GPU, run: +If you installed the optional CLI, you can list devices with: ```bash -# Quick device listing vdlist -# If the above command is unavailable, try: -python3 -m vkdispatch +# Explicit backend selection can be done with cmdline flags: +vdlist --vulkan +vdlist --cuda +vdlist --opencl ``` -If the installation was successful, you should see output listing your GPU(s), for example: - -```text -Device 0: Apple M2 Pro - Vulkan Version: 1.2.283 - Device Type: Integrated GPU - - Features: - Float32 Atomic Add: True - - Properties: - 64-bit Float Support: False - 16-bit Float Support: True - 64-bit Int Support: True - 16-bit Int Support: True - Max Push Constant Size: 4096 bytes - Subgroup Size: 32 - Max Compute Shared Memory Size: 32768 - - Queues: - 0 (count=1, flags=0x7): Graphics | Compute - 1 (count=1, flags=0x7): Graphics | Compute - 2 (count=1, flags=0x7): Graphics | Compute - 3 (count=1, flags=0x7): Graphics | Compute +You can always inspect devices from Python: + +```python +import vkdispatch as vd + +for device in vd.get_devices(): + print(device.get_info_string()) ``` -## Next Steps +The reported version label depends on the active backend: + +- Vulkan devices show a Vulkan version +- CUDA devices show CUDA compute capability +- OpenCL devices show an OpenCL version + +## Quick start + +The example below defines a simple in-place compute kernel in Python: + +```python +import numpy as np +import vkdispatch as vd +import vkdispatch.codegen as vc +from vkdispatch.codegen.abbreviations import Buff, Const, f32 + +# @vd.shader(exec_size=lambda args: args.buff.size) +@vd.shader("buff.size") +def add_scalar(buff: Buff[f32], bias: Const[f32]): + tid = vc.global_invocation_id().x + buff[tid] = buff[tid] + bias + +arr = np.arange(8, dtype=np.float32) +buff = vd.asbuffer(arr) + +# If you want a non-default backend, call vd.initialize(backend=...) first. +add_scalar(buff, 1.5) + +print(buff.read(0)) +``` + +String launch sizing is the shortest form and is kept for convenience. If you want +the launch rule to be more explicit and deterministic, use the equivalent lambda form +instead: `@vd.shader(exec_size=lambda args: args.buff.size)`. + +In normal usage, `vkdispatch` initializes itself and creates a default context on first runtime use. Call `vd.initialize()` and `vd.make_context()` manually only when you want non-default settings such as backend selection, custom device selection, debug logging, or multi-device Vulkan contexts. + +## Codegen-Only Workflows + +If you want generated source without compiling or dispatching it on the current machine, use the dummy backend explicitly: + +```python +import vkdispatch as vd +import vkdispatch.codegen as vc +from vkdispatch.codegen.abbreviations import Buff, Const, f32 + +vd.initialize(backend='dummy') +vd.set_dummy_context_params( + subgroup_size=32, + max_workgroup_size=(128, 1, 1), + max_workgroup_count=(65535, 65535, 65535), +) +vc.set_codegen_backend('cuda') + +# @vd.shader(exec_size=lambda args: args.buff.size) +@vd.shader('buff.size') +def add_scalar(buff: Buff[f32], bias: Const[f32]): + tid = vc.global_invocation_id().x + buff[tid] = buff[tid] + bias + +src = add_scalar.get_src(line_numbers=True) +print(src) +``` + +In this mode, `vkdispatch` uses the dummy device model for launch/layout defaults and emits source for the backend selected with `vc.set_codegen_backend(...)`. + +## Documentation + +The docs site is still under active development, but the main entry points are here: + +- [Getting Started](https://sharhar.github.io/vkdispatch/getting_started.html) +- [Tutorials](https://sharhar.github.io/vkdispatch/tutorials/index.html) +- [Python API Reference](https://sharhar.github.io/vkdispatch/python_api.html) + +Some especially useful tutorials: -- **[Tutorials](https://sharhar.github.io/vkdispatch/tutorials/index.html)** — our curated guide to common workflows and examples -- **[Full Python API Reference](https://sharhar.github.io/vkdispatch/python_api.html)** — comprehensive reference for Python-facing components +- [Shader Authoring and Dispatch](https://sharhar.github.io/vkdispatch/tutorials/shader_tutorial.html) +- [Initialization and Context Creation](https://sharhar.github.io/vkdispatch/tutorials/context_system.html) +- [Command Graph Recording](https://sharhar.github.io/vkdispatch/tutorials/command_graph_tutorial.html) +- [Reductions and FFT Workflows](https://sharhar.github.io/vkdispatch/tutorials/reductions_and_fft.html) Happy GPU programming! diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 79cdf173..a43cdd8c 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -1,83 +1,225 @@ Getting Started -=============================== +=============== -Welcome to vkdispatch! This guide will help you install the library and run your first GPU-accelerated code. +``vkdispatch`` is a Python GPU computing framework for writing single-source +kernels in Python and dispatching them across multiple runtime backends. -.. note:: - vkdispatch requires a Vulkan-compatible GPU and drivers installed on your system. - Please ensure your system meets these requirements before proceeding. +It combines runtime code generation, execution helpers, and FFT/reduction +utilities in one package. The default PyPI install ships with the Vulkan +backend. CUDA and OpenCL backends can be enabled with optional runtime +dependencies. + +Highlights +---------- + +* Single-source Python shaders via ``@vd.shader`` and ``vkdispatch.codegen`` +* Multiple runtime backends: Vulkan, CUDA, OpenCL, and a dummy codegen-only backend +* Backend-aware code generation: GLSL for Vulkan, CUDA source for CUDA, and OpenCL C for OpenCL +* Native FFT workflows through ``vd.fft``, including mapping hooks for fusion and custom I/O +* VkFFT-backed transforms through ``vd.vkfft`` on the Vulkan backend +* Reductions through ``vd.reduce`` +* Batched submission and deferred execution through ``vd.CommandGraph`` +* CUDA interop through ``__cuda_array_interface__`` and CUDA Graph capture helpers Installation ---------------------------- +------------ + +Default Vulkan install +~~~~~~~~~~~~~~~~~~~~~~ -The default installation method for `vkdispatch` is through PyPI (pip): +To install ``vkdispatch`` with the Vulkan backend, run: .. code-block:: bash - # Install the package pip install vkdispatch -On mainstream platforms — Windows (x86_64), macOS (x86_64 and Apple Silicon/arm64), -and Linux (x86_64) — pip will download a **prebuilt wheel** (built with `cibuildwheel` -on GitHub Actions and tagged as *manylinux* where applicable), so no compiler is needed. +This installs the core library, the code generation system, and the Vulkan +runtime backend. The Vulkan backend is designed to run on systems supporting +Vulkan 1.2 or higher, including macOS via a statically linked MoltenVK. +Alternate backends can be added with optional dependencies as described below. + +On mainstream platforms, Windows (``x86_64``), macOS (``x86_64`` and Apple +Silicon/``arm64``), and Linux (``x86_64``), ``pip`` will usually download a +prebuilt wheel, so no compiler is needed. -On less common platforms (e.g., non-Apple ARM or other niche architectures), pip may -fall back to a **source build**, which takes a few minutes. See :doc:`Building From Source` +On less common platforms, ``pip`` may fall back to a source build, which takes +a few minutes. See :doc:`Building From Source` for toolchain requirements and developer-oriented instructions. -.. note:: - If you see output like ``Building wheel for vkdispatch (pyproject.toml)``, - you’re compiling from source. +Core package +~~~~~~~~~~~~ + +For cases where only the codegen component is needed, or in environments where +only the CUDA or OpenCL backends are needed, install the core package: + +.. code-block:: bash + + pip install vkdispatch-core + +This installs the core library and codegen components, but not the Vulkan +runtime backend. To enable runtime features beyond pure codegen, install the +optional dependencies below. + +Optional components +~~~~~~~~~~~~~~~~~~~ + +* Optional CLI: ``pip install "vkdispatch-core[cli]"`` +* CUDA runtime backend: ``pip install "vkdispatch-core[cuda]"`` +* OpenCL runtime backend: ``pip install "vkdispatch-core[opencl]"`` + +Runtime backends +---------------- + +``vkdispatch`` currently supports these runtime backends: + +* ``vulkan`` +* ``cuda`` +* ``opencl`` +* ``dummy`` + +If you do not explicitly select a backend, ``vkdispatch`` prefers Vulkan. When +the Vulkan backend cannot be imported because it is not installed, +initialization falls back to CUDA and then OpenCL. + +You can select a backend explicitly in Python: + +.. code-block:: python + + import vkdispatch as vd + + vd.initialize(backend="vulkan") + # vd.initialize(backend="cuda") + # vd.initialize(backend="opencl") + # vd.initialize(backend="dummy") + +You can also select the backend with an environment variable: + +.. code-block:: bash + + export VKDISPATCH_BACKEND=vulkan + +The dummy backend is useful for codegen-only workflows, source inspection, and +development environments where no GPU runtime is available. + +There are two intended shader-generation modes: -Verifying Your Installation +* Default mode: generate for the current machine/runtime. This is the normal + path and is how ``vkdispatch`` picks backend-specific defaults and limits. +* Custom mode: initialize with ``backend="dummy"`` and optionally tune the + dummy device limits when you want controlled codegen without relying on the + current runtime. + +Verifying your installation --------------------------- -To ensure `vkdispatch` is installed correctly and can detect your GPU, -run this simple Python script: +If you installed the optional CLI, you can list devices with: .. code-block:: bash - - # Run the example script to verify installation + vdlist - # If the above command fails, you can try this alternative - python3 -m vkdispatch + # Explicit backend selection can be done with command-line flags: + vdlist --vulkan + vdlist --cuda + vdlist --opencl + +You can always inspect devices from Python: + +.. code-block:: python + + import vkdispatch as vd + + for device in vd.get_devices(): + print(device.get_info_string()) + +The reported version label depends on the active backend: + +* Vulkan devices show a Vulkan version +* CUDA devices show CUDA compute capability +* OpenCL devices show an OpenCL version + +Quick start +----------- -If the installation was successful, you should see output listing your GPU(s) which may look something like this: +The example below defines a simple in-place compute kernel in Python: -.. code-block:: text +.. code-block:: python - Device 0: Apple M2 Pro - Vulkan Version: 1.2.283 - Device Type: Integrated GPU + import numpy as np + import vkdispatch as vd + import vkdispatch.codegen as vc + from vkdispatch.codegen.abbreviations import Buff, Const, f32 - Features: - Float32 Atomic Add: True + # @vd.shader(exec_size=lambda args: args.buff.size) + @vd.shader("buff.size") + def add_scalar(buff: Buff[f32], bias: Const[f32]): + tid = vc.global_invocation_id().x + buff[tid] = buff[tid] + bias - Properties: - 64-bit Float Support: False - 16-bit Float Support: True - 64-bit Int Support: True - 16-bit Int Support: True - Max Push Constant Size: 4096 bytes - Subgroup Size: 32 - Max Compute Shared Memory Size: 32768 + arr = np.arange(8, dtype=np.float32) + buff = vd.asbuffer(arr) - Queues: - 0 (count=1, flags=0x7): Graphics | Compute - 1 (count=1, flags=0x7): Graphics | Compute - 2 (count=1, flags=0x7): Graphics | Compute - 3 (count=1, flags=0x7): Graphics | Compute + # If you want a non-default backend, call vd.initialize(backend=...) first. + add_scalar(buff, 1.5) + print(buff.read(0)) +String launch sizing is the shortest form and is kept for convenience. If you +want the launch rule to be more explicit and deterministic, use the equivalent +lambda form instead: ``@vd.shader(exec_size=lambda args: args.buff.size)``. -Next Steps +In normal usage, ``vkdispatch`` initializes itself and creates a default +context on first runtime use. Call ``vd.initialize()`` and ``vd.make_context()`` +manually only when you want non-default settings such as backend selection, +custom device selection, debug logging, or multi-device Vulkan contexts. + +Codegen-only workflows +---------------------- + +If you want generated source without compiling or dispatching it on the current +machine, use the dummy backend explicitly: + +.. code-block:: python + + import vkdispatch as vd + import vkdispatch.codegen as vc + from vkdispatch.codegen.abbreviations import Buff, Const, f32 + + vd.initialize(backend="dummy") + vd.set_dummy_context_params( + subgroup_size=32, + max_workgroup_size=(128, 1, 1), + max_workgroup_count=(65535, 65535, 65535), + ) + vc.set_codegen_backend("cuda") + + # @vd.shader(exec_size=lambda args: args.buff.size) + @vd.shader("buff.size") + def add_scalar(buff: Buff[f32], bias: Const[f32]): + tid = vc.global_invocation_id().x + buff[tid] = buff[tid] + bias + + src = add_scalar.get_src(line_numbers=True) + print(src) + +In this mode, ``vkdispatch`` uses the dummy device model for launch and layout +defaults and emits source for the backend selected with +``vc.set_codegen_backend(...)``. + + +Next steps ---------- -Now that you've got `vkdispatch` up and running, consider exploring the following: +The main entry points in the documentation are: + +* :doc:`Tutorials` +* :doc:`Full Python API Reference` + +Some especially useful tutorials: -* :doc:`Code Structure and Execution Flow`: A guided tour of how Python, codegen, and native layers fit together. -* :doc:`Tutorials`: Our curated guide to the most commonly used classes and functions. -* :doc:`Full Python API Reference`: A comprehensive list of all Python-facing components. +* :doc:`Shader Authoring and Dispatch` +* :doc:`Initialization and Context Creation` +* :doc:`Command Graph Recording` +* :doc:`Reductions and FFT Workflows` Happy GPU programming! diff --git a/docs/special_pages/shader_playground.html b/docs/special_pages/shader_playground.html index 9fb37d17..fb60c217 100644 --- a/docs/special_pages/shader_playground.html +++ b/docs/special_pages/shader_playground.html @@ -501,6 +501,7 @@

VkDispatch Shader Playground

import vkdispatch.codegen as vc from vkdispatch.codegen.abreviations import * +# @vd.shader(exec_size=lambda args: args.buff.size) @vd.shader("buff.size") def add_scalar(buff: Buff[f32], bias: Const[f32]): tid = vc.global_invocation_id().x diff --git a/docs/tutorials/code_structure.rst b/docs/tutorials/code_structure.rst index b05cb6fe..86489e7e 100644 --- a/docs/tutorials/code_structure.rst +++ b/docs/tutorials/code_structure.rst @@ -90,6 +90,7 @@ Minimal Example (API Layer View) import vkdispatch.codegen as vc from vkdispatch.codegen.abreviations import * + # @vd.shader(exec_size=lambda args: args.data.size) @vd.shader("data.size") def scale_inplace(data: Buff[f32], alpha: Const[f32]): tid = vc.global_invocation_id().x diff --git a/docs/tutorials/command_graph_tutorial.rst b/docs/tutorials/command_graph_tutorial.rst index 51cdf98f..6316bc11 100644 --- a/docs/tutorials/command_graph_tutorial.rst +++ b/docs/tutorials/command_graph_tutorial.rst @@ -26,6 +26,7 @@ Single Graph, Multiple Dispatches graph = vd.CommandGraph() + # @vd.shader(exec_size=lambda args: args.buff.size) @vd.shader("buff.size") def add_scalar(buff: Buff[f32], value: Const[f32]): tid = vc.global_invocation_id().x diff --git a/docs/tutorials/images_and_sampling.rst b/docs/tutorials/images_and_sampling.rst index f60bc9b7..120cff11 100644 --- a/docs/tutorials/images_and_sampling.rst +++ b/docs/tutorials/images_and_sampling.rst @@ -48,6 +48,7 @@ Use codegen image argument types (``Img1``, ``Img2``, ``Img3``) inside ``@vd.sha upscale = 4 out = vd.Buffer((data.shape[0] * upscale, data.shape[1] * upscale), vd.float32) + # @vd.shader(exec_size=lambda args: args.out.size) @vd.shader("out.size") def sample_2d(out: Buff[f32], src: Img2[f32], scale: Const[f32]): tid = vc.global_invocation_id().x diff --git a/docs/tutorials/shader_tutorial.rst b/docs/tutorials/shader_tutorial.rst index 060425dc..5b0ab767 100644 --- a/docs/tutorials/shader_tutorial.rst +++ b/docs/tutorials/shader_tutorial.rst @@ -6,8 +6,9 @@ runtime. This page covers shader launch patterns and the key semantics of vkdisp runtime shader generation model. Examples below omit ``vd.initialize()`` and ``vd.make_context()`` because vkdispatch -creates them automatically on first runtime use. Call them manually only when you need -custom initialization/context settings. +creates them automatically on first runtime use. That default path is intentional: +generated shaders are specialized against the current machine/runtime unless you +explicitly choose dummy-mode codegen. Runtime Generation Model ------------------------ @@ -23,6 +24,21 @@ This is different from AST/IR compilers: it is a forward streaming model, so exp register materialization and explicit shader control-flow helpers matter for performance and correctness. +Default Runtime-Coupled Generation +---------------------------------- + +By default, ``vkdispatch`` generates shaders for the active runtime backend and uses that +runtime's limits when choosing implicit launch defaults such as ``local_size``. + +This is the normal mode for end-to-end execution: + +1. define the kernel with ``@vd.shader`` +2. let ``vkdispatch`` auto-initialize or call ``vd.initialize(...)`` yourself +3. execute the shader or inspect ``get_src()`` for the current machine + +If you want controlled source generation without relying on the active runtime, use the +dummy backend explicitly. + Imports and Type Annotations ---------------------------- @@ -32,7 +48,7 @@ Most shader examples use these imports: import vkdispatch as vd import vkdispatch.codegen as vc - from vkdispatch.codegen.abreviations import * + from vkdispatch.codegen.abbreviations import * * ``Buff[...]`` is a shader buffer argument type. * ``Const[...]`` is a uniform/constant argument type. @@ -46,8 +62,9 @@ Basic In-Place Kernel import numpy as np import vkdispatch as vd import vkdispatch.codegen as vc - from vkdispatch.codegen.abreviations import * + from vkdispatch.codegen.abbreviations import * + # @vd.shader(exec_size=lambda args: args.buff.size) @vd.shader("buff.size") def add_scalar(buff: Buff[f32], bias: Const[f32]): tid = vc.global_invocation_id().x @@ -69,6 +86,7 @@ Use one of these launch patterns: .. code-block:: python + # @vd.shader(exec_size=lambda args: args.in_buf.size) @vd.shader("in_buf.size") def kernel(in_buf: Buff[f32], out_buf: Buff[f32]): ... @@ -99,6 +117,9 @@ Use one of these launch patterns: ``exec_size`` and ``workgroups`` are mutually exclusive. The string form is often the most concise option for argument-dependent dispatch size. +It is evaluated dynamically, so it is slightly more brittle than the lambda form. +When you want the declaration itself to be more explicit and deterministic, prefer +``exec_size=lambda args: ...``. You can also override launch parameters per call: @@ -118,6 +139,7 @@ To materialize a value once and mutate it, convert it to a register with .. code-block:: python + # @vd.shader(exec_size=lambda args: args.buff.size) @vd.shader("buff.size") def register_example(buff: Buff[f32]): tid = vc.global_invocation_id().x @@ -139,6 +161,7 @@ store syntax ``x[:] = ...``. .. code-block:: python + # @vd.shader(exec_size=lambda args: args.buff.size) @vd.shader("buff.size") def register_store(buff: Buff[f32]): tid = vc.global_invocation_id().x @@ -153,6 +176,7 @@ Native Python control flow with vkdispatch variables is intentionally blocked: .. code-block:: python + # @vd.shader(exec_size=lambda args: args.buff.size) @vd.shader("buff.size") def bad_branch(buff: Buff[f32]): tid = vc.global_invocation_id().x @@ -163,6 +187,7 @@ Use shader control-flow helpers so both branches are emitted into generated code .. code-block:: python + # @vd.shader(exec_size=lambda args: args.buff.size) @vd.shader("buff.size") def threshold(buff: Buff[f32], cutoff: Const[f32]): tid = vc.global_invocation_id().x @@ -182,6 +207,7 @@ conditionals are useful for specialization and unrolling. .. code-block:: python def make_unrolled_sum(unroll: int): + # @vd.shader(exec_size=lambda args: args.dst.size) @vd.shader("dst.size") def unrolled_sum(src: Buff[f32], dst: Buff[f32]): tid = vc.global_invocation_id().x @@ -219,6 +245,35 @@ You can pass mapping functions into APIs that accept ``mapping_function``, Inspecting Generated Shader Source ---------------------------------- +``get_src()`` returns the generated source for the currently selected runtime/codegen +configuration. In the default mode, that means the generated shader is tied to the +current machine/runtime by design. + +For explicit codegen-only workflows, initialize the dummy backend first and select the +output backend you want: + +.. code-block:: python + + import vkdispatch as vd + import vkdispatch.codegen as vc + from vkdispatch.codegen.abbreviations import Buff, Const, f32 + + vd.initialize(backend="dummy") + vd.set_dummy_context_params( + subgroup_size=32, + max_workgroup_size=(128, 1, 1), + max_workgroup_count=(65535, 65535, 65535), + ) + vc.set_codegen_backend("cuda") + + # @vd.shader(exec_size=lambda args: args.buff.size) + @vd.shader("buff.size") + def add_scalar(buff: Buff[f32], bias: Const[f32]): + tid = vc.global_invocation_id().x + buff[tid] = buff[tid] + bias + + print(add_scalar.get_src(line_numbers=True)) + A built shader can be printed for debugging: .. code-block:: python diff --git a/examples/pytorch_cuda_graph_cuda_python.py b/examples/pytorch_cuda_graph_cuda_python.py index 51a949f9..869a7ef3 100644 --- a/examples/pytorch_cuda_graph_cuda_python.py +++ b/examples/pytorch_cuda_graph_cuda_python.py @@ -12,7 +12,7 @@ import vkdispatch as vd import vkdispatch.codegen as vc -from vkdispatch.codegen.abreviations import Buff, Const, f32 +from vkdispatch.codegen.abbreviations import Buff, Const, f32 @vd.shader(exec_size=lambda args: args.x.size) diff --git a/fetch_dependencies.py b/fetch_dependencies.py index 05a21b66..203b01fc 100644 --- a/fetch_dependencies.py +++ b/fetch_dependencies.py @@ -43,7 +43,7 @@ def clone_and_checkout(repo_url, commit_hash, output_dir): sys.exit(1) dependencies = [ - ("https://github.com/sharhar/VkFFT.git", "9cb0da8ab98f0f3a10debd1466f13cab65ce9bc3", "deps/VkFFT"), # my fork of VkFFT, will change to official repo once PR is merged + ("https://github.com/DTolm/VkFFT.git", "e8e6a391288523a5ed03c4a46a79666543e30329", "deps/VkFFT"), ("https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git", "5677097bafb8477097c6e3354ce68b7a44fd01a4", "deps/VMA"), ("https://github.com/KhronosGroup/Vulkan-Headers.git", "eaa319dade959cb61ed2229c8ea42e307cc8f8b3", "deps/Vulkan-Headers"), ("https://github.com/KhronosGroup/Vulkan-Utility-Libraries.git", "ad7f699a7b2b5deb66eb3de19f24aa33597ed65b", "deps/Vulkan-Utility-Libraries"), @@ -60,7 +60,11 @@ def clone_and_checkout(repo_url, commit_hash, output_dir): os.makedirs("deps/MoltenVK", exist_ok=True) -molten_vk_url = "https://github.com/KhronosGroup/MoltenVK/releases/download/v1.4.0/MoltenVK-macos.tar" +# Custom MoltenVK v1.4.1 build that uses a custom +# patched SPIRV-Cross fork that elevates thread_barrier +# memory scope to Device, which is required for correct +# execution of compute shaders given a bug on Apple's side. +molten_vk_url = "https://github.com/sharhar/MoltenVK/releases/download/vkdispatch_patch/MoltenVK-macos.tar" molten_vk_path = "deps/MoltenVK" molten_vk_filename = "MoltenVK-macos.tar" molten_vk_full_file_path = os.path.join(molten_vk_path, molten_vk_filename) diff --git a/merge.py b/merge.py deleted file mode 100644 index 2ad25474..00000000 --- a/merge.py +++ /dev/null @@ -1,51 +0,0 @@ -import os - -def consolidate_repo(root_dir, output_file): - # Extensions to include - extensions = {'.cpp', '.h', '.hh', '.py', '.pxd', '.pyx', '.toml'} - - # Files to ignore (common venv or git directories) - ignore_dirs = {'.git', '__pycache__', 'build', 'dist', 'deps', 'venv', 'env', '.idea', '.vscode'} - - with open(output_file, 'w', encoding='utf-8') as outfile: - # Walk through the directory tree - for dirpath, dirnames, filenames in os.walk(root_dir): - # Modify dirnames in-place to skip ignored directories - dirnames[:] = [d for d in dirnames if d not in ignore_dirs] - - for filename in filenames: - if filename == "wrapper.cpp": - continue - _, ext = os.path.splitext(filename) - - if ext in extensions: - file_path = os.path.join(dirpath, filename) - # Create a relative path for cleaner metadata - rel_path = os.path.relpath(file_path, root_dir) - - try: - with open(file_path, 'r', encoding='utf-8', errors='replace') as infile: - content = infile.read() - - # Write metadata header - outfile.write(f"\n{'='*80}\n") - outfile.write(f"FILE: {rel_path}\n") - outfile.write(f"{'='*80}\n\n") - - # Write file content - outfile.write(content) - outfile.write("\n") # Ensure separation - - print(f"Processed: {rel_path}") - - except Exception as e: - print(f"Error reading {rel_path}: {e}") - -if __name__ == "__main__": - # You can change these paths as needed - source_directory = "." # Current directory - output_filename = "codebase.txt" - - print(f"Scanning directory: {os.path.abspath(source_directory)}") - consolidate_repo(source_directory, output_filename) - print(f"\nDone! All files consolidated into: {output_filename}") \ No newline at end of file diff --git a/setup.py b/setup.py index 422495ce..a6c48a42 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ def read_readme() -> str: "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - "Development Status :: 2 - Pre-Alpha", + "Development Status :: 3 - Alpha", ] COMMON_PROJECT_URLS = { @@ -74,9 +74,8 @@ def read_readme() -> str: } COMMON_EXTRAS = { - "cuda": ["cuda-python"], + "cuda": ["cuda-python", "numpy"], "opencl": ["pyopencl", "numpy"], - "pycuda": ["pycuda"], "numpy": ["numpy"], } @@ -326,10 +325,7 @@ def base_setup_kwargs(): "version": VERSION, "author": "Shahar Sandhaus", "author_email": "shahar.sandhaus@gmail.com", - "description": ( - "A Python module for orchestrating and dispatching large computations " - "across multi-GPU systems using Vulkan." - ), + "description": "Python metaprogramming for GPU compute, with runtime-generated kernels, FFTs, and reductions.", "long_description": read_readme(), "long_description_content_type": "text/markdown", "python_requires": ">=3.6", @@ -373,7 +369,15 @@ def setup_for_target(target: str): "name": "vkdispatch-core", "packages": core_packages(), "install_requires": ["setuptools>=59.0"], - "extras_require": dict(COMMON_EXTRAS), + "extras_require": { + "cli": ["Click"], + **COMMON_EXTRAS, + }, + "entry_points": { + "console_scripts": [ + "vdlist=vkdispatch.cli:cli_entrypoint", + ] + }, } ) return kwargs @@ -405,11 +409,6 @@ def setup_for_target(target: str): "cli": ["Click"], **COMMON_EXTRAS, }, - "entry_points": { - "console_scripts": [ - "vdlist=vkdispatch.cli:cli_entrypoint", - ] - }, } ) return kwargs diff --git a/shader_run.py b/shader_run.py deleted file mode 100644 index 8c34a024..00000000 --- a/shader_run.py +++ /dev/null @@ -1,89 +0,0 @@ -import vkdispatch as vd - -from vkdispatch.base.command_list import CommandList -from vkdispatch.base.compute_plan import ComputePlan -from vkdispatch.base.descriptor_set import DescriptorSet - -import numpy as np - -def load_shader(path: str) -> ComputePlan: - shader_source = open(path, 'r').read() - - return ComputePlan( - shader_source=shader_source, - binding_type_list=[1, 1, 1], - pc_size=0, - shader_name=f"shader_{path.split('/')[-1].split('.')[0]}" - ) - -def make_descriptor(plan: ComputePlan, out_buff: vd.Buffer, in_buff: vd.Buffer, kern_buff: vd.Buffer): - descriptor_set = DescriptorSet(plan) - - descriptor_set.bind_buffer(out_buff, 0) - descriptor_set.bind_buffer(in_buff, 1) - descriptor_set.bind_buffer(kern_buff, 2) - - return descriptor_set - -def numpy_convolution(signal: np.ndarray, kernel: np.ndarray) -> np.ndarray: - return np.fft.ifft( - np.fft.fft(signal, axis=1).astype(np.complex64) - * - kernel.conjugate(), - axis=1 - ) - -BUFF_SHAPE = (4, 512, 257) - -np.random.seed(1337) - -in_data = (np.random.rand(*BUFF_SHAPE) + 1j * np.random.rand(*BUFF_SHAPE)).astype(np.complex64) -kern_data = (np.random.rand(*BUFF_SHAPE) + 1j * np.random.rand(*BUFF_SHAPE)).astype(np.complex64) - -reference_result_data = numpy_convolution(in_data, kern_data[0]) - -out_buff = vd.buffer_c64(BUFF_SHAPE) -in_buff = vd.buffer_c64(BUFF_SHAPE) -kern_buff = vd.buffer_c64(BUFF_SHAPE) - -in_buff.write(in_data) -kern_buff.write(kern_data) - -block_count = (1028, 32, 1) - -plan_bad = load_shader("conv_bad.comp") -plan_good = load_shader("conv_good.comp") - -cmd_list_bad = CommandList() - -cmd_list_bad.record_compute_plan( - plan_bad, - make_descriptor(plan_bad, out_buff, in_buff, kern_buff), - block_count -) - -cmd_list_bad.submit(instance_count=1) - -result_data_bad = out_buff.read(0) - -cmd_list_good = CommandList() - -cmd_list_good.record_compute_plan( - plan_good, - make_descriptor(plan_good, out_buff, in_buff, kern_buff), - block_count -) - -cmd_list_good.submit(instance_count=1) - -result_data_good = out_buff.read(0) - -for i in range(BUFF_SHAPE[0]): - np.save(f"result_bad_{i}.npy", result_data_bad[i]) - np.save(f"result_good_{i}.npy", result_data_good[i]) - np.save(f"reference_result_{i}.npy", reference_result_data[i]) - np.save(f"diff_bad_{i}.npy", result_data_bad[i] - reference_result_data[i]) - np.save(f"diff_good_{i}.npy", result_data_good[i] - reference_result_data[i]) - np.save(f"diff_{i}.npy", result_data_good[i] - result_data_bad[i]) - -assert np.allclose(result_data_good, result_data_bad, atol=1e-3) diff --git a/test.py b/test.py deleted file mode 100644 index b7f21622..00000000 --- a/test.py +++ /dev/null @@ -1,17 +0,0 @@ -import vkdispatch as vd -import vkdispatch.codegen as vc - -vd.initialize(backend="vulkan", log_level=vd.LogLevel.INFO) -vc.set_codegen_backend("glsl") - -SIZE = 4096 - -buff_shape = (2, SIZE, SIZE) - -buff = vd.Buffer(buff_shape, var_type=vd.complex64) - -vd.vkfft.fft(buff, axis=1) #, print_shader=True) - -vd.queue_wait_idle() - -#print(vd.fft.fft_src(buff_shape, axis=1).code) \ No newline at end of file diff --git a/test2.py b/test2.py deleted file mode 100644 index 5f494e18..00000000 --- a/test2.py +++ /dev/null @@ -1,109 +0,0 @@ -import vkdispatch as vd -import vkdispatch.codegen as vc - -#vd.initialize(debug_mode=True, backend="cuda") -#vc.set_codegen_backend("cuda") - -from typing import Callable, Union, Tuple - -import numpy as np - -import time -import dataclasses - -@dataclasses.dataclass -class Config: - data_size: int - iter_count: int - iter_batch: int - run_count: int - signal_factor: int - warmup: int = 10 - - def make_shape(self, fft_size: int) -> Tuple[int, ...]: - total_square_size = fft_size * fft_size - assert self.data_size % total_square_size == 0, "Data size must be a multiple of fft_size squared" - return (self.data_size // total_square_size, fft_size, fft_size) - - def make_random_data(self, fft_size: int): - shape = self.make_shape(fft_size) - return np.random.rand(*shape).astype(np.complex64) - -def run_vkdispatch(config: Config, - fft_size: int, - io_count: Union[int, Callable], - gpu_function: Callable) -> float: - shape = config.make_shape(fft_size) - - buffer = vd.Buffer(shape, var_type=vd.complex64) - kernel = vd.Buffer(shape, var_type=vd.complex64) - - graph = vd.CommandGraph() - old_graph = vd.set_global_graph(graph) - - gpu_function(config, fft_size, buffer, kernel) - - vd.set_global_graph(old_graph) - - for _ in range(config.warmup): - graph.submit(config.iter_batch) - - vd.queue_wait_idle() - - if callable(io_count): - io_count = io_count(buffer.size, fft_size) - - gb_byte_count = io_count * 8 * buffer.size / (1024 * 1024 * 1024) - - start_time = time.perf_counter() - - for _ in range(config.iter_count // config.iter_batch): - graph.submit(config.iter_batch) - - vd.queue_wait_idle() - - elapsed_time = time.perf_counter() - start_time - - buffer.destroy() - kernel.destroy() - graph.destroy() - vd.fft.cache_clear() - - time.sleep(1) - - vd.queue_wait_idle() - - return gb_byte_count, elapsed_time - - -def run_test(config: Config, - io_count: Union[int, Callable], - gpu_function: Callable): - fft_sizes = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_byte_count, elapsed_time = run_vkdispatch(config, fft_size, io_count, gpu_function) - gb_per_second = config.iter_count * gb_byte_count / elapsed_time - - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.4f} GB/s") - rates.append(gb_per_second) - -def do_fft(config: Config, - fft_size: int, - buffer: vd.Buffer, - kernel: vd.Buffer): - vd.fft.fft(buffer) - - -conf = Config( - data_size=2**26, - iter_count=80, - iter_batch=10, - run_count=1, - signal_factor=8 -) - -run_test(conf, 2, do_fft) \ No newline at end of file diff --git a/tests/test_codegen.py b/tests/test_codegen.py index b95b4e83..8dbd7b3a 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -1,6 +1,6 @@ import vkdispatch as vd import vkdispatch.codegen as vc -from vkdispatch.codegen.abreviations import * +from vkdispatch.codegen.abbreviations import * import numpy as np diff --git a/tests/test_command_graph.py b/tests/test_command_graph.py index e2dd15ee..2c57d322 100644 --- a/tests/test_command_graph.py +++ b/tests/test_command_graph.py @@ -1,6 +1,6 @@ import vkdispatch as vd import vkdispatch.codegen as vc -from vkdispatch.codegen.abreviations import * +from vkdispatch.codegen.abbreviations import * vd.initialize(debug_mode=True) diff --git a/tests/test_image.py b/tests/test_image.py index 2a03478c..3d4c957d 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -1,7 +1,7 @@ import vkdispatch as vd import vkdispatch.codegen as vc -from vkdispatch.codegen.abreviations import * +from vkdispatch.codegen.abbreviations import * import numpy as np diff --git a/tests/test_ravel.py b/tests/test_ravel.py index b186bf5c..ffa7c21e 100644 --- a/tests/test_ravel.py +++ b/tests/test_ravel.py @@ -4,10 +4,10 @@ from vkdispatch.base.dtype import to_vector import numpy as np +import pytest from typing import Tuple - def run_index_ravel(shape: Tuple[int, ...], index: Tuple[int, ...], shape_static: bool): var_type = to_vector(vd.uint32, len(shape)) @@ -25,7 +25,7 @@ def test_shader(buff: vc.Buff[var_type]): # pyright: ignore[reportInvalidTypeFor result_value = buffer.read(0) - assert tuple(result_value[index]) == tuple(index), f"Expected index {index}, got {tuple(result_value[index])}" + assert tuple(result_value[index]) == tuple(index), f"Expected index {index}, got {tuple(result_value[index])} for shape {shape} with shape_static={shape_static}" buffer.destroy() diff --git a/tests/test_reductions.py b/tests/test_reductions.py index 3bed232d..7f257809 100644 --- a/tests/test_reductions.py +++ b/tests/test_reductions.py @@ -1,6 +1,6 @@ import vkdispatch as vd import vkdispatch.codegen as vc -from vkdispatch.codegen.abreviations import * +from vkdispatch.codegen.abbreviations import * import numpy as np vd.initialize(debug_mode=True) diff --git a/vkdispatch/__init__.py b/vkdispatch/__init__.py index 27e99e2a..3488df7a 100644 --- a/vkdispatch/__init__.py +++ b/vkdispatch/__init__.py @@ -6,6 +6,8 @@ from .base.init import is_initialized from .base.init import log, log_error, log_warning, log_info, log_verbose, set_log_level +from .backends.backend_selection import BackendUnavailableError + from .base.dtype import dtype from .base.dtype import float16, float32, float64, int16, uint16, int32, uint32, int64, uint64 from .base.dtype import complex32, complex64, complex128 @@ -54,8 +56,7 @@ from .execution_pipeline.command_graph import global_graph, set_global_graph, default_graph from .execution_pipeline.cuda_graph_capture import cuda_graph_capture, get_cuda_capture, CUDAGraphCapture -from .shader.shader_function import ShaderFunction, ShaderSource -from .shader.context import ShaderContext, shader_context +from .shader.shader_function import ShaderBuildError, ShaderFunction, ShaderSource, make_shader_function from .shader.map import map, MappingFunction from .shader.decorator import shader @@ -63,4 +64,4 @@ import vkdispatch.fft as fft import vkdispatch.reduce as reduce -__version__ = "0.0.34" +__version__ = "0.1.0" diff --git a/vkdispatch/backends/backend_selection.py b/vkdispatch/backends/backend_selection.py index 6a3836b9..a0538ad3 100644 --- a/vkdispatch/backends/backend_selection.py +++ b/vkdispatch/backends/backend_selection.py @@ -93,7 +93,7 @@ def _load_backend_module(backend_name: str) -> ModuleType: if backend_name == BACKEND_VULKAN: raise BackendUnavailableError( backend_name, - "Vulkan backend is unavailable because the 'vkdispatch_native' package " + "Vulkan backend is unavailable because the 'vkdispatch_vulkan_native' package " f"could not be imported ({exc}).", ) from exc if backend_name == BACKEND_CUDA: diff --git a/vkdispatch/backends/dummy_backend.py b/vkdispatch/backends/dummy_backend.py index 420a59f8..f79a4294 100644 --- a/vkdispatch/backends/dummy_backend.py +++ b/vkdispatch/backends/dummy_backend.py @@ -97,7 +97,7 @@ def _clear_error(): _DUMMY_CODEGEN_ONLY_ERROR = ( "The 'dummy' backend is codegen-only and does not support runtime GPU " - "operations. Use backend='vulkan', backend='pycuda', or backend='cuda-python' for execution." + "operations. Use backend='vulkan', backend='cuda', or backend='opencl' for execution." ) diff --git a/vkdispatch/base/buffer.py b/vkdispatch/base/buffer.py index 6f49b622..15cf4bb5 100644 --- a/vkdispatch/base/buffer.py +++ b/vkdispatch/base/buffer.py @@ -192,8 +192,11 @@ def write(self, data: Union[bytes, bytearray, memoryview, typing.Any], index: in :raises ValueError: If the data size exceeds the buffer size or if the index is invalid. """ if index is not None: - assert isinstance(index, int), "Index must be an integer or None!" - assert index >= 0 and index < self.context.queue_count, "Index must be valid!" + if not isinstance(index, int): + raise ValueError("Index must be an integer or None!") + + if index < 0 or index >= self.context.queue_count: + raise ValueError("Index must be a valid device index within the context!") if not getattr(self, "is_writable", True): raise ValueError("Cannot write to a read-only buffer alias.") @@ -214,8 +217,6 @@ def write(self, data: Union[bytes, bytearray, memoryview, typing.Any], index: in self._do_writes(true_data_object, index) def _do_reads(self, var_type: dtype, shape: List[int], index: int = None) -> bytes: - assert index is None or (isinstance(index, int) and index >= 0), "Index must be None or a non-negative integer!" - indicies = [index] if index is not None else range(self.context.queue_count) completed_stages = [0] * len(indicies) bytes_list: List[bytes] = [None] * len(indicies) @@ -275,6 +276,12 @@ def read(self, index: Union[int, None] = None): data_shape = list(self.shape) + list(self.var_type.true_numpy_shape) if index is not None: + if not isinstance(index, int): + raise ValueError("Index must be an integer or None!") + + if index < 0 or index >= self.context.queue_count: + raise ValueError("Index must be a valid device index within the context!") + return self._do_reads(true_scalar, data_shape, index) results = self._do_reads(true_scalar, data_shape, None) @@ -306,7 +313,8 @@ def from_cuda_array( writable: typing.Optional[bool] = None, keepalive: bool = True, ) -> Buffer: - assert is_cuda(), "__cuda_array_interface__ is only supported with CUDA backends." + if not is_cuda(): + raise RuntimeError("from_cuda_array is only supported with CUDA backends.") if not hasattr(obj, "__cuda_array_interface__"): raise TypeError("Expected an object with __cuda_array_interface__") @@ -402,8 +410,12 @@ def read_fourier(self, index: Union[int, None] = None): def write_real(self, data, index: int = None): npc.require_numpy("RFFTBuffer.write_real") np = npc.numpy_module() - assert data.shape == self.real_shape, "Data shape must match real shape!" - assert not np.issubdtype(data.dtype, np.complexfloating), "Data dtype must be scalar!" + + if data.shape != self.real_shape: + raise ValueError(f"Data shape {data.shape} must match real shape {self.real_shape}!") + + if np.issubdtype(data.dtype, np.complexfloating): + raise ValueError("Data dtype must be real, not complex!") real_dtype = to_numpy_dtype(self.real_type) true_data = np.zeros(self.shape[:-1] + (self.shape[-1] * 2,), dtype=real_dtype) @@ -414,8 +426,12 @@ def write_real(self, data, index: int = None): def write_fourier(self, data, index: int = None): npc.require_numpy("RFFTBuffer.write_fourier") np = npc.numpy_module() - assert data.shape == self.fourier_shape, f"Data shape {data.shape} must match fourier shape {self.fourier_shape}!" - assert np.issubdtype(data.dtype, np.complexfloating), "Data dtype must be complex!" + + if data.shape != self.fourier_shape: + raise ValueError(f"Data shape {data.shape} must match fourier shape {self.fourier_shape}!") + + if not np.issubdtype(data.dtype, np.complexfloating): + raise ValueError("Data dtype must be complex!") target_fourier_dtype = to_numpy_dtype(self.var_type) if npc.is_host_dtype(target_fourier_dtype): @@ -435,7 +451,9 @@ def write_fourier(self, data, index: int = None): def asrfftbuffer(data, fourier_type: Optional[dtype] = None) -> RFFTBuffer: npc.require_numpy("asrfftbuffer") np = npc.numpy_module() - assert not np.issubdtype(data.dtype, np.complexfloating), "Data dtype must be scalar!" + + if np.issubdtype(data.dtype, np.complexfloating): + raise ValueError("Input data to asrfftbuffer must be real-valued!") if fourier_type is None: scalar_dtype = from_numpy_dtype(data.dtype) diff --git a/vkdispatch/base/command_list.py b/vkdispatch/base/command_list.py index 99fa2799..e1fdde27 100644 --- a/vkdispatch/base/command_list.py +++ b/vkdispatch/base/command_list.py @@ -112,9 +112,9 @@ def submit( instance_count = len(data) // self.get_instance_size() - if self.get_instance_size() != 0: - assert self.get_instance_size() * instance_count == len(data), "Data length must be the product of the instance size and instance count!" - + if self.get_instance_size() != 0 and self.get_instance_size() * instance_count != len(data): + raise ValueError("Data length must be the product of the instance size and instance count!") + if cuda_stream is None and get_cuda_capture() is not None: cuda_stream = get_cuda_capture().cuda_stream diff --git a/vkdispatch/base/context.py b/vkdispatch/base/context.py index 45351b32..ebc0161b 100644 --- a/vkdispatch/base/context.py +++ b/vkdispatch/base/context.py @@ -11,7 +11,7 @@ import os, signal from .errors import check_for_errors, set_running -from .init import DeviceInfo, is_cuda, is_opencl, is_dummy, get_devices, initialize, log_info +from .init import DeviceInfo, is_cuda, is_opencl, is_dummy, get_devices, initialize, log_info, log_warning from ..backends.backend_selection import native VK_SHADER_STAGE_COMPUTE_BIT = 0x00000020 @@ -27,13 +27,14 @@ class Handle: children_dict: MutableMapping[int, "Handle"] def __init__(self): - self.context = get_context() + self.context = None self._handle = None self.destroyed = False self.parents = {} self.children_dict = weakref.WeakValueDictionary() self.canary = False + self.context = get_context() def register_handle(self, handle: int) -> None: """ @@ -87,7 +88,7 @@ def destroy(self) -> None: """ Destroys the context handle and cleans up resources. """ - if self.destroyed: + if getattr(self, "destroyed", True): return self.destroyed = True @@ -100,16 +101,22 @@ def destroy(self) -> None: child = self.children_dict[child_handle] child.destroy() - assert len(self.children_dict) == 0, "Not all children were destroyed!" + if len(self.children_dict) > 0: + log_warning(f"Warning: Not all child handles were destroyed for handle {self._handle}!") - assert not self.canary, "Handle was already destroyed!" + if self.canary: + raise RuntimeError("Handle was already destroyed!") + if self._handle is not None: self._destroy() check_for_errors() self.canary = True - if self._handle in self.context.handles_dict.keys(): + if ( + self.context is not None + and self._handle in self.context.handles_dict.keys() + ): self.context.handles_dict.pop(self._handle) @@ -406,18 +413,14 @@ def make_context( total_devices = len(get_devices()) - # Do type checking before passing to native code - assert len(device_ids) == len( - queue_families - ), "Device and submission thread count lists must be the same length!" + if len(device_ids) != len(queue_families): + raise ValueError("Device and submission thread count lists must be the same length!") - assert all( - [type(dev) == int for dev in device_ids] - ), "Device list must be a list of integers!" - - assert all( - [dev >= 0 and dev < total_devices for dev in device_ids] - ), f"All device indicies must between 0 and {total_devices}" + if not all([isinstance(dev, int) for dev in device_ids]): + raise ValueError("Device list must be a list of integers!") + + if not all([dev >= 0 and dev < total_devices for dev in device_ids]): + raise ValueError(f"All device indicies must between 0 and {total_devices}") __context = Context(device_ids, queue_families) @@ -565,9 +568,12 @@ def queue_wait_idle(queue_index: int = None, context: Context = None) -> None: if context is None: context = get_context() - assert queue_index is None or isinstance(queue_index, int), "queue_index must be an integer or None." - assert queue_index is None or queue_index >= 0, "queue_index must be a non-negative integer or None (for all queues)." - assert queue_index is None or queue_index < context.queue_count, f"Queue index {queue_index} is out of bounds for context with {context.queue_count} queues." + if queue_index is not None: + if not isinstance(queue_index, int): + raise ValueError("queue_index must be an integer or None!") + + if queue_index < 0 or queue_index >= context.queue_count: + raise ValueError(f"queue_index must be between 0 and {context.queue_count - 1} (inclusive) or None for all queues!") if queue_index is None: for i in range(context.queue_count): @@ -601,7 +607,8 @@ def destroy_context() -> None: log_info(f"Destroying handle {handle._handle}...") handle.destroy() - assert len(__context.handles_dict) == 0, "Not all handles were destroyed!" + if len(__context.handles_dict) > 0: + log_warning(f"Warning: Not all handles were destroyed for context handle {__context._handle}!") log_info("Calling native context destroy...") native.context_destroy(__context._handle) @@ -613,25 +620,60 @@ def stop_threads() -> None: """ Stops all threads in the context. """ - native.context_stop_threads(get_context_handle()) + global __context -_shutdown_once = False + if __context is None: + return -def _sig_handler(signum, frame): - print("Ctrl-C received, stopping threads...") + native.context_stop_threads(__context._handle) + +_sigint_shutdown_once = False +_previous_signal_handlers = {} + +def _call_previous_handler(signum, frame) -> None: + previous_handler = _previous_signal_handlers.get(signum) + + if callable(previous_handler) and previous_handler is not _sig_handler: + try: + previous_handler(signum, frame) + except Exception: + pass - global _shutdown_once +def _reraised_signal(signum: int, handler) -> None: + signal.signal(signum, handler) + os.kill(os.getpid(), signum) + +def _sig_handler(signum, frame): + global _sigint_shutdown_once set_running(False) - if not _shutdown_once: - _shutdown_once = True - # Flip the C++ atomic and notify all sleepers - stop_threads() + + if signum == signal.SIGINT: + print("Ctrl-C received, stopping threads...") + + if not _sigint_shutdown_once: + _sigint_shutdown_once = True + stop_threads() + return + + _reraised_signal(signum, signal.SIG_DFL) return - # Second Ctrl-C → default behavior (fast exit with right code) - signal.signal(signum, signal.SIG_DFL) - os.kill(os.getpid(), signum) + + if signum == signal.SIGTERM: + try: + destroy_context() + except Exception: + pass + + _call_previous_handler(signum, frame) + _reraised_signal(signum, signal.SIG_DFL) + return + + _call_previous_handler(signum, frame) + _reraised_signal(signum, signal.SIG_DFL) # No need to register signal handlers in Brython, since it runs in a browser if not sys.implementation.name == "Brython": + _previous_signal_handlers[signal.SIGINT] = signal.getsignal(signal.SIGINT) + _previous_signal_handlers[signal.SIGTERM] = signal.getsignal(signal.SIGTERM) signal.signal(signal.SIGINT, _sig_handler) signal.signal(signal.SIGTERM, _sig_handler) diff --git a/vkdispatch/base/dtype.py b/vkdispatch/base/dtype.py index 62ea81d3..1b5a1455 100644 --- a/vkdispatch/base/dtype.py +++ b/vkdispatch/base/dtype.py @@ -571,19 +571,28 @@ def vector_size(dtype: dtype) -> int: return dtype.child_count def cross_scalar_scalar(dtype1: dtype, dtype2: dtype) -> dtype: - assert is_scalar(dtype1) and is_scalar(dtype2), "Both types must be scalar types!" + if not is_scalar(dtype1): + raise ValueError("First type must be a scalar type!") + + if not is_scalar(dtype2): + raise ValueError("Second type must be a scalar type!") r1 = _SCALAR_RANK[dtype1] r2 = _SCALAR_RANK[dtype2] return dtype1 if r1 >= r2 else dtype2 def cross_vector_scalar(dtype1: dtype, dtype2: dtype) -> dtype: - assert is_vector(dtype1) and is_scalar(dtype2), "First type must be vector type and second type must be scalar type!" + if not is_vector(dtype1): + raise ValueError("First type must be a vector type!") + + if not is_scalar(dtype2): + raise ValueError("Second type must be a scalar type!") return to_vector(cross_scalar_scalar(dtype1.scalar, dtype2), dtype1.child_count) def cross_vector_vector(dtype1: dtype, dtype2: dtype) -> dtype: - assert is_vector(dtype1) and is_vector(dtype2), "Both types must be vector types!" + if not is_vector(dtype1) or not is_vector(dtype2): + raise ValueError("Both types must be vector types!") if dtype1.child_count != dtype2.child_count: raise ValueError(f"Cannot cross types of vectors of two sizes! ({dtype1.child_count} != {dtype2.child_count})") @@ -591,7 +600,8 @@ def cross_vector_vector(dtype1: dtype, dtype2: dtype) -> dtype: return to_vector(cross_scalar_scalar(dtype1.scalar, dtype2.scalar), dtype1.child_count) def cross_vector(dtype1: dtype, dtype2: dtype) -> dtype: - assert is_vector(dtype1), "First type must be vector type!" + if not is_vector(dtype1): + raise ValueError("First type must be a vector type!") if is_vector(dtype2): return cross_vector_vector(dtype1, dtype2) @@ -603,7 +613,8 @@ def cross_vector(dtype1: dtype, dtype2: dtype) -> dtype: raise ValueError("Second type must be vector or scalar type!") def cross_matrix(dtype1: dtype, dtype2: dtype) -> dtype: - assert is_matrix(dtype1), "Both types must be matrix types!" + if not is_matrix(dtype1): + raise ValueError("First type must be a matrix type!") if is_matrix(dtype2): if dtype1.shape != dtype2.shape: diff --git a/vkdispatch/base/errors.py b/vkdispatch/base/errors.py index ca6068b1..dae50bcd 100644 --- a/vkdispatch/base/errors.py +++ b/vkdispatch/base/errors.py @@ -40,7 +40,5 @@ def check_for_compute_stage_errors(): if not isinstance(error, str): raise RuntimeError("Unknown error occurred") - - print("Shader compilation error:\n", error) - raise RuntimeError("Error occurred in compute stage") + raise RuntimeError(error) diff --git a/vkdispatch/base/image.py b/vkdispatch/base/image.py index f78ec483..a1c856f7 100644 --- a/vkdispatch/base/image.py +++ b/vkdispatch/base/image.py @@ -59,7 +59,8 @@ class image_format(Enum): # TODO: Fix class naming scheme to adhere to conventi # TODO: This can be moved into the enum class as an indexing method def select_image_format(dtype: vdt.dtype, channels: int) -> image_format: - assert channels in [1, 2, 3, 4], f"Unsupported number of channels ({channels})! Must be 1, 2, 3 or 4!" + if channels < 1 or channels > 4: + raise ValueError(f"Unsupported number of channels ({channels})! Must be 1, 2, 3 or 4!") # NOTE: These large if-else statements can be better indexed and maintained by a # dictionary lookup scheme @@ -251,18 +252,23 @@ def __init__( ) -> None: super().__init__() - assert len(shape) == 1 or len(shape) == 2 or len(shape) == 3, "Shape must be 2D or 3D!" + if len(shape) < 1 or len(shape) > 3: + raise ValueError("Shape must be 1D, 2D or 3D!") - assert type(shape[0]) == int, "Shape must be a tuple of integers!" + if type(shape[0]) != int: + raise ValueError("Shape must be a tuple of integers!") - if len(shape) > 1: - assert type(shape[1]) == int, "Shape must be a tuple of integers!" + if len(shape) > 1 and type(shape[1]) != int: + raise ValueError("Shape must be a tuple of integers!") - if len(shape) == 3: - assert type(shape[2]) == int, "Shape must be a tuple of integers!" - - assert issubclass(dtype, vdt.dtype), "Dtype must be a dtype!" - assert type(channels) == int, "Channels must be an integer!" + if len(shape) >2 and type(shape[2]) != int: + raise ValueError("Shape must be a tuple of integers!") + + if not issubclass(dtype, vdt.dtype): + raise ValueError("Dtype must be a dtype!") + + if type(channels) != int: + raise ValueError("Channels must be an integer!") self.type = image_type.TYPE_1D @@ -390,7 +396,8 @@ class Image2D(Image): def __init__( self, shape: typing.Tuple[int, int], dtype: type = vdt.float32, channels: int = 1, enable_mipmaps: bool = False ) -> None: - assert len(shape) == 2, "Shape must be 2D!" + if len(shape) != 2: + raise ValueError("Shape must be 2D!") super().__init__(shape, 1, dtype, channels, image_view_type.VIEW_TYPE_2D, enable_mipmaps) @classmethod @@ -407,7 +414,9 @@ def __init__( channels: int = 1, enable_mipmaps: bool = False ) -> None: - assert len(shape) == 2, "Shape must be 2D!" + if len(shape) != 2: + raise ValueError("Shape must be 2D!") + super().__init__( shape, layers, dtype, channels, image_view_type.VIEW_TYPE_2D_ARRAY, enable_mipmaps ) @@ -421,7 +430,9 @@ class Image3D(Image): def __init__( self, shape: typing.Tuple[int, int, int], dtype: type = vdt.float32, channels: int = 1, enable_mipmaps: bool = False ) -> None: - assert len(shape) == 3, "Shape must be 3D!" + if len(shape) != 3: + raise ValueError("Shape must be 3D!") + super().__init__(shape, 1, dtype, channels, image_view_type.VIEW_TYPE_3D, enable_mipmaps) @classmethod diff --git a/vkdispatch/base/init.py b/vkdispatch/base/init.py index a4aa7c26..ebbd99d3 100644 --- a/vkdispatch/base/init.py +++ b/vkdispatch/base/init.py @@ -2,6 +2,7 @@ from enum import Enum import os from typing import Tuple, List, Optional +import sys import inspect @@ -552,6 +553,9 @@ def initialize( ) return + error_backend = None + error_string = None + if ( not backend_explicitly_selected and backend_name == BACKEND_VULKAN @@ -564,7 +568,9 @@ def initialize( loader_debug_logs=loader_debug_logs, ) return - except BackendUnavailableError as vulkan_error: + except BackendUnavailableError as _: + print("\033[33mVkDispatch Warning:\033[0m Vulkan backend unavailable, trying CUDA backend...", file=sys.stderr) + try: _initialize_with_backend( BACKEND_CUDA, @@ -573,7 +579,9 @@ def initialize( loader_debug_logs=loader_debug_logs, ) return - except Exception as cuda_python_error: + except BackendUnavailableError as _: + print("\033[33mVkDispatch Warning:\033[0m CUDA backend unavailable, trying OpenCL backend...", file=sys.stderr) + try: _initialize_with_backend( BACKEND_OPENCL, @@ -582,24 +590,33 @@ def initialize( loader_debug_logs=loader_debug_logs, ) return - except Exception as opencl_error: - raise _build_no_gpu_backend_error( - vulkan_error, - cuda_python_error, - opencl_error, - ) from opencl_error - - try: - _initialize_with_backend( - backend_name, - debug_mode=debug_mode, - log_level=log_level, - loader_debug_logs=loader_debug_logs, - ) - except BackendUnavailableError as backend_error: - if backend_name == BACKEND_VULKAN: - raise _build_vulkan_backend_error(backend_error) from backend_error - raise + except BackendUnavailableError as _: + error_backend = BACKEND_VULKAN + error_string = f"""No available backend! +Please install one of the three supported backends: + Vulkan (`pip install vkdispatch-vulkan-native`) + CUDA (`pip install vkdispatch-core[cuda]`) + OpenCL (`pip install vkdispatch-core[opencl]`)""" + else: + try: + _initialize_with_backend( + backend_name, + debug_mode=debug_mode, + log_level=log_level, + loader_debug_logs=loader_debug_logs, + ) + except BackendUnavailableError as _: + backend_error_dict = { + BACKEND_VULKAN: "Vulkan backend unavailable. It can be installed with `pip install vkdispatch-vulkan-native`.", + BACKEND_CUDA: "CUDA Python backend unavailable. It can be enabled by installing the `cuda-python` package (`pip install vkdispatch-core[cuda]`).", + BACKEND_OPENCL: "OpenCL backend unavailable. It can be enabled by installing the `pyopencl` package (`pip install vkdispatch-core[opencl]`)." + } + + error_backend = backend_name + error_string = f"{backend_error_dict.get(backend_name, 'Selected backend unavailable')}" + + if error_backend is not None and error_string is not None: + raise BackendUnavailableError(error_backend, error_string) def get_devices() -> List[DeviceInfo]: diff --git a/vkdispatch/cli.py b/vkdispatch/cli.py index 2687444c..ea7fc247 100644 --- a/vkdispatch/cli.py +++ b/vkdispatch/cli.py @@ -10,12 +10,35 @@ @click.option('--log_info', is_flag=True, help="Will print verbose messages.") @click.option('--vulkan_loader_debug_logs', '--loader_debug', is_flag=True, help="Enable debug logs for the vulkan loader.") @click.option('--debug', is_flag=True, help="Enable debug logs for the vulkan loader.") +@click.option('--vulkan', is_flag=True, help="Use the Vulkan backend (this is the default backend).") +@click.option('--cuda', is_flag=True, help="Use the CUDA backend.") +@click.option('--opencl', is_flag=True, help="Use the OpenCL backend.") @click.version_option(version=vd.__version__) -def cli_entrypoint(verbose, log_info, vulkan_loader_debug_logs, debug): - if log_info or debug: - vd.initialize(log_level=vd.LogLevel.INFO, loader_debug_logs=vulkan_loader_debug_logs or debug) - else: - vd.initialize(log_level=vd.LogLevel.WARNING, loader_debug_logs=vulkan_loader_debug_logs or debug) +def cli_entrypoint(verbose, log_info, vulkan_loader_debug_logs, debug, vulkan, cuda, opencl): + selected_backend = None + + if vulkan: + if cuda or opencl: + raise click.UsageError("Multiple backends selected. Please select only one backend.") + selected_backend = "vulkan" + + if cuda: + if opencl: + raise click.UsageError("Multiple backends selected. Please select only one backend.") + + selected_backend = "cuda" + + if opencl: + selected_backend = "opencl" + + log_level = vd.LogLevel.INFO if log_info or debug else vd.LogLevel.WARNING + loader_debug_logs = vulkan_loader_debug_logs or debug + + try: + vd.initialize(log_level=log_level, loader_debug_logs=loader_debug_logs, backend=selected_backend) + except vd.BackendUnavailableError as e: + print(f"\033[31mError initializing Vkdispatch:\033[0m {str(e)}") + return for dev in vd.get_devices(): print(dev.get_info_string(verbose)) diff --git a/vkdispatch/codegen/__init__.py b/vkdispatch/codegen/__init__.py index 1d07e8eb..ec5b4720 100644 --- a/vkdispatch/codegen/__init__.py +++ b/vkdispatch/codegen/__init__.py @@ -1,4 +1,4 @@ -from .arguments import Constant, Variable, ConstantArray, VariableArray +from .arguments import Constant, Variable from .arguments import Buffer, Image1D, Image2D, Image3D from .arguments import _ArgType @@ -22,8 +22,8 @@ from .functions.geometric import length, distance, dot, cross, normalize -from .functions.block_synchonization import barrier, memory_barrier, memory_barrier_buffer -from .functions.block_synchonization import memory_barrier_shared, memory_barrier_image, group_memory_barrier +from .functions.block_synchronization import barrier, memory_barrier, memory_barrier_buffer +from .functions.block_synchronization import memory_barrier_shared, memory_barrier_image, group_memory_barrier from .functions.matrix import matrix_comp_mult, outer_product, transpose from .functions.matrix import determinant, inverse @@ -62,10 +62,14 @@ from .functions.subgroups import subgroup_or, subgroup_xor, subgroup_elect from .functions.subgroups import subgroup_barrier -from .functions.control_flow import if_statement, if_any, if_all, else_statement -from .functions.control_flow import else_if_statement, else_if_any, else_if_all -from .functions.control_flow import return_statement, while_statement, new_scope, end -from .functions.control_flow import logical_and, logical_or +# from .functions.control_flow import if_statement, if_any, if_all, else_statement +# from .functions.control_flow import else_if_statement, else_if_any, else_if_all +# from .functions.control_flow import return_statement, while_statement, new_scope, end +# from .functions.control_flow import logical_and, logical_or, any, all + +from .functions.control_flow import if_block, else_block, else_if_block +from .functions.control_flow import return_statement, while_block, scope_block +from .functions.control_flow import any, all from .functions.complex_numbers import mult_complex, complex_from_euler_angle @@ -78,7 +82,9 @@ from .functions.printing import printf from .functions.printing import print_vars as print -from .builder import ShaderBinding, ShaderDescription +from .shader_description import ShaderDescription, BindingType, ShaderArgumentInfo + +from .builder import ShaderBinding from .builder import ShaderBuilder, ShaderFlags from .backends import CodeGenBackend, GLSLBackend, CUDABackend, OpenCLBackend @@ -86,4 +92,6 @@ from .global_builder import set_builder, get_builder, shared_buffer, set_shader_print_line_numbers, get_shader_print_line_numbers from .global_builder import set_codegen_backend, get_codegen_backend -from .abreviations import * +from .context import ShaderContext, shader_context + +from .abbreviations import * diff --git a/vkdispatch/codegen/abreviations.py b/vkdispatch/codegen/abbreviations.py similarity index 95% rename from vkdispatch/codegen/abreviations.py rename to vkdispatch/codegen/abbreviations.py index f9815812..812cb183 100644 --- a/vkdispatch/codegen/abreviations.py +++ b/vkdispatch/codegen/abbreviations.py @@ -1,7 +1,5 @@ from .arguments import Constant as Const from .arguments import Variable as Var -from .arguments import ConstantArray as ConstArr -from .arguments import VariableArray as VarArr from .arguments import Buffer as Buff from .arguments import Image1D as Img1 from .arguments import Image2D as Img2 @@ -43,4 +41,5 @@ from vkdispatch.base.dtype import uvec4 as uv4 from vkdispatch.base.dtype import mat2 as m2 +from vkdispatch.base.dtype import mat3 as m3 from vkdispatch.base.dtype import mat4 as m4 diff --git a/vkdispatch/codegen/arguments.py b/vkdispatch/codegen/arguments.py index fd1133da..a78001f0 100644 --- a/vkdispatch/codegen/arguments.py +++ b/vkdispatch/codegen/arguments.py @@ -1,6 +1,7 @@ import typing -from .builder import ShaderVariable, BufferVariable, ImageVariable +from .variables.variables import ShaderVariable +from .variables.bound_variables import BufferVariable, ImageVariable from vkdispatch.base.dtype import dtype _ArgType = typing.TypeVar('_ArgType', bound=dtype) @@ -14,30 +15,30 @@ class Variable(ShaderVariable, typing.Generic[_ArgType]): def __init__(self) -> None: pass -class ConstantArray(ShaderVariable, typing.Generic[_ArgType, _ArgCount]): - def __init__(self) -> None: - pass - -class VariableArray(ShaderVariable, typing.Generic[_ArgType, _ArgCount]): - def __init__(self) -> None: - pass - class Buffer(BufferVariable, typing.Generic[_ArgType]): def __init__(self) -> None: pass class Image1D(ImageVariable, typing.Generic[_ArgType]): + dimensions: int = 1 + def __init__(self) -> None: pass class Image2D(ImageVariable, typing.Generic[_ArgType]): + dimensions: int = 2 + def __init__(self) -> None: pass class Image2DArray(ImageVariable, typing.Generic[_ArgType]): + dimensions: int = 2 + def __init__(self) -> None: pass class Image3D(ImageVariable, typing.Generic[_ArgType]): + dimensions: int = 3 + def __init__(self) -> None: pass \ No newline at end of file diff --git a/vkdispatch/codegen/backends/base.py b/vkdispatch/codegen/backends/base.py index aafdab6f..5e909d47 100644 --- a/vkdispatch/codegen/backends/base.py +++ b/vkdispatch/codegen/backends/base.py @@ -22,6 +22,11 @@ def mark_feature_usage(self, feature_name: str) -> None: # Backends that emit optional helper code can override this. return + def uses_feature(self, feature_name: str) -> bool: + # Backends that track optional helper code can override this. + _ = feature_name + return False + def mark_composite_unary_op(self, var_type: dtypes.dtype, op: str) -> None: # Backends with composite helper/operator code can override this. return @@ -62,6 +67,25 @@ def buffer_component_expr( _ = (scalar_buffer_expr, base_type, element_index_expr, component_index_expr) return None + def packed_buffer_read_expr( + self, + scalar_buffer_expr: str, + var_type: dtypes.dtype, + element_index_expr: str, + ) -> Optional[str]: + _ = (scalar_buffer_expr, var_type, element_index_expr) + return None + + def packed_buffer_write_statements( + self, + scalar_buffer_expr: str, + var_type: dtypes.dtype, + element_index_expr: str, + value_expr: str, + ) -> Optional[str]: + _ = (scalar_buffer_expr, var_type, element_index_expr, value_expr) + return None + def fma_function_name(self, var_type: dtypes.dtype) -> str: return "fma" diff --git a/vkdispatch/codegen/backends/cuda/backend.py b/vkdispatch/codegen/backends/cuda/backend.py index 7cd91f29..d5b07a7f 100644 --- a/vkdispatch/codegen/backends/cuda/backend.py +++ b/vkdispatch/codegen/backends/cuda/backend.py @@ -39,6 +39,19 @@ class CUDABackend(CodeGenBackend): name = "cuda" + _SUBGROUP_FEATURE_NAMES = { + "num_subgroups", + "subgroup_id", + "subgroup_size", + "subgroup_invocation_id", + "subgroup_add", + "subgroup_mul", + "subgroup_min", + "subgroup_max", + "subgroup_and", + "subgroup_or", + "subgroup_xor", + } _CUDA_BUILTIN_UVEC3_SENTINELS: Dict[str, Dict[str, str]] = { "global_invocation_id": { "sentinel": "VKDISPATCH_CUDA_GLOBAL_INVOCATION_ID_SENTINEL()", @@ -79,11 +92,24 @@ def reset_state(self) -> None: self._sample_texture_dims: Set[int] = set() self._needs_cuda_fp16: bool = False self._feature_usage: Dict[str, bool] = initialize_feature_usage() + self._printf_used: bool = False def mark_feature_usage(self, feature_name: str) -> None: if feature_name in self._feature_usage: self._feature_usage[feature_name] = True + def uses_feature(self, feature_name: str) -> bool: + if feature_name == "subgroup_ops": + return any( + self._feature_usage.get(name, False) + for name in self._SUBGROUP_FEATURE_NAMES + ) + + if feature_name == "printf": + return self._printf_used + + return False + _DTYPE_TO_COMPOSITE_KEY = _CUDA_DTYPE_TO_COMPOSITE_KEY def _composite_key_for_dtype(self, var_type: dtypes.dtype) -> Optional[str]: @@ -399,7 +425,9 @@ def constructor( target_type = self.type_name(var_type) if dtypes.is_scalar(var_type): - assert len(args) > 0, f"Constructor for scalar type '{var_type.name}' needs at least one argument." + if len(args) == 0: + raise ValueError(f"Constructor for scalar type '{var_type.name}' needs at least one argument.") + return f"(({target_type})({args[0]}))" if var_type == dtypes.mat2: @@ -650,7 +678,9 @@ def _cuda_componentwise_binary_math_expr( return None helper_suffix = lhs_helper if lhs_helper is not None else rhs_helper - assert helper_suffix is not None + + if helper_suffix is None: + raise ValueError("At least one of the argument types should have a float vector helper suffix") signature = ("v" if lhs_helper is not None else "s") + ("v" if rhs_helper is not None else "s") self._record_vec_binary_math(helper_suffix, func_name, signature) @@ -861,6 +891,7 @@ def subgroup_barrier_statement(self) -> str: return "__syncwarp();" def printf_statement(self, fmt: str, args: List[str]) -> str: + self._printf_used = True #safe_fmt = fmt.replace("\\", "\\\\").replace('"', '\\"') if len(args) == 0: diff --git a/vkdispatch/codegen/backends/glsl.py b/vkdispatch/codegen/backends/glsl.py index c2187e06..4d085230 100644 --- a/vkdispatch/codegen/backends/glsl.py +++ b/vkdispatch/codegen/backends/glsl.py @@ -21,9 +21,24 @@ class GLSLBackend(CodeGenBackend): def __init__(self) -> None: super().__init__() self._needed_extensions: Set[str] = set() + self._feature_usage = { + "subgroup_ops": False, + "printf": False, + } def reset_state(self) -> None: self._needed_extensions = set() + self._feature_usage = { + "subgroup_ops": False, + "printf": False, + } + + def mark_feature_usage(self, feature_name: str) -> None: + if feature_name in self._feature_usage: + self._feature_usage[feature_name] = True + + def uses_feature(self, feature_name: str) -> bool: + return self._feature_usage.get(feature_name, False) def _track_type_extension(self, var_type: dtypes.dtype) -> None: """Record the GLSL extension required by *var_type* (if any).""" @@ -148,15 +163,19 @@ def num_workgroups_expr(self) -> str: return "gl_NumWorkGroups" def num_subgroups_expr(self) -> str: + self.mark_feature_usage("subgroup_ops") return "gl_NumSubgroups" def subgroup_id_expr(self) -> str: + self.mark_feature_usage("subgroup_ops") return "gl_SubgroupID" def subgroup_size_expr(self) -> str: + self.mark_feature_usage("subgroup_ops") return "gl_SubgroupSize" def subgroup_invocation_id_expr(self) -> str: + self.mark_feature_usage("subgroup_ops") return "gl_SubgroupInvocationID" def barrier_statement(self) -> str: @@ -179,39 +198,49 @@ def group_memory_barrier_statement(self) -> str: def subgroup_add_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: _ = arg_type + self.mark_feature_usage("subgroup_ops") return f"subgroupAdd({arg_expr})" def subgroup_mul_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: _ = arg_type + self.mark_feature_usage("subgroup_ops") return f"subgroupMul({arg_expr})" def subgroup_min_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: _ = arg_type + self.mark_feature_usage("subgroup_ops") return f"subgroupMin({arg_expr})" def subgroup_max_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: _ = arg_type + self.mark_feature_usage("subgroup_ops") return f"subgroupMax({arg_expr})" def subgroup_and_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: _ = arg_type + self.mark_feature_usage("subgroup_ops") return f"subgroupAnd({arg_expr})" def subgroup_or_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: _ = arg_type + self.mark_feature_usage("subgroup_ops") return f"subgroupOr({arg_expr})" def subgroup_xor_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: _ = arg_type + self.mark_feature_usage("subgroup_ops") return f"subgroupXor({arg_expr})" def subgroup_elect_expr(self) -> str: + self.mark_feature_usage("subgroup_ops") return "subgroupElect()" def subgroup_barrier_statement(self) -> str: + self.mark_feature_usage("subgroup_ops") return "subgroupBarrier();" def printf_statement(self, fmt: str, args: List[str]) -> str: + self.mark_feature_usage("printf") args_suffix = "" if len(args) > 0: diff --git a/vkdispatch/codegen/backends/opencl.py b/vkdispatch/codegen/backends/opencl.py index 907f0508..4401ae21 100644 --- a/vkdispatch/codegen/backends/opencl.py +++ b/vkdispatch/codegen/backends/opencl.py @@ -105,14 +105,17 @@ def constructor( target_type = self.type_name(var_type) if dtypes.is_scalar(var_type): - assert len(args) > 0, f"Constructor for scalar type '{var_type.name}' needs at least one argument." + if len(args) == 0: + raise ValueError(f"Constructor for scalar type '{var_type.name}' needs at least one argument.") + return f"(({target_type})({args[0]}))" if dtypes.is_matrix(var_type): dim = var_type.child_count - assert len(args) in (1, dim, dim * dim), ( - f"Constructor for matrix type '{var_type.name}' needs 1, {dim}, or {dim * dim} arguments." - ) + if len(args) not in (1, dim, dim * dim): + raise ValueError( + f"Constructor for matrix type '{var_type.name}' needs 1, {dim}, or {dim * dim} arguments, but got {len(args)}." + ) if len(args) == 1: single_arg = args[0] helper_name = self._matrix_helper_name( @@ -166,6 +169,48 @@ def buffer_component_expr( f"]" ) + @staticmethod + def _uses_packed_vector_storage(var_type: dtypes.dtype) -> bool: + return dtypes.is_vector(var_type) and var_type.child_count == 3 + + def packed_buffer_read_expr( + self, + scalar_buffer_expr: str, + var_type: dtypes.dtype, + element_index_expr: str, + ) -> Optional[str]: + if not self._uses_packed_vector_storage(var_type): + return None + + return self.constructor( + var_type, + [ + f"{scalar_buffer_expr}[(({element_index_expr}) * 3) + 0]", + f"{scalar_buffer_expr}[(({element_index_expr}) * 3) + 1]", + f"{scalar_buffer_expr}[(({element_index_expr}) * 3) + 2]", + ], + arg_types=[var_type.scalar, var_type.scalar, var_type.scalar], + ) + + def packed_buffer_write_statements( + self, + scalar_buffer_expr: str, + var_type: dtypes.dtype, + element_index_expr: str, + value_expr: str, + ) -> Optional[str]: + if not self._uses_packed_vector_storage(var_type): + return None + + statements: List[str] = [] + for component_index, component_name in enumerate(("x", "y", "z")): + statements.append( + f"{scalar_buffer_expr}[(({element_index_expr}) * 3) + {component_index}] = " + f"{self.component_access_expr(value_expr, component_name, var_type)};\n" + ) + + return "".join(statements) + def _cast_math_arg(self, arg_type: dtypes.dtype, arg_expr: str) -> str: if dtypes.is_scalar(arg_type) or dtypes.is_vector(arg_type) or dtypes.is_complex(arg_type): return self.constructor(arg_type, [arg_expr], arg_types=[arg_type]) diff --git a/vkdispatch/codegen/builder.py b/vkdispatch/codegen/builder.py index 44e50e48..f7805465 100644 --- a/vkdispatch/codegen/builder.py +++ b/vkdispatch/codegen/builder.py @@ -8,23 +8,15 @@ from enum import IntFlag, auto -from typing import Dict, List, Optional, Tuple - +from typing import Dict, List, Optional, Any, get_type_hints import dataclasses -import enum - +from ..base.dtype import is_dtype +from .arguments import Constant, Variable, Buffer from .variables.variables import BaseVariable, ShaderVariable, ScaledAndOfftsetIntVariable from .variables.bound_variables import BufferVariable, ImageVariable -_PUSH_CONSTANT_UNSUPPORTED_BACKENDS = set() - - -def _push_constant_not_supported_error(backend_name: str) -> str: - return ( - f"Push Constants are not supported for the {backend_name.upper()} backend. " - "Use Const instead." - ) +from .shader_description import ShaderDescription, BindingType, ShaderArgumentInfo, ShaderArgumentType @dataclasses.dataclass class SharedBuffer: @@ -40,68 +32,6 @@ class SharedBuffer: size: int name: str -class BindingType(enum.Enum): - """ - A dataclass that represents the type of a binding in a shader. Either a - STORAGE_BUFFER, UNIFORM_BUFFER, or SAMPLER. - """ - STORAGE_BUFFER = 1 - UNIFORM_BUFFER = 3 - SAMPLER = 5 - -@dataclasses.dataclass -class ShaderDescription: - """ - A dataclass that represents a description of a shader object. - - Attributes: - source (str): The source code of the shader. - pc_size (int): The size of the push constant buffer in bytes. - pc_structure (List[vc.StructElement]): The structure of the push constant buffer. - uniform_structure (List[vc.StructElement]): The structure of the uniform buffer. - binding_type_list (List[BindingType]): The list of binding types. - """ - - header: str - body: str - name: str - pc_size: int - pc_structure: List[StructElement] - uniform_structure: List[StructElement] - binding_type_list: List[BindingType] - binding_access: List[Tuple[bool, bool]] # List of tuples indicating read and write access for each binding - exec_count_name: Optional[str] - resource_binding_base: int - backend: Optional[CodeGenBackend] = None - - def make_source(self, x: int, y: int, z: int) -> str: - if self.backend is None: - layout_str = f"layout(local_size_x = {x}, local_size_y = {y}, local_size_z = {z}) in;" - shader_source = f"{self.header}\n{layout_str}\n{self.body}" - else: - shader_source = self.backend.make_source(self.header, self.body, x, y, z) - - # ff = open(f"sources/{self.name}.comp", "w") - # ff.write(shader_source) - # ff.close() - - return shader_source - - def __repr__(self): - description_string = "" - - description_string += f"Shader Name: {self.name}\n" - description_string += f"Push Constant Size: {self.pc_size} bytes\n" - description_string += f"Push Constant Structure: {self.pc_structure}\n" - description_string += f"Uniform Structure: {self.uniform_structure}\n" - description_string += f"Binding Types: {self.binding_type_list}\n" - description_string += f"Binding Access: {self.binding_access}\n" - description_string += f"Execution Count Name: {self.exec_count_name}\n" - description_string += f"Backend: {self.backend.name if self.backend is not None else 'none'}\n" - description_string += f"Header:\n{self.header}\n" - description_string += f"Body:\n{self.body}\n" - return description_string - @dataclasses.dataclass class ShaderBinding: """ @@ -129,6 +59,75 @@ class ShaderFlags(IntFlag): NO_PRINTF = auto() NO_EXEC_BOUNDS = auto() +def annotation_to_shader_arg_and_variable(builder: "ShaderBuilder", type_annotation: Any, name: str, default_value: Any): + # Dataclass case + if(dataclasses.is_dataclass(type_annotation)): + creation_args: Dict[str, ShaderVariable] = {} + value_name = {} + + for field_name, field_type in get_type_hints(type_annotation).items(): + if not is_dtype(field_type): + raise ValueError(f"Unsupported type '{field_type}' for field '{type_annotation}.{field_name}'") + + creation_args[field_name] = builder._declare_constant(field_type) + value_name[field_name] = creation_args[field_name].raw_name + + return ShaderArgumentInfo( + name, + ShaderArgumentType.CONSTANT_DATACLASS, + default_value, + value_name + ), type_annotation(**creation_args) + + # Buffer case + if(issubclass(type_annotation.__origin__, Buffer)): + shader_var = builder._declare_buffer(type_annotation.__args__[0]) + + return ShaderArgumentInfo( + name, + ShaderArgumentType.BUFFER, + default_value, + shader_var.raw_name, + shader_shape_name=shader_var.shape_name, + binding=shader_var.binding + ), shader_var + + # Image case + if(issubclass(type_annotation.__origin__, ImageVariable)): + shader_var = builder._declare_image( + type_annotation.__origin__.dimensions + ) + + return ShaderArgumentInfo( + name, + ShaderArgumentType.IMAGE, + default_value, + shader_var.raw_name, + binding=shader_var.binding + ), shader_var + + if(issubclass(type_annotation.__origin__, Constant)): + shader_var = builder._declare_constant(type_annotation.__args__[0]) + + return ShaderArgumentInfo( + name, + ShaderArgumentType.CONSTANT, + default_value, + shader_var.raw_name + ), shader_var + + if(issubclass(type_annotation.__origin__, Variable)): + shader_var = builder._declare_variable(type_annotation.__args__[0]) + + return ShaderArgumentInfo( + name, + ShaderArgumentType.VARIABLE, + default_value, + shader_var.raw_name + ), shader_var + + raise ValueError(f"Unsupported type '{type_annotation.__args__[0]}'") + class ShaderBuilder(ShaderWriter): binding_count: int binding_read_access: Dict[int, bool] @@ -141,20 +140,15 @@ class ShaderBuilder(ShaderWriter): exec_count: Optional[ShaderVariable] flags: ShaderFlags backend: CodeGenBackend + shader_arg_infos: List[ShaderArgumentInfo] + shader_args: List[ShaderVariable] + has_ubo: bool - def __init__(self, - flags: ShaderFlags = ShaderFlags.NONE, - is_apple_device: bool = False, - backend: Optional[CodeGenBackend] = None) -> None: + def __init__(self, flags: ShaderFlags = ShaderFlags.NONE): super().__init__() self.flags = flags - self.is_apple_device = is_apple_device - if backend is not None: - self.backend = backend - else: - # Use the selected backend type while keeping per-builder backend state isolated. - self.backend = get_codegen_backend().__class__() + self.backend = get_codegen_backend().__class__() self.reset() @@ -168,11 +162,14 @@ def reset(self) -> None: self.binding_write_access = {} self.shared_buffers = [] self.scope_num = 1 + self.shader_arg_infos = [] + self.shader_args = [] + self.has_ubo = False self.exec_count = None if not (self.flags & ShaderFlags.NO_EXEC_BOUNDS): - self.exec_count = self.declare_constant(dtypes.uvec4, var_name="exec_count") + self.exec_count = self._declare_constant(dtypes.uvec4, var_name="exec_count") self.append_contents(self.backend.exec_bounds_guard(self.exec_count.resolve())) def new_var(self, @@ -201,7 +198,28 @@ def new_scaled_var(self, offset=offset, parents=parents) - def declare_constant(self, var_type: dtypes.dtype, count: int = 1, var_name: Optional[str] = None): + def declare_shader_arguments(self, + type_annotations: List, + names: Optional[List[str]] = None, + defaults: Optional[List[Any]] = None): + if len(self.shader_args) > 0: + raise RuntimeError("Shader arguments have already been declared for this builder instance") + + for i in range(len(type_annotations)): + shader_arg_info, shader_var = annotation_to_shader_arg_and_variable( + self, + type_annotations[i], + names[i] if names is not None else f"param{i}", + defaults[i] if defaults is not None else None + ) + + self.shader_args.append(shader_var) + self.shader_arg_infos.append(shader_arg_info) + + def get_shader_arguments(self) -> List[ShaderVariable]: + return self.shader_args + + def _declare_constant(self, var_type: dtypes.dtype, count: int = 1, var_name: Optional[str] = None): if var_name is None: var_name = self.new_name() @@ -221,10 +239,7 @@ def declare_constant(self, var_type: dtypes.dtype, count: int = 1, var_name: Opt self.uniform_struct.register_element(new_var.raw_name, var_type, count) return new_var - def declare_variable(self, var_type: dtypes.dtype, count: int = 1, var_name: Optional[str] = None): - if self.backend.name in _PUSH_CONSTANT_UNSUPPORTED_BACKENDS: - raise NotImplementedError(_push_constant_not_supported_error(self.backend.name)) - + def _declare_variable(self, var_type: dtypes.dtype, count: int = 1, var_name: Optional[str] = None): if var_name is None: var_name = self.new_name() @@ -244,9 +259,7 @@ def declare_variable(self, var_type: dtypes.dtype, count: int = 1, var_name: Opt self.pc_struct.register_element(new_var.raw_name, var_type, count) return new_var - def declare_buffer(self, var_type: dtypes.dtype, var_name: Optional[str] = None): - self.binding_count += 1 - + def _declare_buffer(self, var_type: dtypes.dtype, var_name: Optional[str] = None): buffer_name = f"buf{self.binding_count}" if var_name is None else var_name shape_name = f"{buffer_name}_shape" scalar_expr = None @@ -267,11 +280,13 @@ def write_lambda(): self.binding_write_access[current_binding_count] = True def shape_var_factory(): - return self.declare_constant(dtypes.ivec4, var_name=shape_name) + return self._declare_constant(dtypes.ivec4, var_name=shape_name) + + self.binding_count += 1 return BufferVariable( var_type, - self.binding_count, + self.binding_count-1, f"{buffer_name}.data", shape_var_factory=shape_var_factory, shape_name=shape_name, @@ -281,23 +296,25 @@ def shape_var_factory(): write_lambda=write_lambda ) - def declare_image(self, dimensions: int, var_name: Optional[str] = None): - self.binding_count += 1 - + def _declare_image(self, dimensions: int, var_name: Optional[str] = None): image_name = f"tex{self.binding_count}" if var_name is None else var_name self.binding_list.append(ShaderBinding(dtypes.vec4, image_name, dimensions, BindingType.SAMPLER)) self.binding_read_access[self.binding_count] = False self.binding_write_access[self.binding_count] = False + current_binding_count = self.binding_count + def read_lambda(): - self.binding_read_access[self.binding_count] = True + self.binding_read_access[current_binding_count] = True def write_lambda(): - self.binding_write_access[self.binding_count] = True + self.binding_write_access[current_binding_count] = True + + self.binding_count += 1 return ImageVariable( dtypes.vec4, - self.binding_count, + self.binding_count-1, dimensions, f"{image_name}", read_lambda=read_lambda, @@ -353,50 +370,65 @@ def build(self, name: str) -> ShaderDescription: shared_buffer.size ) + "\n" - uniform_elements = self.uniform_struct.build() + uniform_elements = [] + binding_type_list: List[BindingType] = [] + binding_access = [] + binding_base = 0 - uniform_decleration_contents = self.compose_struct_decleration(uniform_elements) - has_uniform_buffer = len(uniform_decleration_contents) > 0 - if has_uniform_buffer: + if not self.uniform_struct.empty(): + uniform_elements = self.uniform_struct.build() + + uniform_decleration_contents = self.compose_struct_decleration(uniform_elements) header += self.backend.uniform_block_declaration(uniform_decleration_contents) - binding_base = 1 if has_uniform_buffer else 0 - binding_type_list = [] - binding_access = [] - if has_uniform_buffer: binding_type_list.append(BindingType.UNIFORM_BUFFER) binding_access.append((True, False)) # UBO is read-only + binding_base = 1 + + for shader_arg_info in self.shader_arg_infos: + if (shader_arg_info.arg_type == ShaderArgumentType.BUFFER or + shader_arg_info.arg_type == ShaderArgumentType.IMAGE): + shader_arg_info.binding += 1 # Shift bindings by 1 to account for UBO at binding 0 for ii, binding in enumerate(self.binding_list): - emitted_binding = ii + binding_base if binding.binding_type == BindingType.STORAGE_BUFFER: - header += self.backend.storage_buffer_declaration(emitted_binding, binding.dtype, binding.name) - binding_type_list.append(binding.binding_type) - binding_access.append(( - self.binding_read_access[ii + 1], - self.binding_write_access[ii + 1] - )) + header += self.backend.storage_buffer_declaration( + binding=ii + binding_base, + var_type=binding.dtype, + name=binding.name + ) else: - header += self.backend.sampler_declaration(emitted_binding, binding.dimension, binding.name) - binding_type_list.append(binding.binding_type) - binding_access.append(( - self.binding_read_access[ii + 1], - self.binding_write_access[ii + 1] - )) + header += self.backend.sampler_declaration( + binding=ii + binding_base, + dimensions=binding.dimension, + name=binding.name + ) + + binding_type_list.append(binding.binding_type) + binding_access.append(( + self.binding_read_access[ii], + self.binding_write_access[ii] + )) pc_elements = self.pc_struct.build() pc_decleration_contents = self.compose_struct_decleration(pc_elements) if len(pc_decleration_contents) > 0: - assert self.backend.name not in _PUSH_CONSTANT_UNSUPPORTED_BACKENDS, ( - _push_constant_not_supported_error(self.backend.name) - ) header += self.backend.push_constant_declaration(pc_decleration_contents) + enable_subgroup_ops = ( + not (self.flags & ShaderFlags.NO_SUBGROUP_OPS) + and self.backend.uses_feature("subgroup_ops") + ) + enable_printf = ( + not (self.flags & ShaderFlags.NO_PRINTF) + and self.backend.uses_feature("printf") + ) + pre_header = self.backend.pre_header( - enable_subgroup_ops=not (self.flags & ShaderFlags.NO_SUBGROUP_OPS), - enable_printf=not (self.flags & ShaderFlags.NO_PRINTF) + enable_subgroup_ops=enable_subgroup_ops, + enable_printf=enable_printf, ) return ShaderDescription( @@ -410,5 +442,6 @@ def build(self, name: str) -> ShaderDescription: binding_access=binding_access, exec_count_name=self.exec_count.raw_name if self.exec_count is not None else None, resource_binding_base=binding_base, - backend=self.backend + backend=self.backend, + shader_arg_infos=self.shader_arg_infos ) diff --git a/vkdispatch/codegen/context.py b/vkdispatch/codegen/context.py new file mode 100644 index 00000000..9ac90069 --- /dev/null +++ b/vkdispatch/codegen/context.py @@ -0,0 +1,41 @@ +import vkdispatch.codegen as vc + +from typing import List, Optional, Any + +import contextlib + +class ShaderContext: + builder: vc.ShaderBuilder + shader_description: vc.ShaderDescription + + def __init__(self, builder: vc.ShaderBuilder): + self.builder = builder + self.shader_description = None + + def get_description(self, name: Optional[str] = None): + if self.shader_description is not None: + return self.shader_description + + self.shader_description = self.builder.build("shader" if name is None else name) + + return self.shader_description + + def declare_input_arguments(self, + type_annotations: List, + names: Optional[List[str]] = None, + defaults: Optional[List[Any]] = None): + self.builder.declare_shader_arguments(type_annotations, names, defaults) + return self.builder.get_shader_arguments() + +@contextlib.contextmanager +def shader_context(flags: vc.ShaderFlags = vc.ShaderFlags.NONE): + + builder = vc.ShaderBuilder(flags=flags) + old_builder = vc.set_builder(builder) + + context = ShaderContext(builder) + + try: + yield context + finally: + vc.set_builder(old_builder) \ No newline at end of file diff --git a/vkdispatch/codegen/functions/atomic_memory.py b/vkdispatch/codegen/functions/atomic_memory.py index 7efb8590..e471b327 100644 --- a/vkdispatch/codegen/functions/atomic_memory.py +++ b/vkdispatch/codegen/functions/atomic_memory.py @@ -29,23 +29,32 @@ def _is_buffer_backed_target(var: ShaderVariable) -> bool: # https://docs.vulkan.org/glsl/latest/chapters/builtinfunctions.html#atomic-memory-functions def atomic_add(mem: ShaderVariable, y: Any) -> ShaderVariable: - assert isinstance(mem, ShaderVariable), f"atomic_add target must be a ShaderVariable, got {type(mem)}" - assert dtypes.is_scalar(mem.var_type), "atomic_add target must be a scalar lvalue" - assert mem.is_setable(), "atomic_add target must be a writable lvalue" - assert not mem.is_register(), "atomic_add does not support register/local variables as target" - assert _is_buffer_backed_target(mem), "atomic_add target must reference a buffer element (e.g., buf[idx])" - - assert mem.var_type in (dtypes.int32, dtypes.uint32), ( - f"atomic_add currently supports only int32/uint32 targets, got '{mem.var_type.name}'" - ) + if not isinstance(mem, ShaderVariable): + raise TypeError(f"atomic_add target must be a ShaderVariable, got {type(mem)}") + + if not dtypes.is_scalar(mem.var_type): + raise TypeError("atomic_add target must be a scalar lvalue") + + if not mem.is_setable(): + raise TypeError("atomic_add target must be a writable lvalue") + + if mem.is_register(): + raise TypeError("atomic_add does not support register/local variables as target") + + if not _is_buffer_backed_target(mem): + raise TypeError("atomic_add target must reference a buffer element (e.g., buf[idx])") + + if mem.var_type not in (dtypes.int32, dtypes.uint32): + raise TypeError(f"atomic_add currently supports only int32/uint32 targets, got '{mem.var_type.name}'") parents: List[BaseVariable] = [mem] if isinstance(y, ShaderVariable): - assert dtypes.is_scalar(y.var_type), "atomic_add increment variable must be scalar" - assert dtypes.is_integer_dtype(y.var_type), ( - f"atomic_add increment variable must be integer-typed, got '{y.var_type.name}'" - ) + if not dtypes.is_scalar(y.var_type): + raise TypeError(f"atomic_add increment variable must be scalar, got variable '{y.resolve()}' of type '{y.var_type.name}'") + + if not dtypes.is_integer_dtype(y.var_type): + raise TypeError(f"atomic_add increment variable must be integer-typed, got variable '{y.resolve()}' of type '{y.var_type.name}'") y.read_callback() parents.append(y) y_expr = utils.backend_constructor(mem.var_type, y) diff --git a/vkdispatch/codegen/functions/base_functions/arithmetic.py b/vkdispatch/codegen/functions/base_functions/arithmetic.py index 79e890e5..71583d28 100644 --- a/vkdispatch/codegen/functions/base_functions/arithmetic.py +++ b/vkdispatch/codegen/functions/base_functions/arithmetic.py @@ -42,7 +42,8 @@ def arithmetic_op_common(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: - assert isinstance(var, BaseVariable), "First argument must be a ShaderVariable" + if not isinstance(var, BaseVariable): + raise TypeError(f"First argument must be a ShaderVariable, but got {type(var)}") result_type = None @@ -56,11 +57,17 @@ def arithmetic_op_common(var: BaseVariable, raise TypeError(f"Unsupported type for arithmetic op: ShaderVariable and {type(other)}") if inplace: - assert var.is_setable(), "Inplace arithmetic requires the variable to be settable." - assert not reverse, "Inplace arithmetic does not support reverse operations." + if not var.is_setable(): + raise ValueError("Inplace arithmetic requires the variable to be settable.") + + if reverse: + raise ValueError("Inplace arithmetic does not support reverse operations.") + var.read_callback() var.write_callback() - assert result_type == var.var_type, "Inplace arithmetic requires the result type to match the variable type." + + if result_type != var.var_type: + raise TypeError(f"Inplace arithmetic requires the result type to match the variable type, but got '{result_type.name}' and '{var.var_type.name}' respectively.") if base_utils.is_scalar_number(other): return result_type @@ -102,8 +109,10 @@ def add(var: BaseVariable, other: Any, inplace: bool = False) -> BaseVariable: else: base_utils.append_contents(f"{var.resolve()} += {scalar_expr};\n") return var - - assert isinstance(other, BaseVariable) + + if not isinstance(other, BaseVariable): + raise TypeError(f"Unsupported type for addition: ShaderVariable and {type(other)}") + _mark_arith_binary(var.var_type, other.var_type, "+", inplace=inplace) expr, use_assignment = _resolve_arithmetic_binary_expr( "+", @@ -163,7 +172,9 @@ def sub(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = Fa base_utils.append_contents(f"{var.resolve()} -= {scalar_expr};\n") return var - assert isinstance(other, BaseVariable) + if not isinstance(other, BaseVariable): + raise TypeError(f"Unsupported type for subtraction: ShaderVariable and {type(other)}") + lhs_type = var.var_type if not reverse else other.var_type rhs_type = other.var_type if not reverse else var.var_type _mark_arith_binary(lhs_type, rhs_type, "-", inplace=inplace) @@ -229,7 +240,8 @@ def mul(var: BaseVariable, other: Any, inplace: bool = False) -> BaseVariable: base_utils.append_contents(f"{var.resolve()} *= {scalar_expr};\n") return var - assert isinstance(other, BaseVariable) + if not isinstance(other, BaseVariable): + raise TypeError(f"Unsupported type for multiplication: ShaderVariable and {type(other)}") if dtypes.is_complex(var.var_type) and dtypes.is_complex(other.var_type): raise ValueError("Complex multiplication is not supported via the `*` operator.") @@ -239,11 +251,15 @@ def mul(var: BaseVariable, other: Any, inplace: bool = False) -> BaseVariable: return_type = dtypes.cross_multiply_type(var.var_type, other.var_type) if inplace: - assert var.is_setable(), "Inplace arithmetic requires the variable to be settable." + if not var.is_setable(): + raise ValueError("Inplace arithmetic requires the variable to be settable.") + var.read_callback() var.write_callback() other.read_callback() - assert return_type == var.var_type, "Inplace arithmetic requires the result type to match the variable type." + + if return_type != var.var_type: + raise TypeError(f"Inplace multiplication requires the result type to match the variable type, but got '{return_type.name}' and '{var.var_type.name}' respectively.") _mark_arith_binary(var.var_type, other.var_type, "*", inplace=inplace) expr, use_assignment = _resolve_arithmetic_binary_expr( @@ -309,7 +325,8 @@ def truediv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool base_utils.append_contents(f"{var.resolve()} /= {other_expr};\n") return var - assert isinstance(other, BaseVariable) + if not isinstance(other, BaseVariable): + raise TypeError(f"Unsupported type for true division: ShaderVariable and {type(other)}") if dtypes.is_complex(var.var_type) and dtypes.is_complex(other.var_type): raise ValueError("Complex division is not supported.") @@ -359,12 +376,17 @@ def truediv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool return var def floordiv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: - assert dtypes.is_integer_dtype(var.var_type), "Floor division is only supported for integer types." + if not dtypes.is_integer_dtype(var.var_type): + raise TypeError(f"Floor division is only supported for integer types, but variable '{var.resolve()}' has type '{var.var_type.name}'!") + return_type = arithmetic_op_common(var, other, reverse=reverse, inplace=inplace) - assert dtypes.is_integer_dtype(return_type), "Floor division is only supported for integer types." + + if not dtypes.is_integer_dtype(return_type): + raise TypeError(f"Floor division is only supported for integer types, but the result type of the operation is '{return_type.name}'!") if base_utils.is_scalar_number(other): - assert base_utils.is_int_number(other), "Floor division only supports integer scalar values." + if not base_utils.is_int_number(other): + raise TypeError("Floor division only supports integer scalar values.") if not inplace: if other == 1: @@ -390,7 +412,9 @@ def floordiv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool base_utils.append_contents(f"{var.resolve()} /= {other};\n") return var - assert isinstance(other, BaseVariable) + if not isinstance(other, BaseVariable): + raise TypeError(f"Unsupported type for floor division: ShaderVariable and {type(other)}") + _mark_arith_binary(var.var_type if not reverse else other.var_type, other.var_type if not reverse else var.var_type, "/", inplace=inplace) if not inplace: @@ -407,9 +431,14 @@ def floordiv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool return var def mod(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: - assert dtypes.is_integer_dtype(var.var_type), "Modulus is only supported for integer types." + + if not dtypes.is_integer_dtype(var.var_type): + raise TypeError(f"Modulus is only supported for integer types, but variable '{var.resolve()}' has type '{var.var_type.name}'!") + return_type = arithmetic_op_common(var, other, reverse=reverse, inplace=inplace) - assert dtypes.is_integer_dtype(return_type), "Modulus is only supported for integer types." + + if not dtypes.is_integer_dtype(return_type): + raise TypeError(f"Modulus is only supported for integer types, but the result type of the operation is '{return_type.name}'!") if base_utils.is_scalar_number(other): scalar_type = base_utils.number_to_dtype(other) @@ -426,8 +455,10 @@ def mod(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = Fa base_utils.append_contents(f"{var.resolve()} %= {other};\n") return var - - assert isinstance(other, BaseVariable) + + if not isinstance(other, BaseVariable): + raise TypeError(f"Unsupported type for modulus op: ShaderVariable and {type(other)}") + _mark_arith_binary(var.var_type if not reverse else other.var_type, other.var_type if not reverse else var.var_type, "%", inplace=inplace) if not inplace: @@ -488,8 +519,8 @@ def pow_expr(x: Any, y: Any) -> Union[BaseVariable, float]: parents=[x] ) - assert isinstance(y, BaseVariable), "First argument must be a ShaderVariable or number" - assert isinstance(x, BaseVariable), "Second argument must be a ShaderVariable or number" + if not isinstance(x, BaseVariable) or not isinstance(y, BaseVariable): + raise TypeError("Both arguments must be ShaderVariables or numbers") result_type = base_utils.dtype_to_floating(dtypes.cross_type(x.var_type, y.var_type)) return base_utils.new_base_var( diff --git a/vkdispatch/codegen/functions/base_functions/base_utils.py b/vkdispatch/codegen/functions/base_functions/base_utils.py index 7a5d7d71..5610189f 100644 --- a/vkdispatch/codegen/functions/base_functions/base_utils.py +++ b/vkdispatch/codegen/functions/base_functions/base_utils.py @@ -108,12 +108,12 @@ def format_number_literal(var: numbers.Number, *, force_float32: bool = False, d return str(var) def resolve_input(var: Any, dtype: Optional[dtypes.dtype] = None) -> str: - #print("Resolving input:", var) - if is_number(var): return format_number_literal(var, dtype=dtype) - assert isinstance(var, BaseVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, BaseVariable): + raise TypeError(f"Argument must be a ShaderVariable or number, but got {type(var)}") + return var.resolve() def resolve_input_type(var: Any) -> Optional[dtypes.dtype]: diff --git a/vkdispatch/codegen/functions/base_functions/bitwise.py b/vkdispatch/codegen/functions/base_functions/bitwise.py index e272817f..14b92a1a 100644 --- a/vkdispatch/codegen/functions/base_functions/bitwise.py +++ b/vkdispatch/codegen/functions/base_functions/bitwise.py @@ -16,8 +16,11 @@ def bitwise_op_common(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: - assert isinstance(var, BaseVariable), "First argument must be a ShaderVariable" - assert dtypes.is_integer_dtype(var.var_type), "Bitwise operations only supported on integer types." + if not isinstance(var, BaseVariable): + raise TypeError(f"First argument must be a ShaderVariable, but got {type(var)}") + + if not dtypes.is_integer_dtype(var.var_type): + raise TypeError(f"Bitwise operations only supported on integer types, but got '{var.var_type.name}'") result_type = None @@ -29,16 +32,23 @@ def bitwise_op_common(var: BaseVariable, raise TypeError(f"Unsupported type for bitwise op: ShaderVariable and {type(other)}") if inplace: - assert var.is_setable(), "Inplace bitwise requires the variable to be settable." - assert not reverse, "Inplace bitwise does not support reverse operations." + if not var.is_setable(): + raise ValueError("Inplace bitwise requires the variable to be settable.") + + if reverse: + raise ValueError("Inplace bitwise does not support reverse operations.") + var.read_callback() var.write_callback() - assert result_type == var.var_type, "Inplace bitwise requires the result type to match the variable type." + + if result_type != var.var_type: + raise TypeError(f"Inplace bitwise requires the result type to match the variable type, but got '{result_type.name}' and '{var.var_type.name}' respectively.") if base_utils.is_int_number(other): return result_type - assert dtypes.is_integer_dtype(other.var_type), "Bitwise operations only supported on integer types." + if not dtypes.is_integer_dtype(other.var_type): + raise TypeError(f"Bitwise operations only supported on integer types, but got '{other.var_type.name}'") if inplace: other.read_callback() @@ -63,7 +73,9 @@ def lshift(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = base_utils.append_contents(f"{var.resolve()} <<= {other};\n") return var - assert isinstance(other, BaseVariable) + if not isinstance(other, BaseVariable): + raise TypeError(f"Unsupported type for left shift: ShaderVariable and {type(other)}") + _mark_bit_binary(var.var_type if not reverse else other.var_type, other.var_type if not reverse else var.var_type, "<<", inplace=inplace) if not inplace: @@ -97,7 +109,9 @@ def rshift(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = base_utils.append_contents(f"{var.resolve()} >>= {other};\n") return var - assert isinstance(other, BaseVariable) + if not isinstance(other, BaseVariable): + raise TypeError(f"Unsupported type for right shift: ShaderVariable and {type(other)}") + _mark_bit_binary(var.var_type if not reverse else other.var_type, other.var_type if not reverse else var.var_type, ">>", inplace=inplace) if not inplace: @@ -123,8 +137,10 @@ def and_bits(var: BaseVariable, other: Any, inplace: bool = False): base_utils.append_contents(f"{var.resolve()} &= {other};\n") return var - - assert isinstance(other, BaseVariable) + + if not isinstance(other, BaseVariable): + raise TypeError(f"Unsupported type for bitwise AND: ShaderVariable and {type(other)}") + _mark_bit_binary(var.var_type, other.var_type, "&", inplace=inplace) if not inplace: @@ -143,8 +159,10 @@ def xor_bits(var: BaseVariable, other: Any, inplace: bool = False): base_utils.append_contents(f"{var.resolve()} ^= {other};\n") return var - - assert isinstance(other, BaseVariable) + + if not isinstance(other, BaseVariable): + raise TypeError(f"Unsupported type for bitwise XOR: ShaderVariable and {type(other)}") + _mark_bit_binary(var.var_type, other.var_type, "^", inplace=inplace) if not inplace: @@ -164,7 +182,9 @@ def or_bits(var: BaseVariable, other: Any, inplace: bool = False): base_utils.append_contents(f"{var.resolve()} |= {other};\n") return var - assert isinstance(other, BaseVariable) + if not isinstance(other, BaseVariable): + raise TypeError(f"Unsupported type for bitwise OR: ShaderVariable and {type(other)}") + _mark_bit_binary(var.var_type, other.var_type, "|", inplace=inplace) if not inplace: @@ -174,8 +194,12 @@ def or_bits(var: BaseVariable, other: Any, inplace: bool = False): return var def invert(var: BaseVariable): - assert isinstance(var, BaseVariable), "First argument must be a ShaderVariable" - assert dtypes.is_integer_dtype(var.var_type), "Bitwise operations only supported on integer types." + if not isinstance(var, BaseVariable): + raise TypeError(f"Argument must be a ShaderVariable, but got {type(var)}") + + if not dtypes.is_integer_dtype(var.var_type): + raise TypeError(f"Bitwise operations only supported on integer types, but got '{var.var_type.name}'") + _mark_bit_unary(var, "~") return base_utils.new_base_var( diff --git a/vkdispatch/codegen/functions/block_synchonization.py b/vkdispatch/codegen/functions/block_synchronization.py similarity index 68% rename from vkdispatch/codegen/functions/block_synchonization.py rename to vkdispatch/codegen/functions/block_synchronization.py index 3deccc45..45506525 100644 --- a/vkdispatch/codegen/functions/block_synchonization.py +++ b/vkdispatch/codegen/functions/block_synchronization.py @@ -1,14 +1,6 @@ -from ..global_builder import get_builder, get_codegen_backend - from . import utils def barrier(): - # On Apple devices, a memory barrier is required before a barrier - # to ensure memory operations are visible to all threads - # (for some reason) - if get_builder().is_apple_device and get_codegen_backend().name == "glsl": - memory_barrier() - utils.append_contents(utils.codegen_backend().barrier_statement() + "\n") def memory_barrier(): diff --git a/vkdispatch/codegen/functions/common_builtins.py b/vkdispatch/codegen/functions/common_builtins.py index e801bdda..681b8364 100644 --- a/vkdispatch/codegen/functions/common_builtins.py +++ b/vkdispatch/codegen/functions/common_builtins.py @@ -34,7 +34,8 @@ def abs(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return abs(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") return utils.new_var( utils.dtype_to_floating(var.var_type), @@ -47,7 +48,8 @@ def sign(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.sign(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") return utils.new_var( utils.dtype_to_floating(var.var_type), @@ -60,7 +62,8 @@ def floor(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.floor(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") return utils.new_var( utils.dtype_to_floating(var.var_type), @@ -73,7 +76,8 @@ def ceil(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.ceil(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") return utils.new_var( utils.dtype_to_floating(var.var_type), @@ -86,7 +90,8 @@ def trunc(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.trunc(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") return utils.new_var( utils.dtype_to_floating(var.var_type), @@ -99,7 +104,8 @@ def round(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.round(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") return utils.new_var( utils.dtype_to_floating(var.var_type), @@ -111,8 +117,10 @@ def round(var: Any) -> Union[ShaderVariable, float]: def round_even(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.round(var) - - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") + utils.mark_backend_feature("roundEven") return utils.new_var( @@ -126,7 +134,9 @@ def fract(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return float(var - se.floor(var)) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") + utils.mark_backend_feature("fract") return utils.new_var( @@ -147,7 +157,7 @@ def mod(x: Any, y: Any) -> Union[ShaderVariable, float]: elif isinstance(x, ShaderVariable): base_var = x else: - raise AssertionError("Arguments must be ShaderVariables or numbers") + raise ValueError("Arguments must be ShaderVariables or numbers") utils.mark_backend_feature("mod") @@ -178,9 +188,13 @@ def modf(x: Any, y: Any) -> Tuple[ShaderVariable, ShaderVariable]: f"mod({x.resolve()}, {utils.resolve_input(y)})", parents=[x] ) - - assert isinstance(y, ShaderVariable), "First argument must be a ShaderVariable or number" - assert isinstance(x, ShaderVariable), "Second argument must be a ShaderVariable or number" + + if not isinstance(y, ShaderVariable): + raise ValueError(f"First argument must be a ShaderVariable or number, but got {type(y)}!") + + if not isinstance(x, ShaderVariable): + raise ValueError(f"Second argument must be a ShaderVariable or number, but got {type(x)}!") + utils.mark_backend_feature("mod") return utils.new_var( @@ -201,7 +215,7 @@ def min(x: Any, y: Any) -> Union[ShaderVariable, float]: elif isinstance(x, ShaderVariable): base_var = x else: - raise AssertionError("Arguments must be ShaderVariables or numbers") + raise ValueError("Arguments must be ShaderVariables or numbers") return utils.new_var( utils.dtype_to_floating(base_var.var_type), @@ -221,7 +235,7 @@ def max(x: Any, y: Any) -> Union[ShaderVariable, float]: elif isinstance(x, ShaderVariable): base_var = x else: - raise AssertionError("Arguments must be ShaderVariables or numbers") + raise ValueError("Arguments must be ShaderVariables or numbers") return utils.new_var( utils.dtype_to_floating(base_var.var_type), @@ -243,7 +257,7 @@ def clip(x: Any, min_val: Any, max_val: Any) -> Union[ShaderVariable, float]: elif isinstance(x, ShaderVariable): base_var = x else: - raise AssertionError("Arguments must be ShaderVariables or numbers") + raise ValueError("Arguments must be ShaderVariables or numbers") return utils.new_var( utils.dtype_to_floating(base_var.var_type), @@ -268,7 +282,7 @@ def mix(x: Any, y: Any, a: Any) -> Union[ShaderVariable, float]: elif isinstance(x, ShaderVariable): base_var = x else: - raise AssertionError("Arguments must be ShaderVariables or numbers") + raise ValueError("Arguments must be ShaderVariables or numbers") utils.mark_backend_feature("mix") @@ -290,7 +304,7 @@ def step(edge: Any, x: Any) -> Union[ShaderVariable, float]: elif isinstance(edge, ShaderVariable): base_var = edge else: - raise AssertionError("Arguments must be ShaderVariables or numbers") + raise ValueError("Arguments must be ShaderVariables or numbers") utils.mark_backend_feature("step") @@ -315,7 +329,7 @@ def smoothstep(edge0: Any, edge1: Any, x: Any) -> Union[ShaderVariable, float]: elif isinstance(edge0, ShaderVariable): base_var = edge0 else: - raise AssertionError("Arguments must be ShaderVariables or numbers") + raise ValueError("Arguments must be ShaderVariables or numbers") utils.mark_backend_feature("smoothstep") @@ -330,7 +344,8 @@ def isnan(var: Any) -> Union[ShaderVariable, bool]: if utils.is_number(var): return se.isnan(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") return utils.new_var( dtypes.int32, @@ -343,7 +358,8 @@ def isinf(var: Any) -> Union[ShaderVariable, bool]: if utils.is_number(var): return se.isinf(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") return utils.new_var( dtypes.int32, @@ -356,7 +372,8 @@ def float_bits_to_int(var: Any) -> Union[ShaderVariable, int]: if utils.is_number(var): return se.float_bits_to_int(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") return utils.new_var( dtypes.int32, @@ -369,7 +386,8 @@ def float_bits_to_uint(var: Any) -> Union[ShaderVariable, int]: if utils.is_number(var): return se.float_bits_to_uint(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") return utils.new_var( dtypes.uint32, @@ -381,8 +399,9 @@ def float_bits_to_uint(var: Any) -> Union[ShaderVariable, int]: def int_bits_to_float(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.int_bits_to_float(var) - - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") return utils.new_var( dtypes.float32, @@ -395,7 +414,8 @@ def uint_bits_to_float(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.uint_bits_to_float(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError(f"Argument must be a ShaderVariable or number, but got {type(var)}!") return utils.new_var( dtypes.float32, @@ -417,7 +437,7 @@ def fma(a: Any, b: Any, c: Any) -> Union[ShaderVariable, float]: elif isinstance(a, ShaderVariable): base_var = a else: - raise AssertionError("Arguments must be ShaderVariables or numbers") + raise ValueError("Arguments must be ShaderVariables or numbers") result_type = utils.dtype_to_floating(base_var.var_type) fma_function = utils.codegen_backend().fma_function_name(result_type) diff --git a/vkdispatch/codegen/functions/complex_numbers.py b/vkdispatch/codegen/functions/complex_numbers.py index e99f3d7b..9984999f 100644 --- a/vkdispatch/codegen/functions/complex_numbers.py +++ b/vkdispatch/codegen/functions/complex_numbers.py @@ -18,11 +18,14 @@ def complex_from_euler_angle(angle: ShaderVariable): def validate_complex_number(arg1: Any) -> Union[ShaderVariable, complex]: if isinstance(arg1, ShaderVariable): - assert dtypes.is_complex(arg1.var_type), "Input variables to complex multiplication must be complex" + if not dtypes.is_complex(arg1.var_type): + raise TypeError(f"Input variable '{arg1.resolve()}' of type '{arg1.var_type.name}' is not a complex type!") + return arg1 - assert utils.is_number(arg1), "Argument must be ShaderVariable or number" - + if not utils.is_number(arg1): + raise TypeError(f"Argument must be a ShaderVariable or a number, got {type(arg1)}!") + return complex(arg1) def _new_big_complex(var_type: dtypes.dtype, arg1: Any, arg2: Any): diff --git a/vkdispatch/codegen/functions/control_flow.py b/vkdispatch/codegen/functions/control_flow.py index 88fcad45..52e99d6a 100644 --- a/vkdispatch/codegen/functions/control_flow.py +++ b/vkdispatch/codegen/functions/control_flow.py @@ -3,6 +3,8 @@ from typing import List, Optional, Union from . import utils +import contextlib + def proc_bool(arg: Union[ShaderVariable, bool]) -> ShaderVariable: if isinstance(arg, bool): return "true" if arg else "false" @@ -12,44 +14,29 @@ def proc_bool(arg: Union[ShaderVariable, bool]) -> ShaderVariable: raise TypeError(f"Argument of type {type(arg)} cannot be processed as a boolean.") -def if_statement(arg: ShaderVariable, command: Optional[str] = None): - if command is None: - utils.append_contents(f"if({proc_bool(arg)}) {'{'}\n") - utils.scope_increment() - return - - utils.append_contents(f"if({proc_bool(arg)})\n") - utils.scope_increment() - utils.append_contents(f"{command}\n") - utils.scope_decrement() - -def if_any(*args: List[ShaderVariable]): - utils.append_contents(f"if({' || '.join([str(proc_bool(elem)) for elem in args])}) {'{'}\n") - utils.scope_increment() - -def if_all(*args: List[ShaderVariable]): - utils.append_contents(f"if({' && '.join([str(proc_bool(elem)) for elem in args])}) {'{'}\n") +@contextlib.contextmanager +def if_block(arg: ShaderVariable): + utils.append_contents(f"if({proc_bool(arg)}) {{\n") utils.scope_increment() + yield + utils.scope_decrement() + utils.append_contents("}\n") -def else_statement(): - utils.scope_decrement() - utils.append_contents("} else {\n") +@contextlib.contextmanager +def else_if_block(arg: ShaderVariable): + utils.append_contents(f"else if({proc_bool(arg)}) {{\n") utils.scope_increment() - -def else_if_statement(arg: ShaderVariable): + yield utils.scope_decrement() - utils.append_contents(f"}} else if({proc_bool(arg)}) {'{'}\n") - utils.scope_increment() + utils.append_contents("}\n") -def else_if_any(*args: List[ShaderVariable]): - utils.scope_decrement() - utils.append_contents(f"}} else if({' || '.join([str(proc_bool(elem)) for elem in args])}) {'{'}\n") +@contextlib.contextmanager +def else_block(): + utils.append_contents("else {\n") utils.scope_increment() - -def else_if_all(*args: List[ShaderVariable]): + yield utils.scope_decrement() - utils.append_contents(f"}} else if({' && '.join([str(proc_bool(elem)) for elem in args])}) {'{'}\n") - utils.scope_increment() + utils.append_contents("}\n") def return_statement(arg=None): if arg is None: @@ -65,27 +52,33 @@ def return_statement(arg=None): utils.append_contents(f"return {arg_expr};\n") -def while_statement(arg: ShaderVariable): - utils.append_contents(f"while({proc_bool(arg)}) {'{'}\n") +@contextlib.contextmanager +def while_block(arg: ShaderVariable): + utils.append_contents(f"while({proc_bool(arg)}) {{\n") utils.scope_increment() + yield + utils.scope_decrement() + utils.append_contents("}\n") -def new_scope(indent: bool = True, comment: str = None): +@contextlib.contextmanager +def scope_block(indent: bool = True, comment: str = None): if comment is None: utils.append_contents("{\n") else: - utils.append_contents("{ " + f"/* {comment} */\n") + utils.append_contents(f"{{ /* {comment} */\n") if indent: utils.scope_increment() - -def end(indent: bool = True): + + yield + if indent: utils.scope_decrement() utils.append_contents("}\n") -def logical_and(arg1: ShaderVariable, arg2: ShaderVariable): - return utils.new_var(dtypes.int32, f"({proc_bool(arg1)} && {proc_bool(arg2)})", [arg1, arg2]) +def any(*args: List[ShaderVariable]): + return utils.new_var(dtypes.int32, f"({' || '.join([str(proc_bool(elem)) for elem in args])})", list(args)) -def logical_or(arg1: ShaderVariable, arg2: ShaderVariable): - return utils.new_var(dtypes.int32, f"({proc_bool(arg1)} || {proc_bool(arg2)})", [arg1, arg2]) +def all(*args: List[ShaderVariable]): + return utils.new_var(dtypes.int32, f"({' && '.join([str(proc_bool(elem)) for elem in args])})", list(args)) diff --git a/vkdispatch/codegen/functions/exponential.py b/vkdispatch/codegen/functions/exponential.py index 68b2ebc6..91f3a92f 100644 --- a/vkdispatch/codegen/functions/exponential.py +++ b/vkdispatch/codegen/functions/exponential.py @@ -87,8 +87,8 @@ def pow(x: Any, y: Any) -> Union[ShaderVariable, float]: parents=[x] ) - assert isinstance(y, ShaderVariable), "First argument must be a ShaderVariable or number" - assert isinstance(x, ShaderVariable), "Second argument must be a ShaderVariable or number" + if not isinstance(x, ShaderVariable) or not isinstance(y, ShaderVariable): + raise ValueError("Both arguments must be ShaderVariables or numbers") result_type = utils.dtype_to_floating(dtypes.cross_type(x.var_type, y.var_type)) return utils.new_var( @@ -108,28 +108,36 @@ def exp(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.exp(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("exp", var) def exp2(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.exp2(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("exp2", var) def log(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.log(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("log", var) def log2(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.log2(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("log2", var) # has double @@ -137,7 +145,9 @@ def sqrt(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.sqrt(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("sqrt", var) # has double @@ -145,7 +155,9 @@ def inversesqrt(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return float(1.0 / se.sqrt(var)) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + utils.mark_backend_feature("inversesqrt") return utils.new_var( diff --git a/vkdispatch/codegen/functions/geometric.py b/vkdispatch/codegen/functions/geometric.py index 6992a8ad..687b14c6 100644 --- a/vkdispatch/codegen/functions/geometric.py +++ b/vkdispatch/codegen/functions/geometric.py @@ -9,7 +9,8 @@ def length(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.abs_value(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") return utils.new_var( utils.dtype_to_floating(var.var_type), @@ -59,11 +60,11 @@ def dot(x: Any, y: Any) -> Union[ShaderVariable, float]: ) def cross(x: ShaderVariable, y: ShaderVariable) -> ShaderVariable: - assert isinstance(x, ShaderVariable), "Argument x must be a ShaderVariable" - assert isinstance(y, ShaderVariable), "Argument y must be a ShaderVariable" - - assert x.var_type == dtypes.vec3, "Argument x must be of type vec3 or dvec3" - assert y.var_type == dtypes.vec3, "Argument y must be of type vec3 or dvec3" + if not isinstance(x, ShaderVariable) or not isinstance(y, ShaderVariable): + raise ValueError("Both arguments must be ShaderVariables") + + if x.var_type != dtypes.vec3 or y.var_type != dtypes.vec3: + raise ValueError("Both arguments must be of type vec3 or dvec3") return utils.new_var( dtypes.vec3, @@ -73,7 +74,8 @@ def cross(x: ShaderVariable, y: ShaderVariable) -> ShaderVariable: ) def normalize(var: ShaderVariable) -> ShaderVariable: - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable") return utils.new_var( var.var_type, diff --git a/vkdispatch/codegen/functions/index_raveling.py b/vkdispatch/codegen/functions/index_raveling.py index d1f38b86..e8f8c8ee 100644 --- a/vkdispatch/codegen/functions/index_raveling.py +++ b/vkdispatch/codegen/functions/index_raveling.py @@ -12,15 +12,19 @@ def sanitize_input(value: Union[ShaderVariable, Tuple[int, ...]]) -> Tuple[List[ axes_lengths = [] if isinstance(value, ShaderVariable): - assert dtypes.is_vector(value.var_type) or dtypes.is_scalar(value.var_type), f"Value is of type '{value.var_type.name}', but it must be a vector or integer!" - assert dtypes.is_integer_dtype(value.var_type), f"Value is of type '{value.var_type.name}', but it must be of integer type!" + if not (dtypes.is_vector(value.var_type) or dtypes.is_scalar(value.var_type)): + raise ValueError(f"Value is of type '{value.var_type.name}', but it must be a vector or integer!") + if not dtypes.is_integer_dtype(value.var_type): + raise ValueError(f"Value is of type '{value.var_type.name}', but it must be of integer type!") + if dtypes.is_scalar(value.var_type): axes_lengths.append(value) return axes_lengths elem_count = value.var_type.child_count - assert elem_count >= 2 and elem_count <= 4, f"Value is of type '{value.var_type.name}', but it must have 2, 3 or 4 components!" + if elem_count < 2 or elem_count > 4: + raise ValueError(f"Value is of type '{value.var_type.name}', but it must have 2, 3 or 4 components!") # Since buffer shapes store total elem count in the 4th component, we ignore it here. if elem_count == 4: @@ -32,13 +36,17 @@ def sanitize_input(value: Union[ShaderVariable, Tuple[int, ...]]) -> Tuple[List[ if utils.check_is_int(value): return [value] - assert isinstance(value, (list, tuple)), "Value must be a ShaderVariable or a list/tuple of integers!" + if not isinstance(value, (list, tuple)): + raise ValueError(f"Value must be a ShaderVariable or a list/tuple of integers, but got {type(value)}!") elem_count = len(value) - assert elem_count >= 1 or elem_count <= 3, f"Value has {elem_count} elements, but it must have 1, 2, or 3 elements!" + + if elem_count < 1 or elem_count > 3: + raise ValueError(f"Value has {elem_count} elements, but it must have 1, 2, or 3 elements!") for i in range(elem_count): - assert utils.check_is_int(value[i]), "When value is a list/tuple, all its elements must be integers!" + if not utils.check_is_int(value[i]): + raise ValueError(f"When value is a list/tuple, all its elements must be integers, but element {i} is of type '{type(value[i])}'!") axes_lengths.append(value[i]) @@ -48,8 +56,11 @@ def ravel_index(index: Union[ShaderVariable, int], shape: Union[ShaderVariable, sanitized_shape = sanitize_input(shape) sanitized_index = sanitize_input(index) - assert len(sanitized_index) == 1, f"Index must be a single integer value, not '{index}'!" - assert len(sanitized_shape) == 2 or len(sanitized_shape) == 3, f"Shape must have 2 or 3 elements, not '{shape}'!" + if len(sanitized_index) != 1: + raise ValueError(f"Index must be a single integer value, not '{index}'!") + + if len(sanitized_shape) != 2 and len(sanitized_shape) != 3: + raise ValueError(f"Shape must have 2 or 3 elements, not '{shape}'!") if len(sanitized_shape) == 2: x = sanitized_index[0] // sanitized_shape[1] @@ -69,7 +80,8 @@ def unravel_index(index: Union[ShaderVariable, Tuple[int, ...]], shape: Union[Sh sanitized_shape = sanitize_input(shape) sanitized_index = sanitize_input(index) - assert len(sanitized_index) <= len(sanitized_shape), f"Index ({index}) must have the same number of elements as shape ({sanitized_shape})!" + if len(sanitized_index) > len(sanitized_shape): + raise ValueError(f"Index ({index}) must have the same number of elements as shape ({sanitized_shape})!") if len(sanitized_index) == 1: return index diff --git a/vkdispatch/codegen/functions/matrix.py b/vkdispatch/codegen/functions/matrix.py index 6629bc25..b97be859 100644 --- a/vkdispatch/codegen/functions/matrix.py +++ b/vkdispatch/codegen/functions/matrix.py @@ -4,13 +4,14 @@ from . import utils def matrix_comp_mult(x: ShaderVariable, y: ShaderVariable) -> ShaderVariable: - assert isinstance(y, ShaderVariable), "Second argument must be a ShaderVariable" - assert isinstance(x, ShaderVariable), "First argument must be a ShaderVariable" + if not isinstance(x, ShaderVariable) or not isinstance(y, ShaderVariable): + raise ValueError("Both arguments must be ShaderVariables") - assert dtypes.is_matrix(x.var_type), "First argument must be a matrix" - assert dtypes.is_matrix(y.var_type), "Second argument must be a matrix" + if not dtypes.is_matrix(x.var_type) or not dtypes.is_matrix(y.var_type): + raise ValueError("Both arguments must be matrices") - assert x.var_type == y.var_type, "Matrices must have the same shape" + if x.var_type != y.var_type: + raise ValueError("Both matrices must have the same shape") return utils.new_var( x.var_type, @@ -20,13 +21,14 @@ def matrix_comp_mult(x: ShaderVariable, y: ShaderVariable) -> ShaderVariable: ) def outer_product(x: ShaderVariable, y: ShaderVariable) -> ShaderVariable: - assert isinstance(y, ShaderVariable), "Second argument must be a ShaderVariable" - assert isinstance(x, ShaderVariable), "First argument must be a ShaderVariable" + if not isinstance(x, ShaderVariable) or not isinstance(y, ShaderVariable): + raise ValueError("Both arguments must be ShaderVariables") - assert dtypes.is_vector(x.var_type), "First argument must be a matrix" - assert dtypes.is_vector(y.var_type), "Second argument must be a matrix" - - assert x.var_type == y.var_type, "Matrices must have the same shape" + if not dtypes.is_vector(x.var_type) or not dtypes.is_vector(y.var_type): + raise ValueError("Both arguments must be vectors") + + if x.var_type != y.var_type: + raise ValueError("Both arguments must be of the same vector type") out_type = None @@ -47,9 +49,11 @@ def outer_product(x: ShaderVariable, y: ShaderVariable) -> ShaderVariable: ) def transpose(var: ShaderVariable) ->ShaderVariable: - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable" - - assert dtypes.is_matrix(var.var_type), "Argument must be a matrix" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable") + + if not dtypes.is_matrix(var.var_type): + raise ValueError("Argument must be a matrix") return utils.new_var( var.var_type, @@ -59,9 +63,11 @@ def transpose(var: ShaderVariable) ->ShaderVariable: ) def determinant(var: ShaderVariable) -> ShaderVariable: - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable" - - assert dtypes.is_matrix(var.var_type), "Argument must be a matrix" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable") + + if not dtypes.is_matrix(var.var_type): + raise ValueError("Argument must be a matrix") return utils.new_var( dtypes.float32, @@ -71,9 +77,11 @@ def determinant(var: ShaderVariable) -> ShaderVariable: ) def inverse(var: ShaderVariable) -> ShaderVariable: - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable" - - assert dtypes.is_matrix(var.var_type), "Argument must be a matrix" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable") + + if not dtypes.is_matrix(var.var_type): + raise ValueError("Argument must be a matrix") return utils.new_var( var.var_type, diff --git a/vkdispatch/codegen/functions/printing.py b/vkdispatch/codegen/functions/printing.py index 2f1893fa..f6b2387d 100644 --- a/vkdispatch/codegen/functions/printing.py +++ b/vkdispatch/codegen/functions/printing.py @@ -12,7 +12,7 @@ def printf(format: str, *args: Any): resolved_args = [resolve_arg(arg) for arg in args] utils.append_contents(utils.codegen_backend().printf_statement(format, resolved_args) + "\n") -def print_vars(*args: Any, seperator=" "): +def print_vars(*args: Any, separator=" "): args_list = [] fmts = [] @@ -24,6 +24,6 @@ def print_vars(*args: Any, seperator=" "): else: fmts.append(str(arg)) - fmt = seperator.join(fmts) + fmt = separator.join(fmts) utils.append_contents(utils.codegen_backend().printf_statement(fmt, args_list) + "\n") diff --git a/vkdispatch/codegen/functions/trigonometry.py b/vkdispatch/codegen/functions/trigonometry.py index 19251db1..87bc02e1 100644 --- a/vkdispatch/codegen/functions/trigonometry.py +++ b/vkdispatch/codegen/functions/trigonometry.py @@ -108,7 +108,9 @@ def radians(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return var * (3.141592653589793 / 180.0) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + utils.mark_backend_feature("radians") return _unary_math_var("radians", var) @@ -116,7 +118,9 @@ def degrees(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return var * (180.0 / 3.141592653589793) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + utils.mark_backend_feature("degrees") return _unary_math_var("degrees", var) @@ -124,42 +128,54 @@ def sin(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.sin(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("sin", var) def cos(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.cos(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("cos", var) def tan(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.tan(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("tan", var) def asin(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.arcsin(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("asin", var) def acos(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.arccos(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("acos", var) def atan(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.arctan(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("atan", var) def atan2(y: Any, x: Any) -> Union[ShaderVariable, float]: @@ -192,8 +208,8 @@ def atan2(y: Any, x: Any) -> Union[ShaderVariable, float]: [x], ) - assert isinstance(y, ShaderVariable), "First argument must be a ShaderVariable or number" - assert isinstance(x, ShaderVariable), "Second argument must be a ShaderVariable or number" + if not isinstance(y, ShaderVariable) or not isinstance(x, ShaderVariable): + raise ValueError("Both arguments must be ShaderVariables or numbers") result_type = dtype_to_floating(dtypes.cross_type(y.var_type, x.var_type)) return _binary_math_var( @@ -211,40 +227,52 @@ def sinh(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.sinh(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("sinh", var) def cosh(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.cosh(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("cosh", var) def tanh(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.tanh(var) - - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("tanh", var) def asinh(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.arcsinh(var) - - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("asinh", var) def acosh(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.arccosh(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("acosh", var) def atanh(var: Any) -> Union[ShaderVariable, float]: if utils.is_number(var): return se.arctanh(var) - assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + if not isinstance(var, ShaderVariable): + raise ValueError("Argument must be a ShaderVariable or number") + return _unary_math_var("atanh", var) diff --git a/vkdispatch/codegen/functions/type_casting.py b/vkdispatch/codegen/functions/type_casting.py index 276a479a..0fd38141 100644 --- a/vkdispatch/codegen/functions/type_casting.py +++ b/vkdispatch/codegen/functions/type_casting.py @@ -79,7 +79,8 @@ def _infer_complex_dtype(*args) -> dtypes.dtype: return complex_type def _to_complex_dtype(var_type: dtypes.dtype, *args): - assert len(args) == 1 or len(args) == 2, "Must give one of two arguments for complex init" + if len(args) != 1 and len(args) != 2: + raise ValueError("Must give one of two arguments for complex init") if len(args) == 1 and isinstance(args[0], ShaderVariable) and dtypes.is_complex(args[0].var_type): return to_dtype(var_type, args[0]) diff --git a/vkdispatch/codegen/global_builder.py b/vkdispatch/codegen/global_builder.py index 8a14b1b9..48876750 100644 --- a/vkdispatch/codegen/global_builder.py +++ b/vkdispatch/codegen/global_builder.py @@ -77,13 +77,18 @@ def set_builder(builder: 'ShaderBuilder'): set_shader_writer(None) return - assert _get_builder() is None, "A global ShaderBuilder is already set for the current thread!" + if _get_builder() is not None: + raise RuntimeError("A global ShaderBuilder is already set for the current thread!") + set_shader_writer(builder) _builder_context.active_builder = builder def get_builder() -> 'ShaderBuilder': builder = _get_builder() - assert builder is not None, "No global ShaderBuilder is set for the current thread!" + + if builder is None: + raise RuntimeError("No global ShaderBuilder is set for the current thread!") + return builder def shared_buffer(var_type: dtypes.dtype, size: int, var_name: Optional[str] = None): diff --git a/vkdispatch/codegen/shader_description.py b/vkdispatch/codegen/shader_description.py new file mode 100644 index 00000000..efe9d9b0 --- /dev/null +++ b/vkdispatch/codegen/shader_description.py @@ -0,0 +1,92 @@ +import enum +import dataclasses +from typing import List, Tuple, Optional, Any, Union, Dict + +from .backends import CodeGenBackend +from .struct_builder import StructElement + +class BindingType(enum.Enum): + """ + A dataclass that represents the type of a binding in a shader. Either a + STORAGE_BUFFER, UNIFORM_BUFFER, or SAMPLER. + """ + STORAGE_BUFFER = 1 + UNIFORM_BUFFER = 3 + SAMPLER = 5 + +class ShaderArgumentType(enum.Enum): + BUFFER = 0 + IMAGE = 1 + VARIABLE = 2 + CONSTANT = 3 + CONSTANT_DATACLASS = 4 + +@dataclasses.dataclass +class ShaderArgumentInfo: + name: str + arg_type: ShaderArgumentType + default_value: Any + shader_name: Union[str, Dict[str, str]] + shader_shape_name: Optional[str] + binding: Optional[int] + + def __init__(self, + name: str, + arg_type: ShaderArgumentType, + default_value: Any, + shader_name: Union[str, Dict[str, str]], + shader_shape_name: Optional[str] = None, + binding: Optional[int] = None): + self.name = name + self.arg_type = arg_type + self.default_value = default_value + self.shader_name = shader_name + self.shader_shape_name = shader_shape_name + self.binding = binding + +@dataclasses.dataclass +class ShaderDescription: + """ + A dataclass that represents a description of a shader object. + + Attributes: + source (str): The source code of the shader. + pc_size (int): The size of the push constant buffer in bytes. + pc_structure (List[vc.StructElement]): The structure of the push constant buffer. + uniform_structure (List[vc.StructElement]): The structure of the uniform buffer. + binding_type_list (List[BindingType]): The list of binding types. + """ + + header: str + body: str + name: str + pc_size: int + pc_structure: List[StructElement] + uniform_structure: List[StructElement] + binding_type_list: List[BindingType] + binding_access: List[Tuple[bool, bool]] # List of tuples indicating read and write access for each binding + exec_count_name: Optional[str] + resource_binding_base: int + backend: CodeGenBackend + shader_arg_infos: List[ShaderArgumentInfo] + + def make_source(self, x: int, y: int, z: int) -> str: + return self.backend.make_source(self.header, self.body, x, y, z) + + def get_arg_names_and_defaults(self) -> List[Tuple[str, Any]]: + return [(arg.name, arg.default_value) for arg in self.shader_arg_infos] + + def __repr__(self): + description_string = "" + + description_string += f"Shader Name: {self.name}\n" + description_string += f"Push Constant Size: {self.pc_size} bytes\n" + description_string += f"Push Constant Structure: {self.pc_structure}\n" + description_string += f"Uniform Structure: {self.uniform_structure}\n" + description_string += f"Binding Types: {self.binding_type_list}\n" + description_string += f"Binding Access: {self.binding_access}\n" + description_string += f"Execution Count Name: {self.exec_count_name}\n" + description_string += f"Backend: {self.backend.name if self.backend is not None else 'none'}\n" + description_string += f"Header:\n{self.header}\n" + description_string += f"Body:\n{self.body}\n" + return description_string \ No newline at end of file diff --git a/vkdispatch/codegen/shader_writer.py b/vkdispatch/codegen/shader_writer.py index b374588c..723f9ec2 100644 --- a/vkdispatch/codegen/shader_writer.py +++ b/vkdispatch/codegen/shader_writer.py @@ -10,15 +10,20 @@ def _get_shader_writer() -> Optional['ShaderWriter']: def shader_writer() -> 'ShaderWriter': writer = _get_shader_writer() - assert writer is not None, "No global ShaderWriter is set for the current thread!" + + if writer is None: + raise RuntimeError("No global ShaderWriter is set for the current thread!") + return writer def set_shader_writer(writer: 'ShaderWriter'): if writer is None: _thread_context.writer = None return - - assert _get_shader_writer() is None, "A global ShaderWriter is already set for the current thread!" + + if _get_shader_writer() is not None: + raise RuntimeError("A global ShaderWriter is already set for the current thread!") + _thread_context.writer = writer class ShaderWriter: diff --git a/vkdispatch/codegen/struct_builder.py b/vkdispatch/codegen/struct_builder.py index 71911298..759af889 100644 --- a/vkdispatch/codegen/struct_builder.py +++ b/vkdispatch/codegen/struct_builder.py @@ -41,6 +41,9 @@ def register_element(self, name: str, dtype: dtype.dtype, count: int) -> None: self.elements.append(StructElement(name, dtype, count)) self.size += dtype.item_size * count + def empty(self) -> bool: + return len(self.elements) == 0 and self.size == 0 + def build(self) -> List[StructElement]: # Sort the elements by size in descending order self.elements.sort(key=lambda x: x.dtype.item_size * x.count, reverse=True) diff --git a/vkdispatch/codegen/variables/base_variable.py b/vkdispatch/codegen/variables/base_variable.py index cb730815..dab3af70 100644 --- a/vkdispatch/codegen/variables/base_variable.py +++ b/vkdispatch/codegen/variables/base_variable.py @@ -25,13 +25,14 @@ def __init__(self, self.can_index = False self.use_child_type = True - assert name is not None, "Variable name cannot be None!" + if name is None: + raise ValueError("Variable name cannot be None!") self.name = name self.raw_name = raw_name if raw_name is not None else self.name - if register: - assert settable, "An unsettable register makes no sense" + if register and not settable: + raise ValueError("An unsettable register makes no sense!") self.settable = settable self.register = register diff --git a/vkdispatch/codegen/variables/bound_variables.py b/vkdispatch/codegen/variables/bound_variables.py index 228ff299..f6c6143d 100644 --- a/vkdispatch/codegen/variables/bound_variables.py +++ b/vkdispatch/codegen/variables/bound_variables.py @@ -58,7 +58,9 @@ def __init__(self, @property def shape(self) -> "ShaderVariable": if self._shape_var is None: - assert self._shape_var_factory is not None, "Buffer shape variable factory is not available!" + if self._shape_var_factory is None: + raise ValueError("Buffer shape variable factory is not available!") + self._shape_var = self._shape_var_factory() return self._shape_var @@ -70,15 +72,38 @@ def write_callback(self): self.write_lambda() def __getitem__(self, index) -> "ShaderVariable": - assert self.can_index, f"Variable '{self.resolve()}' of type '{self.var_type.name}' cannot be indexed into!" + if not self.can_index: + raise TypeError(f"Variable '{self.resolve()}' of type '{self.var_type.name}' cannot be indexed into!") return_type = self.var_type.child_type if self.use_child_type else self.var_type if isinstance(index, tuple): - assert len(index) == 1, "Only single index is supported, cannot use multi-dimentional indexing!" + if len(index) != 1: + raise ValueError("Only single index is supported, cannot use multi-dimentional indexing!") + index = index[0] if base_utils.is_int_number(index): + backend = self.codegen_backend if self.codegen_backend is not None else get_codegen_backend() + packed_expr = None + if self.scalar_expr is not None: + packed_expr = backend.packed_buffer_read_expr( + self.scalar_expr, + return_type, + str(index), + ) + + if packed_expr is not None: + return ShaderVariable( + return_type, + packed_expr, + parents=[self], + settable=self.settable, + lexical_unit=True, + buffer_root=self, + buffer_index_expr=str(index), + ) + return ShaderVariable( return_type, f"{self.resolve()}[{index}]", @@ -89,9 +114,34 @@ def __getitem__(self, index) -> "ShaderVariable": buffer_index_expr=str(index), ) - assert isinstance(index, ShaderVariable), f"Index must be a ShaderVariable or int type, not {type(index)}!" - assert dtypes.is_scalar(index.var_type), "Indexing variable must be a scalar!" - assert dtypes.is_integer_dtype(index.var_type), "Indexing variable must be an integer type!" + if not isinstance(index, ShaderVariable): + raise TypeError(f"Index must be a ShaderVariable or int type, not {type(index)}!") + + if not dtypes.is_scalar(index.var_type): + raise TypeError("Indexing variable must be a scalar!") + + if not dtypes.is_integer_dtype(index.var_type): + raise TypeError("Indexing variable must be an integer type!") + + backend = self.codegen_backend if self.codegen_backend is not None else get_codegen_backend() + packed_expr = None + if self.scalar_expr is not None: + packed_expr = backend.packed_buffer_read_expr( + self.scalar_expr, + return_type, + index.resolve(), + ) + + if packed_expr is not None: + return ShaderVariable( + return_type, + packed_expr, + parents=[self, index], + settable=self.settable, + lexical_unit=True, + buffer_root=self, + buffer_index_expr=index.resolve(), + ) return ShaderVariable( return_type, diff --git a/vkdispatch/codegen/variables/variables.py b/vkdispatch/codegen/variables/variables.py index e8e776ee..db03c6a2 100644 --- a/vkdispatch/codegen/variables/variables.py +++ b/vkdispatch/codegen/variables/variables.py @@ -97,12 +97,15 @@ def _buffer_component_expr(self, component_index_expr: str) -> Optional[str]: ) def __getitem__(self, index) -> "ShaderVariable": - assert self.can_index, f"Variable '{self.resolve()}' of type '{self.var_type.name}' cannot be indexed into!" + if not self.can_index: + raise ValueError(f"Variable '{self.resolve()}' of type '{self.var_type.name}' cannot be indexed into!") return_type = self.var_type.child_type if self.use_child_type else self.var_type if isinstance(index, tuple): - assert len(index) == 1, "Only single index is supported, cannot use multi-dimentional indexing!" + if len(index) != 1: + raise ValueError("Only single index is supported, cannot use multi-dimentional indexing!") + index = index[0] if base_utils.is_int_number(index): @@ -118,9 +121,14 @@ def __getitem__(self, index) -> "ShaderVariable": return ShaderVariable(return_type, f"{self.resolve()}[{index}]", parents=[self], settable=self.settable, lexical_unit=True) - assert isinstance(index, ShaderVariable), f"Index must be a ShaderVariable or int type, not {type(index)}!" - assert dtypes.is_scalar(index.var_type), "Indexing variable must be a scalar!" - assert dtypes.is_integer_dtype(index.var_type), "Indexing variable must be an integer type!" + if not isinstance(index, ShaderVariable): + raise ValueError(f"Index must be a ShaderVariable or int type, not {type(index)}!") + + if not dtypes.is_scalar(index.var_type): + raise ValueError(f"Indexing variable must be a scalar, but got variable '{index.resolve()}' of type '{index.var_type.name}'!") + + if not dtypes.is_integer_dtype(index.var_type): + raise ValueError(f"Indexing variable must be an integer type, but got variable '{index.resolve()}' of type '{index.var_type.name}'!") component_expr = self._buffer_component_expr(index.resolve()) if component_expr is not None: @@ -135,13 +143,18 @@ def __getitem__(self, index) -> "ShaderVariable": return ShaderVariable(return_type, f"{self.resolve()}[{index.resolve()}]", parents=[self, index], settable=self.settable, lexical_unit=True) def swizzle(self, components: str) -> "ShaderVariable": - assert dtypes.is_vector(self.var_type) or dtypes.is_complex(self.var_type) or dtypes.is_scalar(self.var_type), f"Variable '{self.resolve()}' of type '{self.var_type.name}' does not support swizzling!" - assert self.use_child_type, f"Variable '{self.resolve()}' does not support swizzling!" - - assert len(components) >= 1 and len(components) <= 4, f"Swizzle must have between 1 and 4 components, got {len(components)}!" + if not (dtypes.is_vector(self.var_type) or dtypes.is_complex(self.var_type) or dtypes.is_scalar(self.var_type)): + raise ValueError(f"Variable '{self.resolve()}' of type '{self.var_type.name}' does not support swizzling!") + + if not self.use_child_type: + raise ValueError(f"Variable '{self.resolve()}' does not support swizzling!") + + if len(components) < 1 or len(components) > 4: + raise ValueError(f"Swizzle must have between 1 and 4 components, got {len(components)}!") for c in components: - assert c in ['x', 'y', 'z', 'w'], f"Invalid swizzle component '{c}'!" + if c not in ['x', 'y', 'z', 'w']: + raise ValueError(f"Invalid swizzle component '{c}' in swizzle '{components}' for variable '{self.resolve()}'!") sample_type = self.var_type if dtypes.is_scalar(self.var_type) else self.var_type.child_type return_type = sample_type if len(components) == 1 else dtypes.to_vector(sample_type, len(components)) @@ -149,7 +162,8 @@ def swizzle(self, components: str) -> "ShaderVariable": base_expr = self.resolve() if dtypes.is_scalar(self.var_type): - assert all(c == 'x' for c in components), f"Cannot swizzle scalar variable '{self.resolve()}' with components other than 'x'!" + if any(c != 'x' for c in components): + raise ValueError(f"Cannot swizzle scalar variable '{self.resolve()}' with components other than 'x'!") scalar_x_expr = backend.component_access_expr(base_expr, "x", self.var_type) swizzle_expr = scalar_x_expr @@ -168,14 +182,14 @@ def swizzle(self, components: str) -> "ShaderVariable": register=self.register and len(components) == 1 ) - if self.var_type.shape[0] < 4: - assert 'w' not in components, f"Cannot swizzle variable '{self.resolve()}' of type '{self.var_type.name}' with component 'w'!" + if self.var_type.shape[0] < 4 and 'w' in components: + raise ValueError(f"Cannot swizzle variable '{self.resolve()}' of type '{self.var_type.name}' with component 'w' because it only has {self.var_type.shape[0]} components!") - if self.var_type.shape[0] < 3: - assert 'z' not in components, f"Cannot swizzle variable '{self.resolve()}' of type '{self.var_type.name}' with component 'z'!" + if self.var_type.shape[0] < 3 and 'z' in components: + raise ValueError(f"Cannot swizzle variable '{self.resolve()}' of type '{self.var_type.name}' with component 'z' because it only has {self.var_type.shape[0]} components!") - if self.var_type.shape[0] < 2: - assert 'y' not in components, f"Cannot swizzle variable '{self.resolve()}' of type '{self.var_type.name}' with component 'y'!" + if self.var_type.shape[0] < 2 and 'y' in components: + raise ValueError(f"Cannot swizzle variable '{self.resolve()}' of type '{self.resolve()}' with component 'y' because it only has {self.var_type.shape[0]} components!") if len(components) == 1: component_index = "xyzw".index(components) @@ -207,7 +221,8 @@ def swizzle(self, components: str) -> "ShaderVariable": ) def conjugate(self) -> "ShaderVariable": - assert self.is_complex, f"Variable '{self.resolve()}' of type '{self.var_type.name}' is not a complex variable and cannot be conjugated!" + if not self.is_complex: + raise ValueError(f"Variable '{self.resolve()}' of type '{self.var_type.name}' is not a complex variable and cannot be conjugated!") return ShaderVariable( var_type=self.var_type, @@ -221,7 +236,8 @@ def conjugate(self) -> "ShaderVariable": ) def set_value(self, value: "ShaderVariable") -> None: - assert self.settable, f"Cannot set value of '{self.resolve()}' because it is not a settable variable!" + if not self.settable: + raise ValueError(f"Cannot set value of '{self.resolve()}' because it is not a settable variable!") self.write_callback() self.read_callback() @@ -242,16 +258,39 @@ def set_value(self, value: "ShaderVariable") -> None: base_utils.append_contents(f"{self.resolve()} = {base_utils.format_number_literal(value)};\n") return - assert self.var_type == value.var_type, f"Cannot set variable of type '{self.var_type.name}' to value of type '{value.var_type.name}'!" + if self.var_type != value.var_type: + raise ValueError(f"Cannot set variable of type '{self.var_type.name}' to value of type '{value.var_type.name}'!") + value.read_callback() + if self.buffer_root is not None and self.buffer_index_expr is not None: + scalar_expr = getattr(self.buffer_root, "scalar_expr", None) + if scalar_expr is not None: + backend = getattr(self.buffer_root, "codegen_backend", None) + if backend is None: + backend = get_codegen_backend() + + packed_write = backend.packed_buffer_write_statements( + scalar_expr, + self.var_type, + self.buffer_index_expr, + value.resolve(), + ) + + if packed_write is not None: + base_utils.append_contents(packed_write) + return + base_utils.append_contents(f"{self.resolve()} = {value.resolve()};\n") def __setitem__(self, index, value: "ShaderVariable") -> None: - assert self.settable, f"Cannot set value of '{self.resolve()}' because it is not a settable variable!" + if not self.settable: + raise ValueError(f"Cannot set value of '{self.resolve()}' because it is not a settable variable!") if isinstance(index, slice): - assert index.start is None and index.stop is None and index.step is None, "Only full slice (:) is supported!" + if index.start is not None or index.stop is not None or index.step is not None: + raise ValueError("Only full slice (:) is supported for setting variable values!") + self.set_value(value) return @@ -283,10 +322,14 @@ def __setattr__(self, name: str, value: "ShaderVariable") -> "ShaderVariable": elif name == "y": self.y.set_value(value) elif name == "z": - assert self.var_type.shape[0] >= 3, f"Variable '{self.resolve()}' of type '{self.var_type.name}' does not have 'z' component!" + if self.var_type.shape[0] < 3: + raise ValueError(f"Variable '{self.resolve()}' of type '{self.var_type.name}' does not have 'z' component!") + self.z.set_value(value) elif name == "w": - assert self.var_type.shape[0] == 4, f"Variable '{self.resolve()}' of type '{self.var_type.name}' does not have 'w' component!" + if self.var_type.shape[0] < 4: + raise ValueError(f"Variable '{self.resolve()}' of type '{self.var_type.name}' does not have 'w' component!") + self.w.set_value(value) return diff --git a/vkdispatch/execution_pipeline/buffer_builder.py b/vkdispatch/execution_pipeline/buffer_builder.py index 01418bae..49f7f9d5 100644 --- a/vkdispatch/execution_pipeline/buffer_builder.py +++ b/vkdispatch/execution_pipeline/buffer_builder.py @@ -40,7 +40,8 @@ class BufferBuilder: element_map: Dict[Tuple[str, str], BufferedStructEntry] def __init__(self, struct_alignment: Optional[int] = None, usage: Optional[BufferUsage] = None) -> None: - assert struct_alignment is not None or usage is not None, "Either struct_alignment or usage must be provided!" + if struct_alignment is None and usage is None: + raise ValueError("Either 'struct_alignment' or 'usage' must be provided!") if struct_alignment is None: if usage == BufferUsage.PUSH_CONSTANT: @@ -110,7 +111,8 @@ def _setitem_numpy(self, key: Tuple[str, str], value: Any) -> None: arr = np.array(value, dtype=buffer_element.dtype) if self.instance_count != 1: - assert arr.shape[0] == self.instance_count, f"Invalid shape for {key}! Expected {self.instance_count} but got {arr.shape[0]}!" + if arr.shape[0] != self.instance_count: + raise ValueError(f"Invalid shape for {key}! Expected {self.instance_count} but got {arr.shape[0]}!") if buffer_element.shape == (1,): arr = arr.reshape(*arr.shape, 1) @@ -221,9 +223,9 @@ def _setitem_python(self, key: Tuple[str, str], value: Any) -> None: payload = self._pack_single_instance_value(value[instance_index], key, buffer_element) self._write_payload(instance_index, buffer_element.memory_slice, payload) - def __setitem__( - self, key: Tuple[str, str], value: Union[Any, list, tuple, int, float] - ) -> None: + def set_item(self, + key: Tuple[str, str], + value: Union[Any, list, tuple, int, float]): if key not in self.element_map: raise ValueError(f"Invalid buffer element name '{key}'!") diff --git a/vkdispatch/execution_pipeline/command_graph.py b/vkdispatch/execution_pipeline/command_graph.py index efdfc40f..36879c2a 100644 --- a/vkdispatch/execution_pipeline/command_graph.py +++ b/vkdispatch/execution_pipeline/command_graph.py @@ -18,9 +18,6 @@ import dataclasses -def _runtime_supports_push_constants() -> bool: - return True - @dataclasses.dataclass class BufferBindInfo: """A dataclass to hold information about a buffer binding.""" @@ -97,7 +94,7 @@ def __init__(self, reset_on_submit: bool = False, submit_on_record: bool = False self.uniform_constants_size = 0 self.uniform_constants_buffer = None - def _ensure_uniform_constants_capacity(self, uniform_word_size: int) -> None: + def ensure_uniform_constants_capacity(self, uniform_word_size: int) -> None: if self.uniform_constants_buffer is not None and uniform_word_size <= self.uniform_constants_size: return @@ -108,26 +105,18 @@ def _ensure_uniform_constants_capacity(self, uniform_word_size: int) -> None: self.uniform_constants_size = max(uniform_word_size, self.uniform_constants_size * 2) self.uniform_constants_buffer = vd.Buffer(shape=(self.uniform_constants_size,), var_type=vd.uint32) - def _prepare_submission_state(self, instance_count: int) -> None: - if len(self.pc_builder.element_map) > 0 and ( - self.pc_builder.instance_count != instance_count or not self.buffers_valid - ): - - assert _runtime_supports_push_constants(), ( - "Push constants not supported for backends without push-constant support " - "(OpenCL). Use UBO-backed variables instead." - ) - + def prepare_submission_state(self, instance_count: int) -> None: + if len(self.pc_builder.element_map) > 0 and (self.pc_builder.instance_count != instance_count or not self.buffers_valid): self.pc_builder.prepare(instance_count) for key, value in self.pc_values.items(): - self.pc_builder[key] = value + self.pc_builder.set_item(key, value) if len(self.uniform_builder.element_map) > 0 and not self.buffers_valid: self.uniform_builder.prepare(1) for key, value in self.uniform_values.items(): - self.uniform_builder[key] = value + self.uniform_builder.set_item(key, value) uniform_word_size = (self.uniform_builder.instance_bytes + 3) // 4 uniform_payload = self.uniform_builder.tobytes() @@ -136,8 +125,10 @@ def _prepare_submission_state(self, instance_count: int) -> None: for descriptor_set, offset, size in self.uniform_descriptors: descriptor_set.set_inline_uniform_payload(uniform_payload[offset:offset + size]) else: - self._ensure_uniform_constants_capacity(uniform_word_size) - assert self.uniform_constants_buffer is not None + self.ensure_uniform_constants_capacity(uniform_word_size) + + if self.uniform_constants_buffer is None: + raise RuntimeError("Failed to allocate uniform constants buffer!") for descriptor_set, offset, size in self.uniform_descriptors: descriptor_set.bind_buffer(self.uniform_constants_buffer, 0, offset, size, True, write_access=False) @@ -156,7 +147,7 @@ def prepare_for_cuda_graph_capture(self, instance_count: int = None) -> None: if instance_count is None: instance_count = 1 - self._prepare_submission_state(instance_count) + self.prepare_submission_state(instance_count) def reset(self) -> None: """Reset the command graph by clearing the push constant buffer and descriptor @@ -184,12 +175,6 @@ def _destroy(self) -> None: super()._destroy() def bind_var(self, name: str): - if not _runtime_supports_push_constants(): - raise RuntimeError( - "CommandGraph.bind_var() is disabled for backends without push-constant " - "support (OpenCL). Pass Variable values directly at shader invocation." - ) - def register_var(key: Tuple[str, str]): if not name in self.name_to_pc_key_dict.keys(): self.name_to_pc_key_dict[name] = [] @@ -199,12 +184,6 @@ def register_var(key: Tuple[str, str]): return register_var def set_var(self, name: str, value: Any): - if not _runtime_supports_push_constants(): - raise RuntimeError( - "CommandGraph.set_var() is disabled for backends without push-constant " - "support (OpenCL). Pass Variable values directly at shader invocation." - ) - if name not in self.name_to_pc_key_dict.keys(): raise ValueError("Variable not bound!") @@ -246,19 +225,7 @@ def record_shader(self, if shader_uuid is None: shader_uuid = shader_description.name + "_" + str(uuid.uuid4()) - if (not _runtime_supports_push_constants()) and len(pc_values) > 0: - raise RuntimeError( - "Push-constant Variable payloads are disabled for backends without " - "push-constant support (OpenCL). " - "Variable values must be UBO-backed and provided at shader invocation." - ) - if len(shader_description.pc_structure) != 0: - if not _runtime_supports_push_constants(): - raise RuntimeError( - "Kernels should not emit push-constant layouts for backends without " - "push-constant support (OpenCL). Use UBO-backed variables." - ) self.pc_builder.register_struct(shader_uuid, shader_description.pc_structure) uniform_field_names = {elem.name for elem in shader_description.uniform_structure} @@ -328,10 +295,10 @@ def submit( if instance_count is None: instance_count = 1 - self._prepare_submission_state(instance_count) + self.prepare_submission_state(instance_count) for key, val in self.queued_pc_values.items(): - self.pc_builder[key] = val + self.pc_builder.set_item(key, val) my_data = None @@ -375,5 +342,7 @@ def set_global_graph(graph: CommandGraph = None) -> CommandGraph: _global_graph.custom_graph = None return - assert _get_global_graph() is None, "A global CommandGraph is already set for the current thread!" + if _get_global_graph() is not None: + raise RuntimeError("A global CommandGraph is already set for the current thread!") + _global_graph.custom_graph = graph diff --git a/vkdispatch/execution_pipeline/cuda_graph_capture.py b/vkdispatch/execution_pipeline/cuda_graph_capture.py index a96f6a9e..59c35b1a 100644 --- a/vkdispatch/execution_pipeline/cuda_graph_capture.py +++ b/vkdispatch/execution_pipeline/cuda_graph_capture.py @@ -23,7 +23,8 @@ def get_cuda_capture() -> CUDAGraphCapture: @contextmanager def cuda_graph_capture(cuda_stream=None): - assert vd.is_cuda(), "CUDA graph capture is only supported when using the CUDA backend." + if not vd.is_cuda(): + raise RuntimeError("CUDA graph capture is only supported when using the CUDA backend.") cap = CUDAGraphCapture() cap.cuda_stream = cuda_stream diff --git a/vkdispatch/fft/config.py b/vkdispatch/fft/config.py index 02628e84..8e193101 100644 --- a/vkdispatch/fft/config.py +++ b/vkdispatch/fft/config.py @@ -13,7 +13,8 @@ def plan_fft_stages(N: int, max_register_count: int, compute_item_size: int) -> all_factors = prime_factors(N) for factor in all_factors: - assert factor <= default_max_prime(), f"A prime factor of {N} is {factor}, which exceeds the maximum prime supported {default_max_prime()}" + if factor > default_max_prime(): + raise ValueError(f"A prime factor of {N} is {factor}, which exceeds the maximum prime supported {default_max_prime()}") prime_groups = group_primes(all_factors, max_register_count) @@ -115,7 +116,8 @@ def select_fft_plan_candidate( compute_item_size=compute_item_size, ) - assert candidate.stages is not None, f"Failed to create an FFT plan candidate for N={N} with max_register_count={requested_limit}" + if candidate.stages is None: + raise ValueError(f"Failed to create an FFT plan candidate for N={N} with max_register_count={requested_limit}") if candidate.batch_threads <= batch_threads_limit: return candidate @@ -235,7 +237,8 @@ def __init__( all_factors = prime_factors(N) for factor in all_factors: - assert factor <= default_max_prime(), f"A prime factor of {N} is {factor}, which exceeds the maximum prime supported {default_max_prime()}" + if factor > default_max_prime(): + raise ValueError(f"A prime factor of {N} is {factor}, which exceeds the maximum prime supported {default_max_prime()}") self.max_prime_radix = max(all_factors) diff --git a/vkdispatch/fft/context.py b/vkdispatch/fft/context.py index 8a6bc7cc..fffcb2f2 100644 --- a/vkdispatch/fft/context.py +++ b/vkdispatch/fft/context.py @@ -3,7 +3,7 @@ import vkdispatch.base.dtype as dtypes import contextlib -from typing import Optional, Tuple, Union, List, Dict +from typing import Optional, Tuple, List from .io_manager import IOManager from .config import FFTConfig @@ -15,7 +15,7 @@ from .global_memory_iterators import global_reads_iterator, global_writes_iterator class FFTContext: - shader_context: vd.ShaderContext + shader_context: vc.ShaderContext config: FFTConfig grid: FFTGridManager registers: FFTRegisters @@ -28,7 +28,7 @@ class FFTContext: declarer: str def __init__(self, - shader_context: vd.ShaderContext, + shader_context: vc.ShaderContext, buffer_shape: Tuple, axis: int = None, max_register_count: int = None, @@ -50,7 +50,8 @@ def __init__(self, self.name = name if name is not None else f"fft_shader_{buffer_shape}_{axis}" def allocate_registers(self, name: str, count: int = None) -> FFTRegisters: - assert name is not None, "Must provide a name for allocated registers" + if name is None: + raise ValueError("Must provide a name for allocated registers") if count is None: count = self.config.register_count @@ -58,7 +59,9 @@ def allocate_registers(self, name: str, count: int = None) -> FFTRegisters: return FFTRegisters(self.resources, count, name) def declare_shader_args(self, types: List) -> List[vc.ShaderVariable]: - assert not self.declared_shader_args, f"Shader arguments already declared with {self.declarer}" + if self.declared_shader_args: + raise ValueError(f"Shader arguments already declared with {self.declarer}") + self.declared_shader_args = True self.declarer = "declare_shader_args" return self.shader_context.declare_input_arguments(types) @@ -69,7 +72,10 @@ def make_io_manager(self, input_type: Optional[dtypes.dtype] = None, input_map: Optional[vd.MappingFunction] = None, kernel_map: Optional[vd.MappingFunction] = None) -> IOManager: - assert not self.declared_shader_args, f"Shader arguments already declared with {self.declarer}" + + if self.declared_shader_args: + raise ValueError(f"Shader arguments already declared with {self.declarer}") + self.declared_shader_args = True self.declarer = "make_io_manager" return IOManager( @@ -131,14 +137,16 @@ def register_shuffle(self, ) def compile_shader(self): - self.fft_callable = self.shader_context.get_function( + self.fft_callable = vd.make_shader_function( + self.shader_context.get_description(self.name), local_size=self.grid.local_size, - exec_count=self.grid.exec_size, - name=self.name + exec_count=self.grid.exec_size ) def get_callable(self) -> vd.ShaderFunction: - assert self.fft_callable is not None, "Shader not compiled yet... something is wrong" + if self.fft_callable is None: + raise ValueError("Shader not compiled yet... something is wrong") + return self.fft_callable def execute(self, inverse: bool): @@ -178,7 +186,7 @@ def fft_context(buffer_shape: Tuple, name: Optional[str] = None): try: - with vd.shader_context(vc.ShaderFlags.NO_EXEC_BOUNDS) as context: + with vc.shader_context(vc.ShaderFlags.NO_EXEC_BOUNDS) as context: fft_context = FFTContext( shader_context=context, buffer_shape=buffer_shape, diff --git a/vkdispatch/fft/functions.py b/vkdispatch/fft/functions.py index 0818a8eb..7660a58b 100644 --- a/vkdispatch/fft/functions.py +++ b/vkdispatch/fft/functions.py @@ -117,8 +117,8 @@ def fft( input_type: vd.dtype = None, compute_type: vd.dtype = None, input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None): - - assert len(buffers) >= 1, "At least one buffer must be provided" + if len(buffers) < 1: + raise ValueError("At least one buffer must be provided") if input_map is None and output_map is None and len(buffers) != 1: raise ValueError("fft() expects exactly one buffer unless input_map/output_map are used") @@ -171,7 +171,8 @@ def fft2( input_type: vd.dtype = None, compute_type: vd.dtype = None, ): - assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer.shape) != 2 and len(buffer.shape) != 3: + raise ValueError('Buffer must have 2 or 3 dimensions') fft( buffer, @@ -204,7 +205,8 @@ def fft3( input_type: vd.dtype = None, compute_type: vd.dtype = None, ): - assert len(buffer.shape) == 3, 'Buffer must have 3 dimensions' + if len(buffer.shape) != 3: + raise ValueError('Buffer must have 3 dimensions') fft( buffer, @@ -275,7 +277,8 @@ def ifft2( input_type: vd.dtype = None, compute_type: vd.dtype = None, ): - assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer.shape) != 2 and len(buffer.shape) != 3: + raise ValueError('Buffer must have 2 or 3 dimensions') ifft( buffer, @@ -311,7 +314,8 @@ def ifft3( input_type: vd.dtype = None, compute_type: vd.dtype = None, ): - assert len(buffer.shape) == 3, 'Buffer must have 3 dimensions' + if len(buffer.shape) != 3: + raise ValueError('Buffer must have 3 dimensions') ifft( buffer, @@ -367,7 +371,8 @@ def rfft( ) def rfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, compute_type: vd.dtype = None): - assert len(buffer.real_shape) == 2 or len(buffer.real_shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer.real_shape) != 2 and len(buffer.real_shape) != 3: + raise ValueError('Buffer must have 2 or 3 dimensions') rfft(buffer, graph=graph, print_shader=print_shader, compute_type=compute_type) fft( @@ -381,7 +386,8 @@ def rfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bo ) def rfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, compute_type: vd.dtype = None): - assert len(buffer.real_shape) == 3, 'Buffer must have 3 dimensions' + if len(buffer.real_shape) != 3: + raise ValueError('Buffer must have 3 dimensions') rfft(buffer, graph=graph, print_shader=print_shader, compute_type=compute_type) fft( @@ -426,7 +432,8 @@ def irfft( ) def irfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, compute_type: vd.dtype = None): - assert len(buffer.real_shape) == 2 or len(buffer.real_shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer.real_shape) != 2 and len(buffer.real_shape) != 3: + raise ValueError('Buffer must have 2 or 3 dimensions') ifft( buffer, @@ -441,7 +448,8 @@ def irfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: b irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize, compute_type=compute_type) def irfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, compute_type: vd.dtype = None): - assert len(buffer.real_shape) == 3, 'Buffer must have 3 dimensions' + if len(buffer.real_shape) != 3: + raise ValueError('Buffer must have 3 dimensions') ifft( buffer, @@ -484,7 +492,8 @@ def convolve( kernel_type: vd.dtype = None, compute_type: vd.dtype = None, input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None): - assert len(buffers) >= 1, "At least one buffer must be provided" + if len(buffers) < 1: + raise ValueError("At least one buffer must be provided") if kernel_map is None and len(buffers) < 2: raise ValueError("convolve() requires at least an output buffer and kernel buffer") @@ -554,7 +563,8 @@ def convolve2D( kernel_type: vd.dtype = None, compute_type: vd.dtype = None): - assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer.shape) != 2 and len(buffer.shape) != 3: + raise ValueError('Buffer must have 2 or 3 dimensions') input_buffers = [buffer] @@ -612,8 +622,9 @@ def convolve2DR( print_shader: bool = False, normalize: bool = True, compute_type: vd.dtype = None): - - assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' + + if len(buffer.real_shape) != 2 and len(buffer.real_shape) != 3: + raise ValueError('Buffer must have 2 or 3 dimensions') rfft(buffer, graph=graph, print_shader=print_shader, compute_type=compute_type) convolve( @@ -677,7 +688,10 @@ def transpose( f"out_buffer type ({out_buffer.var_type.name}) does not match output_type ({resolved_output_type.name})" ) - assert out_buffer.size >= transposed_size, f"Output buffer size {out_buffer.size} does not match expected transposed size {transposed_size}" + if out_buffer.size < transposed_size: + raise ValueError( + f"Output buffer size {out_buffer.size} is smaller than expected transposed size {transposed_size}" + ) if conv_shape is None: conv_shape = in_buffer.shape diff --git a/vkdispatch/fft/global_memory_iterators.py b/vkdispatch/fft/global_memory_iterators.py index c621f6b6..4625b2a0 100644 --- a/vkdispatch/fft/global_memory_iterators.py +++ b/vkdispatch/fft/global_memory_iterators.py @@ -1,6 +1,6 @@ import vkdispatch.codegen as vc -from typing import Optional, Tuple +from typing import Optional, Tuple, ContextManager import dataclasses @@ -29,9 +29,14 @@ def global_batch_offset( outer_batch_stride = config.N * config.fft_stride if r2c: - assert inverse is not None, "Must specify inverse for r2c io" - assert is_output is not None, "Must specify is_output for r2c io" - assert config.fft_stride == 1, "R2C io only supported for contiguous data" + if inverse is None: + raise ValueError("Must specify inverse for r2c io") + + if is_output is None: + raise ValueError("Must specify is_output for r2c io") + + if config.fft_stride != 1: + raise ValueError("R2C io only supported for contiguous data") outer_batch_stride = (config.N // 2) + 1 @@ -78,9 +83,8 @@ def write_to_buffer(self, return if not self.inverse: - vc.if_statement(self.fft_index < (self.fft_size // 2) + 1) - buffer[io_index] = _cast_if_needed(register, buffer.var_type) - vc.end() + with vc.if_block(self.fft_index < (self.fft_size // 2) + 1): + buffer[io_index] = _cast_if_needed(register, buffer.var_type) return out_scalar_type = buffer.var_type.child_type @@ -95,7 +99,8 @@ def global_writes_iterator( extra_comment_lines = "" if r2c: - assert inverse is not None, "Must specify inverse for r2c io" + if inverse is None: + raise ValueError("Must specify inverse for r2c io") if inverse: extra_comment_lines = "\nDoing R2C inverse write, applying Hermitian reconstruction and packed-real rules as needed." @@ -134,6 +139,8 @@ class GlobalReadOp(MemoryOp): format_transposed: bool signal_range: Tuple[int, int] + signal_range_context: ContextManager + @classmethod def from_memory_op(cls, base: MemoryOp, @@ -153,30 +160,39 @@ def from_memory_op(cls, inverse=inverse, r2c_inverse_offset=r2c_inverse_offset, format_transposed=format_transposed, - signal_range=signal_range + signal_range=signal_range, + signal_range_context=None ) def check_in_signal_range(self) -> bool: if self.signal_range == (0, self.fft_size): + self.signal_range_context = None return - if self.signal_range[0] == 0: - vc.if_statement(self.fft_index < self.signal_range[1]) - return + assert self.signal_range_context is None, "Signal range context already active" - if self.signal_range[1] == self.fft_size: - vc.if_statement(self.fft_index >= self.signal_range[0]) - return + condition_check = None + + if self.signal_range[0] == 0: + condition_check = self.fft_index < self.signal_range[1] + elif self.signal_range[1] == self.fft_size: + condition_check = self.fft_index >= self.signal_range[0] + else: + condition_check = vc.all(self.fft_index >= self.signal_range[0], self.fft_index < self.signal_range[1]) - vc.if_all(self.fft_index >= self.signal_range[0], self.fft_index < self.signal_range[1]) + self.signal_range_context = vc.if_block(condition_check) + self.signal_range_context.__enter__() def signal_range_end(self, register: vc.ShaderVariable): if self.signal_range == (0, self.fft_size): return + + assert self.signal_range_context is not None, "Signal range context not active" + + self.signal_range_context.__exit__(None, None, None) - vc.else_statement() - register[:] = vc.to_dtype(register.var_type, 0) - vc.end() + with vc.else_block(): + register[:] = vc.to_dtype(register.var_type, 0) def read_from_buffer(self, buffer: vc.Buffer, @@ -202,13 +218,13 @@ def read_from_buffer(self, self.signal_range_end(register) return - vc.if_statement(self.fft_index >= (self.fft_size // 2) + 1) - self.io_index_2[:] = self.r2c_inverse_offset - io_index - register[:] = _cast_if_needed(buffer[self.io_index_2], register.var_type) - register.imag = -register.imag - vc.else_statement() - register[:] = _cast_if_needed(buffer[io_index], register.var_type) - vc.end() + with vc.if_block(self.fft_index >= (self.fft_size // 2) + 1): + self.io_index_2[:] = self.r2c_inverse_offset - io_index + register[:] = _cast_if_needed(buffer[self.io_index_2], register.var_type) + register.imag = -register.imag + + with vc.else_block(): + register[:] = _cast_if_needed(buffer[io_index], register.var_type) self.signal_range_end(register) @@ -254,8 +270,8 @@ def global_reads_iterator( vc.comment(f"""Reading input samples from global memory into FFT registers.{transpose_comment_str}{signal_range_comment_str}{r2c_comment_str}""") - if r2c: - assert not format_transposed, "R2C transposed format not supported" + if r2c and format_transposed: + raise ValueError("R2C format transposition not supported") resources = registers.resources config = registers.config diff --git a/vkdispatch/fft/io_manager.py b/vkdispatch/fft/io_manager.py index b91d6bd9..6f4d6c09 100644 --- a/vkdispatch/fft/io_manager.py +++ b/vkdispatch/fft/io_manager.py @@ -22,12 +22,18 @@ def _get_read_op() -> Optional[GlobalReadOp]: def write_op() -> GlobalWriteOp: op = _get_write_op() - assert op is not None, "No global write operation is set for the current thread!" + + if op is None: + raise ValueError("No global write operation is set for the current thread!") + return op def read_op() -> GlobalReadOp: op = _get_read_op() - assert op is not None, "No global read operation is set for the current thread!" + + if op is None: + raise ValueError("No global read operation is set for the current thread!") + return op def set_write_op(op: GlobalWriteOp): @@ -35,7 +41,9 @@ def set_write_op(op: GlobalWriteOp): _write_op.op = None return - assert _get_write_op() is None, "A global write operation is already set for the current thread!" + if _get_write_op() is not None: + raise ValueError("A global write operation is already set for the current thread!") + _write_op.op = op def set_read_op(op: GlobalReadOp): @@ -43,7 +51,8 @@ def set_read_op(op: GlobalReadOp): _read_op.op = None return - assert _get_read_op() is None, "A global read operation is already set for the current thread!" + if _get_read_op() is not None: + raise ValueError("A global read operation is already set for the current thread!") _read_op.op = op class IOManager: @@ -54,7 +63,7 @@ class IOManager: def __init__(self, default_registers: FFTRegisters, - shader_context: vd.ShaderContext, + shader_context: vc.ShaderContext, output_map: Optional[vd.MappingFunction], output_type: dtypes.dtype = vd.complex64, input_type: Optional[dtypes.dtype] = None, diff --git a/vkdispatch/fft/io_proxy.py b/vkdispatch/fft/io_proxy.py index 5744b1ba..fd65addb 100644 --- a/vkdispatch/fft/io_proxy.py +++ b/vkdispatch/fft/io_proxy.py @@ -33,13 +33,15 @@ def __init__(self, obj: Union[type, vd.MappingFunction], name: str): raise ValueError("IOObject must be initialized with a Buffer or MappingFunction") def set_variables(self, vars: List[vc.Buffer]) -> None: - assert len(vars) == len(self.buffer_types), "Number of buffer variables does not match number of buffer types" + if len(vars) != len(self.buffer_types): + raise ValueError(f"Number of buffer variables does not match number of buffer types. Expected {len(self.buffer_types)} but got {len(vars)}") + if len(vars) == 0: self.enabled = False return - if self.map_func is None: - assert len(vars) == 1, "Buffer IOObject must have exactly one buffer variable" + if self.map_func is None and len(vars) != 1: + raise ValueError("IOProxy initialized with a non-mapping function must have exactly one buffer variable") self.buffer_variables = vars @@ -47,5 +49,7 @@ def has_callback(self) -> bool: return self.map_func is not None def do_callback(self): - assert self.map_func is not None, "IOProxy does not have a mapping function" + if self.map_func is None: + raise ValueError("IOProxy does not have a mapping function") + self.map_func.callback(*self.buffer_variables) diff --git a/vkdispatch/fft/prime_utils.py b/vkdispatch/fft/prime_utils.py index ee1624fa..cc823d91 100644 --- a/vkdispatch/fft/prime_utils.py +++ b/vkdispatch/fft/prime_utils.py @@ -51,7 +51,8 @@ def pad_dim(dim: int, max_register_count: int = None): if max_register_count is None: max_register_count = default_register_limit() - assert dim > 0, 'Dimension must be greater than 0' + if dim <= 0: + raise ValueError('Dimension must be greater than 0') current_dim = dim current_primes = prime_factors(current_dim) diff --git a/vkdispatch/fft/registers.py b/vkdispatch/fft/registers.py index 31c79e32..55201342 100644 --- a/vkdispatch/fft/registers.py +++ b/vkdispatch/fft/registers.py @@ -79,7 +79,8 @@ def try_shuffle(self, output_stage: int = -1, input_stage: int = 0) -> bool: return True def read_from_registers(self, other: "FFTRegisters") -> "FFTRegisters": - assert self.count == other.count, "Register counts must match for copy" + if self.count != other.count: + raise ValueError("Register counts must match for copy") for i in range(self.count): self.registers[i][:] = other.registers[i] diff --git a/vkdispatch/fft/resources.py b/vkdispatch/fft/resources.py index f63bd04e..b570c79e 100644 --- a/vkdispatch/fft/resources.py +++ b/vkdispatch/fft/resources.py @@ -1,30 +1,31 @@ -import vkdispatch as vd import vkdispatch.codegen as vc -from vkdispatch.codegen.abreviations import * import dataclasses -from typing import List +from typing import List, ContextManager from .config import FFTConfig from .grid_manager import FFTGridManager +import contextlib + @dataclasses.dataclass class FFTResources: input_batch_offset: vc.ShaderVariable output_batch_offset: vc.ShaderVariable omega_register: vc.ShaderVariable - subsequence_offset: Const[u32] - io_index: Const[u32] - io_index_2: Const[u32] + subsequence_offset: vc.Const[vc.u32] + io_index: vc.Const[vc.u32] + io_index_2: vc.Const[vc.u32] radix_registers: List[vc.ShaderVariable] tid: vc.ShaderVariable - grid: FFTGridManager - config: FFTConfig + stage_context: ContextManager + invocation_context: ContextManager + def __init__(self, config: FFTConfig, grid: FFTGridManager): self.tid = grid.tid self.grid = grid @@ -36,30 +37,57 @@ def __init__(self, config: FFTConfig, grid: FFTGridManager): self.io_index = vc.new_uint_register(var_name="io_index") self.io_index_2 = vc.new_uint_register(var_name="io_index_2") + self.stage_context = None + self.invocation_context = None + self.radix_registers = [ vc.new_register(config.compute_type, var_name=f"radix_register_{i}") for i in range(config.max_prime_radix) ] def stage_begin(self, stage_index: int): + if self.stage_context is not None: + raise RuntimeError("Stage context is already active. Cannot begin a new stage before ending the previous one.") + thread_count = self.config.stages[stage_index].thread_count - if thread_count < self.config.batch_threads: - vc.if_statement(self.tid < thread_count) + if thread_count >= self.config.batch_threads: + return + + self.stage_context = vc.if_block(self.tid < thread_count) + self.stage_context.__enter__() def stage_end(self, stage_index: int): thread_count = self.config.stages[stage_index].thread_count - if thread_count < self.config.batch_threads: - vc.end() + if thread_count >= self.config.batch_threads: + return + + if self.stage_context is None: + raise RuntimeError("No active stage context to end.") + + self.stage_context.__exit__(None, None, None) + self.stage_context = None def invocation_gaurd(self, stage_index: int, invocation_index: int): stage = self.config.stages[stage_index] - if stage.remainder_offset == 1 and invocation_index == stage.extra_ffts: - vc.if_statement(self.tid < self.config.N // stage.registers_used) + if stage.remainder_offset == 0 or invocation_index != stage.extra_ffts: + return + + if self.invocation_context is not None: + raise RuntimeError("Invocation context is already active. Cannot begin a new invocation guard before ending the previous one.") + + self.invocation_context = vc.if_block(self.tid < self.config.N // stage.registers_used) + self.invocation_context.__enter__() def invocation_end(self, stage_index: int): stage = self.config.stages[stage_index] - if stage.remainder_offset == 1: - vc.end() + if stage.remainder_offset == 0: + return + + if self.invocation_context is None: + raise RuntimeError("No active invocation context to end.") + + self.invocation_context.__exit__(None, None, None) + self.invocation_context = None diff --git a/vkdispatch/fft/shader_factories.py b/vkdispatch/fft/shader_factories.py index 28a481fd..deb1af85 100644 --- a/vkdispatch/fft/shader_factories.py +++ b/vkdispatch/fft/shader_factories.py @@ -1,6 +1,6 @@ import vkdispatch as vd import vkdispatch.codegen as vc -from vkdispatch.codegen.abreviations import * +from vkdispatch.codegen.abbreviations import * from ..compat import numpy_compat as npc diff --git a/vkdispatch/fft/src_functions.py b/vkdispatch/fft/src_functions.py index e8952bb3..0fc79917 100644 --- a/vkdispatch/fft/src_functions.py +++ b/vkdispatch/fft/src_functions.py @@ -28,7 +28,8 @@ def fft_src( return fft_shader.get_src(line_numbers=line_numbers) def fft2_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): - assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer Shape must have 2 or 3 dimensions' + if len(buffer_shape) != 2 and len(buffer_shape) != 3: + raise ValueError('Buffer Shape must have 2 or 3 dimensions') return ( fft_src(axis=len(buffer_shape) - 2, input_map=input_map), @@ -36,7 +37,8 @@ def fft2_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_m ) def fft3_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): - assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions' + if len(buffer_shape) != 3: + raise ValueError('Buffer Shape must have 3 dimensions') return ( fft_src(buffer_shape, axis=0, input_map=input_map), @@ -54,7 +56,8 @@ def ifft_src( return fft_src(buffer_shape, axis=axis, inverse=True, normalize_inverse=normalize, input_map=input_map, output_map=output_map) def ifft2_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): - assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer_shape) != 2 and len(buffer_shape) != 3: + raise ValueError('Buffer Shape must have 2 or 3 dimensions') return ( ifft_src(buffer_shape, axis=len(buffer_shape) - 2, normalize=normalize, input_map=input_map), @@ -62,7 +65,8 @@ def ifft2_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.Mapping ) def ifft3_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): - assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions' + if len(buffer_shape) != 3: + raise ValueError('Buffer Shape must have 3 dimensions') return ( ifft_src(buffer_shape, axis=0, normalize=normalize, input_map=input_map), @@ -75,7 +79,8 @@ def rfft_src(buffer_shape: Tuple): return fft_src(buffer_shape, r2c=True) def rfft2_src(buffer_shape: Tuple): - assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer_shape) != 2 and len(buffer_shape) != 3: + raise ValueError('Buffer Shape must have 2 or 3 dimensions') return ( rfft_src(buffer_shape), @@ -83,7 +88,8 @@ def rfft2_src(buffer_shape: Tuple): ) def rfft3_src(buffer_shape: Tuple): - assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions' + if len(buffer_shape) != 3: + raise ValueError('Buffer Shape must have 3 dimensions') return ( rfft_src(buffer_shape), @@ -95,7 +101,8 @@ def irfft_src(buffer_shape: Tuple, normalize: bool = True): return fft_src(buffer_shape, inverse=True, normalize_inverse=normalize, r2c=True) def irfft2_src(buffer_shape: Tuple, normalize: bool = True): - assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer_shape) != 2 and len(buffer_shape) != 3: + raise ValueError('Buffer Shape must have 2 or 3 dimensions') return ( ifft_src(buffer_shape, axis=len(buffer_shape) - 2, normalize=normalize), @@ -103,7 +110,8 @@ def irfft2_src(buffer_shape: Tuple, normalize: bool = True): ) def irfft3_src(buffer_shape: Tuple, normalize: bool = True): - assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions' + if len(buffer_shape) != 3: + raise ValueError('Buffer Shape must have 3 dimensions') return ( ifft_src(buffer_shape, axis=0, normalize=normalize), @@ -147,7 +155,8 @@ def convolve2D_src( input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): - assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer_shape) != 2 and len(buffer_shape) != 3: + raise ValueError('Buffer must have 2 or 3 dimensions') return ( fft_src(buffer_shape, input_map=input_map), @@ -169,7 +178,8 @@ def convolve2DR_src( kernel_inner_only: bool = False, normalize: bool = True): - assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer_shape) != 2 and len(buffer_shape) != 3: + raise ValueError('Buffer must have 2 or 3 dimensions') return ( rfft_src(buffer_shape), diff --git a/vkdispatch/reduce/decorator.py b/vkdispatch/reduce/decorator.py index 0cc1e189..d8942e4f 100644 --- a/vkdispatch/reduce/decorator.py +++ b/vkdispatch/reduce/decorator.py @@ -33,8 +33,9 @@ def decorator(func: Callable[..., RetType]) -> Callable[[vd.Buffer[RetType]], vd func = lambda buffer: buffer[mapped_io_index()], return_type=func_signature.return_annotation, input_types=[vc.Buffer[func_signature.return_annotation]]) - else: - assert used_mapping_function.return_type == func_signature.return_annotation, "Mapping function return type must match the return type of the reduction function" + + elif used_mapping_function.return_type != func_signature.return_annotation: + raise ValueError("Mapping function return type must match the return type of the reduction function") return ReduceFunction( reduction=ReduceOp( diff --git a/vkdispatch/reduce/reduce_function.py b/vkdispatch/reduce/reduce_function.py index e8438498..34b10474 100644 --- a/vkdispatch/reduce/reduce_function.py +++ b/vkdispatch/reduce/reduce_function.py @@ -111,7 +111,8 @@ def __call__(self, *args, **kwargs) -> vd.Buffer: for i in skipped_axes: input_stride *= args[0].shape[i] - assert input_stride == 1, "Reduction axes must be contiguous!" + if input_stride != 1: + raise ValueError("Reduction axes must be contiguous!") workgroups_x = int(npc.ceil(input_size / (self.group_size * input_stride))) diff --git a/vkdispatch/reduce/stage.py b/vkdispatch/reduce/stage.py index 3bce6759..8e160049 100644 --- a/vkdispatch/reduce/stage.py +++ b/vkdispatch/reduce/stage.py @@ -54,20 +54,17 @@ def global_reduce( end_index = vc.new_uint_register(start_index + params.input_size, var_name="end_index") - vc.while_statement(current_index < end_index) + with vc.while_block(current_index < end_index): + mapped_value = buffers[0][current_index] - mapped_value = buffers[0][current_index] + if map_func is not None: + set_mapped_io_index(current_index) + mapped_value = map_func.callback(*buffers) + set_mapped_io_index(None) - if map_func is not None: - set_mapped_io_index(current_index) - mapped_value = map_func.callback(*buffers) - set_mapped_io_index(None) + reduction_aggregate[:] = reduction.reduction(reduction_aggregate, mapped_value) - reduction_aggregate[:] = reduction.reduction(reduction_aggregate, mapped_value) - - current_index += vc.workgroup_size().x * vc.num_workgroups().x - - vc.end() + current_index += vc.workgroup_size().x * vc.num_workgroups().x return reduction_aggregate @@ -91,19 +88,17 @@ def workgroup_reduce( current_size = group_size // 2 while current_size > subgroup_reduce_size: - vc.if_statement(tid < current_size) - sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + current_size]) - if current_size // 2 > subgroup_reduce_size: - vc.end() - else: + with vc.if_block(tid < current_size): + sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + current_size]) + + if current_size // 2 <= subgroup_reduce_size: tid_limit = 2 if subgroup_reduce_size != 1: tid_limit = 2*vc.subgroup_size() - vc.else_if_statement(tid < tid_limit) - sdata[tid] = vc.new_register(out_type, 0) - vc.end() + with vc.else_if_block(tid < tid_limit): + sdata[tid] = vc.new_register(out_type, 0) vc.barrier() @@ -122,10 +117,9 @@ def subgroup_reduce( subgroup_reduce_size = 1 if group_size > subgroup_reduce_size: - vc.if_statement(tid < subgroup_reduce_size) - sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + subgroup_reduce_size]) - vc.end() - + with vc.if_block(tid < subgroup_reduce_size): + sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + subgroup_reduce_size]) + if subgroup_reduce_size == 1: return sdata[tid].to_register("local_var") @@ -139,9 +133,9 @@ def subgroup_reduce( else: current_size = subgroup_reduce_size // 2 while current_size > 1: - vc.if_statement(tid < current_size) - sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + current_size]) - vc.end() + with vc.if_block(tid < current_size): + sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + current_size]) + vc.subgroup_barrier() current_size //= 2 @@ -160,7 +154,7 @@ def make_reduction_stage( name = f"reduction_stage_{reduction.name}_{out_type.name}_{input_types}_{group_size}" - with vd.shader_context() as context: + with vc.shader_context() as context: signature_type_array = [] signature_type_array.append(vc.Buffer[out_type]) @@ -183,8 +177,10 @@ def make_reduction_stage( batch_offset = vc.workgroup_id().y * params.output_y_batch_stride output_offset = vc.workgroup_id().x * params.output_stride - vc.if_statement(vc.local_invocation_id().x == 0) - input_variables[0][batch_offset + output_offset + params.output_offset] = local_var - vc.end() + with vc.if_block(vc.local_invocation_id().x == 0): + input_variables[0][batch_offset + output_offset + params.output_offset] = local_var - return context.get_function(local_size=(group_size, 1, 1), name=name) + return vd.make_shader_function( + context.get_description(name), + local_size=(group_size, 1, 1) + ) diff --git a/vkdispatch/shader/context.py b/vkdispatch/shader/context.py deleted file mode 100644 index 9bd5713c..00000000 --- a/vkdispatch/shader/context.py +++ /dev/null @@ -1,88 +0,0 @@ -import vkdispatch as vd -import vkdispatch.codegen as vc - -from .signature import ShaderSignature, ShaderArgumentType -from typing import List, Optional, Any - -import contextlib - -class ShaderContext: - builder: vc.ShaderBuilder - signature: ShaderSignature - shader_function: vd.ShaderFunction - - def __init__(self, builder: vc.ShaderBuilder): - self.builder = builder - self.signature = None - self.shader_function = None - - def get_function(self, - local_size=None, - workgroups=None, - exec_count=None, - name: Optional[str] = None) -> vd.ShaderFunction: - if self.shader_function is not None: - return self.shader_function - - description = self.builder.build("shader" if name is None else name) - - # Resource bindings are declared before final shader layout is known. - # For some shader construction paths (e.g. from_description), signatures are - # pre-populated and still hold logical bindings assuming a reserved UBO at 0. - binding_shift = description.resource_binding_base - 1 - if binding_shift != 0: - binding_access_len = len(description.binding_access) - needs_remap = False - - for shader_arg in self.signature.arguments: - if ( - shader_arg.binding is not None - and ( - shader_arg.arg_type == ShaderArgumentType.BUFFER - or shader_arg.arg_type == ShaderArgumentType.IMAGE - ) - and shader_arg.binding >= binding_access_len - ): - needs_remap = True - break - - if needs_remap: - for shader_arg in self.signature.arguments: - if ( - shader_arg.binding is not None - and ( - shader_arg.arg_type == ShaderArgumentType.BUFFER - or shader_arg.arg_type == ShaderArgumentType.IMAGE - ) - ): - shader_arg.binding += binding_shift - - self.shader_function = vd.ShaderFunction( - description, - self.signature, - local_size=local_size, - workgroups=workgroups, - exec_count=exec_count - ) - - return self.shader_function - - def declare_input_arguments(self, - annotations: List, - names: Optional[List[str]] = None, - defaults: Optional[List[Any]] = None): - self.signature = ShaderSignature.from_type_annotations(self.builder, annotations, names, defaults) - return self.signature.get_variables() - -@contextlib.contextmanager -def shader_context(flags: vc.ShaderFlags = vc.ShaderFlags.NONE): - - builder = vc.ShaderBuilder(flags=flags, is_apple_device=vd.get_context().is_apple()) - old_builder = vc.set_builder(builder) - - context = ShaderContext(builder) - - try: - yield context - finally: - vc.set_builder(old_builder) \ No newline at end of file diff --git a/vkdispatch/shader/decorator.py b/vkdispatch/shader/decorator.py index 0dbe5239..603e153d 100644 --- a/vkdispatch/shader/decorator.py +++ b/vkdispatch/shader/decorator.py @@ -3,9 +3,7 @@ import dataclasses import inspect -from typing import Callable, TypeVar - -from .context import shader_context +from typing import Callable, Optional, List, Any import sys @@ -27,7 +25,7 @@ def inspect_function_signature(func: Callable): raise ValueError("All parameters must be annotated") - if not dataclasses.is_dataclass(param.annotation): # issubclass(param.annotation.__origin__, dataclasses.dataclass): + if not dataclasses.is_dataclass(param.annotation): if not hasattr(param.annotation, '__args__'): raise TypeError(f"Argument '{param.name}: vd.{param.annotation}' must have a type annotation") @@ -44,7 +42,10 @@ def shader( exec_size=None, local_size=None, workgroups=None, - flags: vc.ShaderFlags = vc.ShaderFlags.NONE): + flags: vc.ShaderFlags = vc.ShaderFlags.NONE, + arg_type_annotations: Optional[List[Any]] = None, + arg_names: Optional[List[str]] = None, + arg_defaults: Optional[List[Any]] = None): """ A decorator that transforms a Python function into a GPU Compute Shader. @@ -64,6 +65,17 @@ def shader( :type workgroups: Union[int, Tuple[int, ...], Callable] :param flags: Compilation flags (e.g., ``vc.ShaderFlags.NO_EXEC_BOUNDS``). :type flags: vkdispatch.codegen.ShaderFlags + :param arg_type_annotations: Optional list of type annotations for the shader function's + parameters. If not provided, annotations will be inferred + from the decorated function's signature. + :type arg_type_annotations: Optional[List[Any]] + :param arg_names: Optional list of parameter names corresponding to the type annotations. + If not provided, names will be inferred from the decorated + function's signature. + :type arg_names: Optional[List[str]] + :param arg_defaults: Optional list of default values for the parameters. If not provided, + defaults will be inferred from the decorated function's signature. + :type arg_defaults: Optional[List[Any]] :return: A ``ShaderFunction`` wrapper that can be called to execute the kernel. :raises ValueError: If both ``exec_size`` and ``workgroups`` are provided. """ @@ -71,11 +83,25 @@ def shader( raise ValueError("Cannot specify both 'workgroups' and 'exec_size'") def decorator_callback(func: Callable[P, None]) -> Callable[P, None]: - with shader_context(flags=flags) as context: - annotations, names, defaults = inspect_function_signature(func) - args = context.declare_input_arguments(annotations, names, defaults) + type_annotations = arg_type_annotations + names = arg_names + defaults = arg_defaults + + if type_annotations is None: + if names is not None or defaults is not None: + raise ValueError("If 'arg_type_annotations' is not provided, 'arg_names' and 'arg_defaults' must also be None") + + type_annotations, names, defaults = inspect_function_signature(func) + + with vc.shader_context(flags=flags) as context: + args = context.declare_input_arguments(type_annotations, names, defaults) func(*args) - return context.get_function(local_size=local_size, workgroups=workgroups, exec_count=exec_size, name=func.__name__) + return vd.make_shader_function( + context.get_description(func.__name__), + local_size=local_size, + workgroups=workgroups, + exec_count=exec_size + ) return decorator_callback diff --git a/vkdispatch/shader/map.py b/vkdispatch/shader/map.py index 6d27ccb6..e1c5af90 100644 --- a/vkdispatch/shader/map.py +++ b/vkdispatch/shader/map.py @@ -28,16 +28,14 @@ def __eq__(self, other): def callback(self, *args): if self.return_type is None: - vc.new_scope(indent=False) - self.mapping_function(*args) - vc.end(indent=False) + with vc.scope_block(): + self.mapping_function(*args) return return_var = vc.new_register(self.return_type) - vc.new_scope(indent=False) - return_var[:] = self.mapping_function(*args) - vc.end(indent=False) + with vc.scope_block(): + return_var[:] = self.mapping_function(*args) return return_var diff --git a/vkdispatch/shader/shader_function.py b/vkdispatch/shader/shader_function.py index 18e135ab..959d7fd5 100644 --- a/vkdispatch/shader/shader_function.py +++ b/vkdispatch/shader/shader_function.py @@ -1,7 +1,7 @@ import vkdispatch as vd import vkdispatch.codegen as vc -from typing import Tuple +from typing import Tuple, Optional from typing import Union from typing import Callable from typing import List @@ -9,7 +9,7 @@ from vkdispatch.base.compute_plan import ComputePlan -from .signature import ShaderArgumentType, ShaderSignature +from ..codegen.shader_description import ShaderArgumentType import uuid @@ -48,7 +48,7 @@ def get_values(self) -> List[Any]: def __getattr__(self, name: str): return self.ref_dict[name] -class ExectionBounds: +class ExecutionBounds: local_size: Tuple[int, int, int] workgroups: Union[Tuple[int, int, int], Callable, None] exec_size: Union[Tuple[int, int, int], Callable, None] @@ -77,7 +77,7 @@ def process_input(self, in_val, args, kwargs) -> Tuple[int, int, int]: if not isinstance(in_val, tuple): raise ValueError("Must provide a tuple of dimensions!") - if len(in_val) < 0 or len(in_val) > 4: + if len(in_val) <= 0 or len(in_val) >= 4: raise ValueError("Must provide a tuple of length 1, 2, or 3!") return_val = [1, 1, 1] @@ -127,8 +127,11 @@ def get_blocks_and_limits(self, args, kwargs) -> Tuple[Tuple[int, int, int], Tup my_block = my_blocks[i] max_block = vd.get_context().max_workgroup_count[i] - assert my_block > 0, f"Workgroup count for dimension {i} must be greater than 0!" - assert my_block <= max_block, f"Workgroup count ({my_block}) for dimension {i} exceeds maximum allowed size ({max_block})!" + if my_block <= 0: + raise ValueError(f"Workgroup count for dimension {i} must be greater than 0!") + + if my_block > max_block: + raise ValueError(f"Workgroup count ({my_block}) for dimension {i} exceeds maximum allowed size ({max_block})!") return (my_blocks, my_limits) @@ -141,39 +144,47 @@ class ShaderSource: def __repr__(self): return f"// ====== Source Code for '{self.name}', workgroup_size: {self.local_size} ======\n{self.code}" +class ShaderBuildError(RuntimeError): + shader_source: ShaderSource + compiler_log: Optional[str] + + def __init__( + self, + message: str, + *, + shader_source: ShaderSource, + compiler_log: Optional[str] = None, + ) -> None: + super().__init__(message) + self.shader_source = shader_source + self.compiler_log = compiler_log + class ShaderFunction: plan: ComputePlan shader_description: vc.ShaderDescription - shader_signature: ShaderSignature - bounds: ExectionBounds + bounds: ExecutionBounds ready: bool name: str source: str - flags: vc.ShaderFlags local_size: Union[Tuple[int, int, int], Callable, None] workgroups: Union[Tuple[int, int, int], Callable, None] exec_size: Union[Tuple[int, int, int], Callable, None] def __init__(self, shader_description: vc.ShaderDescription, - shader_signature: ShaderSignature, local_size=None, workgroups=None, - exec_count=None, - flags: vc.ShaderFlags = vc.ShaderFlags.NONE, - name: str = None) -> None: + exec_count=None) -> None: self.plan = None self.shader_description = shader_description - self.shader_signature = shader_signature self.bounds = None self.ready = False - self.name = name if name is not None else None + self.name = shader_description.name self.source = None self.local_size = local_size self.workgroups = workgroups self.exec_size = exec_count - self.flags = flags def build(self): if self.ready: @@ -185,7 +196,7 @@ def build(self): else [vd.get_context().max_workgroup_size[0], 1, 1] ) - self.bounds = ExectionBounds(self.shader_signature.get_names_and_defaults(), my_local_size, self.workgroups, self.exec_size) + self.bounds = ExecutionBounds(self.shader_description.get_arg_names_and_defaults(), my_local_size, self.workgroups, self.exec_size) shader_backend_name = ( self.shader_description.backend.name @@ -231,9 +242,13 @@ def build(self): self.shader_description.name ) except Exception as e: - print(f"Error building shader: {e}") - print(self.get_src(build=False, line_numbers=True)) - raise e + shader_source = self.get_src(build=False, line_numbers=True) + compiler_log = str(e) + raise ShaderBuildError( + f"Failed to build shader '{self.name}': {compiler_log}", + shader_source=shader_source, + compiler_log=compiler_log, + ) from e self.ready = True @@ -260,8 +275,9 @@ def print_src(self, line_numbers: bool = None): print(self.get_src(line_numbers)) def __call__(self, *args, **kwargs): - assert not vd.is_dummy(), "Cannot execute shader functions with dummy backend!" - + if vd.is_dummy(): + raise RuntimeError("Cannot execute shader functions with dummy backend!") + self.build() if not self.ready: @@ -287,7 +303,7 @@ def __call__(self, *args, **kwargs): shader_uuid = f"{self.shader_description.name}.{uuid.uuid4()}" - for ii, shader_arg in enumerate(self.shader_signature.arguments): + for ii, shader_arg in enumerate(self.shader_description.shader_arg_infos): arg = None if ii < len(args): @@ -304,7 +320,7 @@ def __call__(self, *args, **kwargs): if shader_arg.arg_type == ShaderArgumentType.BUFFER: if not isinstance(arg, vd.Buffer): raise ValueError(f"Expected a buffer for argument '{shader_arg.name}' but got '{arg}'!") - + bound_buffers.append(vd.BufferBindInfo( buffer=arg, binding=shader_arg.binding, @@ -362,3 +378,16 @@ def __call__(self, *args, **kwargs): pc_values, shader_uuid=shader_uuid ) + +def make_shader_function( + description: vc.ShaderDescription, + local_size=None, + workgroups=None, + exec_count=None +) -> ShaderFunction: + return ShaderFunction( + description, + local_size=local_size, + workgroups=workgroups, + exec_count=exec_count + ) diff --git a/vkdispatch/shader/signature.py b/vkdispatch/shader/signature.py deleted file mode 100644 index 8d6f4a46..00000000 --- a/vkdispatch/shader/signature.py +++ /dev/null @@ -1,179 +0,0 @@ -import vkdispatch.codegen as vc - -from ..base.dtype import is_dtype - -from typing import List -from typing import Any -from typing import Callable -from typing import Optional -from typing import Tuple -from typing import Union -from typing import Dict -#from types import GenericAlias - -from typing import get_type_hints - -import dataclasses - -import inspect - -import enum - -_PUSH_CONSTANT_UNSUPPORTED_BACKENDS = set() - - -def _push_constant_not_supported_error(backend_name: str) -> str: - return ( - f"Push Constants are not supported for the {backend_name.upper()} backend. " - "Use Const instead." - ) - - -class ShaderArgumentType(enum.Enum): - BUFFER = 0 - IMAGE = 1 - VARIABLE = 2 - CONSTANT = 3 - CONSTANT_DATACLASS = 4 - -@dataclasses.dataclass -class ShaderArgument: - name: str - arg_type: ShaderArgumentType - default_value: Any - shader_name: Union[str, Dict[str, str]] - shader_shape_name: Optional[str] - binding: Optional[int] - - -class ShaderSignature: - arguments: List[ShaderArgument] - variables: List[vc.ShaderVariable] - - def __init__(self): - raise NotImplementedError("This class is not meant to be instantiated") - - @classmethod - def from_argument_list(cls, arguments: Optional[List[ShaderArgument]] = None) -> 'ShaderSignature': - instance = cls.__new__(cls) # Bypasses __init__ - instance.arguments = arguments if arguments is not None else [] - instance.variables = None - return instance - - @classmethod - def from_inspectable_function(cls, builder: vc.ShaderBuilder, func: Callable) -> 'ShaderSignature': - func_signature = inspect.signature(func) - - annotations = [] - names = [] - defaults = [] - - for param in func_signature.parameters.values(): - if param.annotation == inspect.Parameter.empty: - raise ValueError("All parameters must be annotated") - - - if not dataclasses.is_dataclass(param.annotation): # issubclass(param.annotation.__origin__, dataclasses.dataclass): - if not hasattr(param.annotation, '__args__'): - raise TypeError(f"Argument '{param.name}: vd.{param.annotation}' must have a type annotation") - - if len(param.annotation.__args__) != 1: - raise ValueError(f"Type '{param.name}: vd.{param.annotation.__name__}' must have exactly one type argument") - - annotations.append(param.annotation) - names.append(param.name) - defaults.append(param.default if param.default != inspect.Parameter.empty else None) - - return ShaderSignature.from_type_annotations(builder, annotations, names, defaults) - - @classmethod - def from_type_annotations(cls, - builder: vc.ShaderBuilder, - annotations: List, - names: Optional[List[str]] = None, - defaults: Optional[List[Any]] = None) -> 'ShaderSignature': - - instance = cls.__new__(cls) # Bypasses __init__ - instance.arguments = [] - instance.variables = [] - - for i in range(len(annotations)): - shader_param = None - arg_type = None - shape_name = None - binding = None - value_name = None - - if(dataclasses.is_dataclass(annotations[i])): - creation_args: Dict[str, vc.ShaderVariable] = {} - arg_type = ShaderArgumentType.CONSTANT_DATACLASS - value_name = {} - - for field_name, field_type in get_type_hints(annotations[i]).items(): - assert is_dtype(field_type), f"Unsupported type '{field_type}' for field '{annotations[i]}.{field_name}'" - - creation_args[field_name] = builder.declare_constant(field_type) - value_name[field_name] = creation_args[field_name].raw_name - - shader_param = annotations[i](**creation_args) - - elif(issubclass(annotations[i].__origin__, vc.Buffer)): - shader_param = builder.declare_buffer(annotations[i].__args__[0]) - - arg_type = ShaderArgumentType.BUFFER - shape_name = shader_param.shape_name - binding = shader_param.binding - value_name = shader_param.raw_name - - elif(issubclass(annotations[i].__origin__, vc.Image1D)): - shader_param = builder.declare_image(1) - - arg_type = ShaderArgumentType.IMAGE - binding = shader_param.binding - value_name = shader_param.raw_name - - elif(issubclass(annotations[i].__origin__, vc.Image2D)): - shader_param = builder.declare_image(2) - arg_type = ShaderArgumentType.IMAGE - binding = shader_param.binding - value_name = shader_param.raw_name - - elif(issubclass(annotations[i].__origin__, vc.Image3D)): - shader_param = builder.declare_image(3) - arg_type = ShaderArgumentType.IMAGE - binding = shader_param.binding - value_name = shader_param.raw_name - - elif(issubclass(annotations[i].__origin__, vc.Constant)): - shader_param = builder.declare_constant(annotations[i].__args__[0]) - value_name = shader_param.raw_name - arg_type = ShaderArgumentType.CONSTANT - elif(issubclass(annotations[i].__origin__, vc.Variable)): - if builder.backend.name in _PUSH_CONSTANT_UNSUPPORTED_BACKENDS: - raise NotImplementedError(_push_constant_not_supported_error(builder.backend.name)) - - shader_param = builder.declare_variable(annotations[i].__args__[0]) - arg_type = ShaderArgumentType.VARIABLE - value_name = shader_param.raw_name - - else: - raise ValueError(f"Unsupported type '{annotations[i].__args__[0]}'") - - instance.variables.append(shader_param) - - instance.arguments.append(ShaderArgument( - names[i] if names is not None else f"param{i}", - arg_type, - defaults[i] if defaults is not None else None, - value_name, - shape_name, - binding - )) - - return instance - - def get_variables(self) -> List[vc.ShaderVariable]: - return self.variables - - def get_names_and_defaults(self) -> List[Tuple[str, Any]]: - return [(arg.name, arg.default_value) for arg in self.arguments] diff --git a/vkdispatch/vkfft/__init__.py b/vkdispatch/vkfft/__init__.py index 2d96d064..633dd90a 100644 --- a/vkdispatch/vkfft/__init__.py +++ b/vkdispatch/vkfft/__init__.py @@ -4,6 +4,4 @@ from .vkfft_dispatcher import ifft, ifft2, ifft3 from .vkfft_dispatcher import rfft, rfft2, rfft3 from .vkfft_dispatcher import irfft, irfft2, irfft3 -from .vkfft_dispatcher import clear_plan_cache, convolve2D, transpose_kernel2D -#from .fft_dispatcher import ifft, irfft, create_kernel_2Dreal, convolve_2Dreal -#from .fft_dispatcher import reset_fft_plans \ No newline at end of file +from .vkfft_dispatcher import clear_plan_cache, convolve2D, transpose_kernel2D \ No newline at end of file diff --git a/vkdispatch/vkfft/vkfft_dispatcher.py b/vkdispatch/vkfft/vkfft_dispatcher.py index e289293b..339efd55 100644 --- a/vkdispatch/vkfft/vkfft_dispatcher.py +++ b/vkdispatch/vkfft/vkfft_dispatcher.py @@ -106,137 +106,6 @@ def execute_fft_plan( if graph.submit_on_record: graph.submit() -def sanitize_2d_convolution_buffer_shape(in_shape: vd.Buffer): - if in_shape is None: - return None - - in_shape = in_shape.shape - - assert len(in_shape) == 2 or len(in_shape) == 3, "Input shape must be 2D or 3D!" - - if len(in_shape) == 2: - return (1, in_shape[0], in_shape[1]) - - return in_shape - -def convolve_2Dreal( - buffer: vd.RFFTBuffer, - kernel: Union[vd.Buffer[vd.float32], vd.RFFTBuffer], - input: Union[vd.Buffer[vd.float32], vd.RFFTBuffer] = None, - normalize: bool = False, - conjugate_kernel: bool = False, - graph: Optional[vd.CommandGraph] = None, - keep_shader_code: bool = False): - - buffer_shape = sanitize_2d_convolution_buffer_shape(buffer) - kernel_shape = sanitize_2d_convolution_buffer_shape(kernel) - - assert buffer_shape == kernel_shape, f"Buffer ({buffer_shape}) and Kernel ({kernel_shape}) shapes must match!" - - input_shape = sanitize_2d_convolution_buffer_shape(input) - - kernel_count = 1 - feature_count = 1 - - if input_shape is not None: - assert buffer_shape[0] % input_shape[0] == 0, f"Output count ({buffer_shape[0]}) must be divisible by input count ({input_shape[0]})!" - kernel_count = buffer_shape[0] // input_shape[0] - feature_count = input_shape[0] - else: - feature_count = buffer.shape[0] - - execute_fft_plan( - buffer, - False, - graph = graph, - config = FFTConfig( - buffer_handle=buffer._handle, - shape=sanitize_input_tuple(buffer.real_shape), - do_r2c=True, - normalize=normalize, - kernel_count=kernel_count, - conjugate_convolution=conjugate_kernel, - input_shape=sanitize_input_tuple(input.shape if input is not None else None), - input_type=input.var_type if input is not None else None, - convolution_features=feature_count, - keep_shader_code=keep_shader_code - ), - kernel=kernel, - input=input - ) - -def create_kernel_2Dreal( - kernel: vd.RFFTBuffer, - shape: Tuple[int, ...] = None, - feature_count: int = 1, - graph: Optional[vd.CommandGraph] = None, - keep_shader_code: bool = False) -> vd.RFFTBuffer: - - if shape is None: - shape = kernel.shape - - if len(shape) == 2: - assert feature_count == 1, "Feature count must be 1 for 2D kernels!" - shape = (1,) + shape - - execute_fft_plan( - kernel, - False, - graph = graph, - config = FFTConfig( - buffer_handle=kernel._handle, - shape=sanitize_input_tuple(kernel.real_shape), - do_r2c=True, - kernel_convolution=True, - convolution_features=feature_count, - num_batches=shape[0] // feature_count, - keep_shader_code=keep_shader_code - ) - ) - - return kernel - -def convolve_2D( - buffer: vd.Buffer, - kernel: vd.Buffer, - normalize: bool = False, - conjugate_kernel: bool = False, - graph: Optional[vd.CommandGraph] = None, - keep_shader_code: bool = False, - padding: Tuple[Tuple[int, int]] = None): - - buffer_shape = sanitize_2d_convolution_buffer_shape(buffer) - kernel_shape = sanitize_2d_convolution_buffer_shape(kernel) - - # assert buffer_shape == kernel_shape, f"Buffer ({buffer_shape}) and Kernel ({kernel_shape}) shapes must match!" - - kernel_count = kernel.shape[0] if len(kernel.shape) == 3 else 1 - feature_count = 1 - - if kernel_count > 1: - feature_count = buffer.shape[0] - - in_shape = sanitize_input_tuple(buffer.shape) - - execute_fft_plan( - buffer, - False, - graph = graph, - config = FFTConfig( - buffer_handle=buffer._handle, - shape=in_shape[1:], # if kernel_count == 1 else in_shape, - normalize=normalize, - kernel_count=1, #kernel_count, - conjugate_convolution=conjugate_kernel, - convolution_features=1, #feature_count, - keep_shader_code=keep_shader_code, - num_batches=buffer.shape[0], # if kernel_count == 1 else 1, - padding=padding - ), - kernel=kernel - ) - - def transpose_kernel2D( kernel: vd.Buffer, shape: Tuple[int, ...] = None, @@ -248,8 +117,9 @@ def transpose_kernel2D( if len(shape) == 2: shape = (1,) + shape - assert len(shape) == 3, "Kernel shape must be 2D or 3D!" - + if len(shape) != 3: + raise ValueError('Kernel shape must be 2D or 3D!') + execute_fft_plan( kernel, False, @@ -338,14 +208,16 @@ def fft( ) def fft2(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False): - assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer.shape) < 2: + raise ValueError('Buffer must have at least 2 dimensions') axes = (len(buffer.shape) - 2, len(buffer.shape) - 1) fft(buffer, graph=graph, print_shader=print_shader, axis=axes) def fft3(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False): - assert len(buffer.shape) == 3, 'Buffer must have 3 dimensions' + if len(buffer.shape) < 3: + raise ValueError('Buffer must have at least 3 dimensions') fft(buffer, graph=graph, print_shader=print_shader, axis=(0, 1, 2)) @@ -358,14 +230,16 @@ def ifft( fft(buffer, graph=graph, print_shader=print_shader, axis=axis, inverse=True, normalize_inverse=normalize) def ifft2(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False): - assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer.shape) < 2: + raise ValueError('Buffer must have at least 2 dimensions') axes = (len(buffer.shape) - 2, len(buffer.shape) - 1) ifft(buffer, graph=graph, print_shader=print_shader, axis=axes) def ifft3(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False): - assert len(buffer.shape) == 3, 'Buffer must have 3 dimensions' + if len(buffer.shape) != 3: + raise ValueError('Buffer must have 3 dimensions') ifft(buffer, graph=graph, print_shader=print_shader, axis=(0, 1, 2)) @@ -374,13 +248,15 @@ def rfft(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: boo fft(buffer, buffer_shape=buffer.real_shape, graph=graph, print_shader=print_shader, r2c=True, axis=axis) def rfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False): - assert len(buffer.real_shape) == 2 or len(buffer.real_shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer.real_shape) < 2: + raise ValueError('Buffer must have at least 2 dimensions') axes = (len(buffer.shape) - 2, len(buffer.shape) - 1) rfft(buffer, graph=graph, print_shader=print_shader, axis=axes) def rfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False): - assert len(buffer.real_shape) == 3, 'Buffer must have 3 dimensions' + if len(buffer.real_shape) != 3: + raise ValueError('Buffer must have 3 dimensions') rfft(buffer, graph=graph, print_shader=print_shader, axis=(0, 1, 2)) @@ -388,12 +264,14 @@ def irfft(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bo fft(buffer, buffer_shape=buffer.real_shape, graph=graph, print_shader=print_shader, inverse=True, normalize_inverse=normalize, r2c=True, axis=axis) def irfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False): - assert len(buffer.real_shape) == 2 or len(buffer.real_shape) == 3, 'Buffer must have 2 or 3 dimensions' + if len(buffer.real_shape) < 2: + raise ValueError('Buffer must have at least 2 dimensions') axes = (len(buffer.shape) - 2, len(buffer.shape) - 1) irfft(buffer, graph=graph, print_shader=print_shader, axis=axes) def irfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False): - assert len(buffer.real_shape) == 3, 'Buffer must have 3 dimensions' + if len(buffer.real_shape) != 3: + raise ValueError('Buffer must have 3 dimensions') irfft(buffer, graph=graph, print_shader=print_shader, axis=(0, 1, 2)) diff --git a/vkdispatch/vkfft/vkfft_plan.py b/vkdispatch/vkfft/vkfft_plan.py index 0ad12dea..72ef1fda 100644 --- a/vkdispatch/vkfft/vkfft_plan.py +++ b/vkdispatch/vkfft/vkfft_plan.py @@ -30,7 +30,8 @@ def __init__(self, keep_shader_code: bool = False): super().__init__() - assert len(shape) > 0 and len(shape) < 4, "shape must be 1D, 2D, or 3D" + if len(shape) == 0 or len(shape) > 3: + raise ValueError("Shape must be 1D, 2D, or 3D!") self.shape = shape self.do_r2c = do_r2c diff --git a/vkdispatch_native/context/context.cpp b/vkdispatch_native/context/context.cpp index 4d92935c..47fe0d3f 100644 --- a/vkdispatch_native/context/context.cpp +++ b/vkdispatch_native/context/context.cpp @@ -241,47 +241,6 @@ void signal_destroy_extern(void* signal_ptr) { delete signal; } - -// void wait_for_queue(struct Context* ctx, int queue_index) { -// LOG_INFO("Waiting for queue %d to finish execution...", queue_index); - -// uint64_t* p_timestamp = new uint64_t(); -// Signal* signal = new Signal(ctx); - -// *p_timestamp = 0; - -// context_submit_command(ctx, "queue-wait-idle", queue_index, RECORD_TYPE_SYNC, -// [ctx, signal, p_timestamp](VkCommandBuffer cmd_buffer, ExecIndicies indicies, void* pc_data, BarrierManager* barrier_manager, uint64_t timestamp){ -// LOG_VERBOSE("Waiting for queue %d to finish execution...", indicies.queue_index); -// *p_timestamp = timestamp; -// signal->notify(timestamp); -// } -// ); - -// signal->wait(); - -// if(*p_timestamp == 0) { -// if (ctx->running.load(std::memory_order_acquire)) -// LOG_WARNING("Queue %d did not finish execution", queue_index); -// } else { -// LOG_INFO("Queue %d finished execution at timestamp %llu", queue_index, *p_timestamp); -// } - -// ctx->queues[queue_index]->wait_for_timestamp(*p_timestamp); - -// delete signal; -// } - -// bool context_queue_wait_idle_extern(struct Context* context, int queue_index) { -// if(queue_index == -1) { -// for(int i = 0; i < context->queues.size(); i++) { -// wait_for_queue(context, i); -// } -// } else { -// wait_for_queue(context, queue_index); -// } -// } - void context_submit_command( Context* context, const char* name, diff --git a/vkdispatch_native/context/context_extern.pxd b/vkdispatch_native/context/context_extern.pxd index 873a38b7..f222e531 100644 --- a/vkdispatch_native/context/context_extern.pxd +++ b/vkdispatch_native/context/context_extern.pxd @@ -156,7 +156,8 @@ cpdef inline get_devices(): return device_list cpdef inline context_create(list[int] device_indicies, list[list[int]] queue_families): - assert len(device_indicies) == len(queue_families) + if len(device_indicies) != len(queue_families): + raise ValueError(f"Length of device_indicies ({len(device_indicies)}) must match length of queue_families ({len(queue_families)})") cdef int len_device_indicies = len(device_indicies) cdef int* device_indicies_c = malloc(len_device_indicies * sizeof(int)) @@ -177,8 +178,12 @@ cpdef inline context_create(list[int] device_indicies, list[list[int]] queue_fam for j in range(queue_counts_c[i]): queue_families_c[current_index] = queue_families[i][j] current_index += 1 - - assert current_index == total_queue_count + + if current_index != total_queue_count: + free(device_indicies_c) + free(queue_counts_c) + free(queue_families_c) + raise ValueError(f"Total queue count mismatch: expected {total_queue_count}, got {current_index}") cdef unsigned long long result = context_create_extern(device_indicies_c, queue_counts_c, queue_families_c, len_device_indicies) diff --git a/vkdispatch_native/context/init.cpp b/vkdispatch_native/context/init.cpp index 86ef05f2..fedd6fdc 100644 --- a/vkdispatch_native/context/init.cpp +++ b/vkdispatch_native/context/init.cpp @@ -80,11 +80,6 @@ static VkBool32 VKAPI_PTR vulkan_custom_debug_callback( log_message(log_level, "\n", "Vulkan", 0, pCallbackData->pMessage); - if(log_level == LOG_LEVEL_ERROR) { - exit(1); - } - - //printf("%s", pCallbackData->pMessage); return VK_FALSE; } diff --git a/vkdispatch_native/objects/buffer.cpp b/vkdispatch_native/objects/buffer.cpp index ede3347d..1f56d1bf 100644 --- a/vkdispatch_native/objects/buffer.cpp +++ b/vkdispatch_native/objects/buffer.cpp @@ -95,9 +95,6 @@ void buffer_destroy_extern(struct Buffer* buffer) { uint64_t signals_pointers_handle = buffer->signals_pointers_handle; Signal* signal = (Signal*)ctx->handle_manager->get_handle(queue_index, signals_pointers_handle, 0); - // wait for the recording thread to finish - //signal->wait(); - ctx->handle_manager->destroy_handle(queue_index, buffer->signals_pointers_handle); delete signal; @@ -188,21 +185,8 @@ void buffer_write_extern(struct Buffer* buffer, unsigned long long offset, unsig uint64_t signals_pointers_handle = buffer->signals_pointers_handle; Signal* signal = (Signal*)ctx->handle_manager->get_handle(queue_index, signals_pointers_handle, 0); - // wait for the recording thread to finish - //signal->wait(); signal->reset(); - // wait for the staging buffer to be ready - // uint64_t staging_buffer_timestamp = ctx->handle_manager->get_handle_timestamp(queue_index, buffer->staging_buffers_handle); - // ctx->queues[queue_index]->wait_for_timestamp(staging_buffer_timestamp); - - // VmaAllocation staging_allocation = (VmaAllocation)ctx->handle_manager->get_handle(queue_index, buffer->staging_allocations_handle, 0); - - // void* mapped; - // VK_CALL(vmaMapMemory(ctx->allocators[device_index], staging_allocation, &mapped)); - // memcpy(mapped, data, size); - // vmaUnmapMemory(ctx->allocators[device_index], staging_allocation); - uint64_t buffers_handle = buffer->buffers_handle; uint64_t staging_buffers_handle = buffer->staging_buffers_handle; @@ -262,8 +246,6 @@ void buffer_read_extern(struct Buffer* buffer, unsigned long long offset, unsign uint64_t signals_pointers_handle = buffer->signals_pointers_handle; Signal* signal = (Signal*)ctx->handle_manager->get_handle(queue_index, signals_pointers_handle, 0); - // wait for the recording thread to finish - //signal->wait(); signal->reset(); uint64_t buffers_handle = buffer->buffers_handle; @@ -315,20 +297,4 @@ void buffer_read_extern(struct Buffer* buffer, unsigned long long offset, unsign signal->notify(indicies.queue_index, timestamp); } ); - - // wait for the recording thread to finish again - // signal->wait(); - - // // wait for the staging buffer to be ready - // uint64_t staging_buffer_timestamp = ctx->handle_manager->get_handle_timestamp(queue_index, buffer->staging_buffers_handle); - // ctx->queues[queue_index]->wait_for_timestamp(staging_buffer_timestamp); - - // int device_index = ctx->queues[queue_index]->device_index; - - // VmaAllocation staging_allocation = (VmaAllocation)ctx->handle_manager->get_handle(queue_index, buffer->staging_allocations_handle, 0); - - // void* mapped; - // VK_CALL(vmaMapMemory(ctx->allocators[device_index], staging_allocation, &mapped)); - // memcpy(data, mapped, size); - // vmaUnmapMemory(ctx->allocators[device_index], staging_allocation); } diff --git a/vkdispatch_native/objects/objects_extern.pxd b/vkdispatch_native/objects/objects_extern.pxd index cbefeed7..170b45e4 100644 --- a/vkdispatch_native/objects/objects_extern.pxd +++ b/vkdispatch_native/objects/objects_extern.pxd @@ -67,10 +67,6 @@ cdef extern from "objects/objects_extern.hh": void image_write_extern(Image* image, void* data, VkOffset3D offset, VkExtent3D extent, unsigned int baseLayer, unsigned int layerCount, int device_index) void image_read_extern(Image* image, void* data, VkOffset3D offset, VkExtent3D extent, unsigned int baseLayer, unsigned int layerCount, int device_index) - #void image_copy_extern(Image* src, Image* dst, VkOffset3D src_offset, unsigned int src_baseLayer, unsigned int src_layerCount, - # VkOffset3D dst_offset, unsigned int dst_baseLayer, unsigned int dst_layerCount, VkExtent3D extent, int device_index) - - cpdef inline buffer_create(unsigned long long context, unsigned long long size, int per_device): return buffer_create_extern(context, size, per_device) @@ -152,7 +148,8 @@ cpdef inline descriptor_set_write_image( descriptor_set_write_image_extern(ds, binding, object, sampler_obj, read_access, write_access) cpdef inline image_create(unsigned long long context, tuple[unsigned int, unsigned int, unsigned int] extent, unsigned int layers, unsigned int format, unsigned int type, unsigned int view_type, unsigned int generate_mips): - assert len(extent) == 3 + if len(extent) != 3: + raise ValueError("Extent must be a tuple of three unsigned integers (width, height, depth)") cdef unsigned int width = extent[0] cdef unsigned int height = extent[1] @@ -170,8 +167,11 @@ cpdef inline image_destroy_sampler(unsigned long long sampler): image_destroy_sampler_extern(sampler) cpdef inline image_write(unsigned long long image, bytes data, tuple[int, int, int] offset, tuple[unsigned int, unsigned int, unsigned int] extent, unsigned int baseLayer, unsigned int layerCount, int device_index): - assert len(offset) == 3 - assert len(extent) == 3 + if len(offset) != 3: + raise ValueError("Offset must be a tuple of three integers (x, y, z)") + + if len(extent) != 3: + raise ValueError("Extent must be a tuple of three unsigned integers (width, height, depth)") cdef int x = offset[0] cdef int y = offset[1] @@ -189,8 +189,11 @@ cpdef inline unsigned int image_format_block_size(unsigned int format): return image_format_block_size_extern(format) cpdef inline image_read(unsigned long long image, unsigned long long out_size, tuple[int, int, int] offset, tuple[unsigned int, unsigned int, unsigned int] extent, unsigned int baseLayer, unsigned int layerCount, int device_index): - assert len(offset) == 3 - assert len(extent) == 3 + if len(offset) != 3: + raise ValueError("Offset must be a tuple of three integers (x, y, z)") + + if len(extent) != 3: + raise ValueError("Extent must be a tuple of three unsigned integers (width, height, depth)") cdef int x = offset[0] cdef int y = offset[1] diff --git a/vkdispatch_native/stages/stage_fft.cpp b/vkdispatch_native/stages/stage_fft.cpp index f0b98bc2..f9908d37 100644 --- a/vkdispatch_native/stages/stage_fft.cpp +++ b/vkdispatch_native/stages/stage_fft.cpp @@ -222,8 +222,8 @@ struct FFTPlan* stage_fft_plan_create_extern( config.maxComputeWorkGroupSize[2] = resource->max_compute_work_group_size_z; config.isCompilerInitialized = true; - config.glslang_mutex = &ctx->glslang_mutex; - config.queue_mutex = &ctx->queues[indicies.queue_index]->queue_usage_mutex; + // config.glslang_mutex = NULL;// &ctx->glslang_mutex; + // config.queue_mutex = NULL; //&ctx->queues[indicies.queue_index]->queue_usage_mutex; config.physicalDevice = &ctx->physicalDevices[indicies.device_index]; config.device = &ctx->devices[indicies.device_index]; config.queue = &ctx->queues[indicies.queue_index]->queue; @@ -250,8 +250,13 @@ struct FFTPlan* stage_fft_plan_create_extern( LOG_VERBOSE("Doing FFT Init"); VkFFTApplication* application = new VkFFTApplication(); - + + ctx->glslang_mutex.lock(); + ctx->queues[indicies.queue_index]->queue_usage_mutex.lock(); VkFFTResult resFFT = initializeVkFFT(application, config); + ctx->queues[indicies.queue_index]->queue_usage_mutex.unlock(); + ctx->glslang_mutex.unlock(); + if (resFFT != VKFFT_SUCCESS) { set_error("(VkFFTResult is %d) initializeVkFFT inside '%s' at %s:%d\n", resFFT, __FUNCTION__, __FILE__, __LINE__); } @@ -335,8 +340,10 @@ void stage_fft_record_extern( } VkFFTApplication* application = (VkFFTApplication*)ctx->handle_manager->get_handle(index, vkfft_applications_handle, timestamp); - + + ctx->queues[indicies.queue_index]->queue_usage_mutex.lock(); VkFFTResult fftRes = VkFFTAppend(application, inverse, &launchParams); + ctx->queues[indicies.queue_index]->queue_usage_mutex.unlock(); if (fftRes != VKFFT_SUCCESS) { set_error("(VkFFTResult is %d) VkFFTAppend inside '%s' at %s:%d\n", fftRes, __FUNCTION__, __FILE__, __LINE__); diff --git a/vkdispatch_native/stages/stages_extern.pxd b/vkdispatch_native/stages/stages_extern.pxd index 86539f4c..67349e9c 100644 --- a/vkdispatch_native/stages/stages_extern.pxd +++ b/vkdispatch_native/stages/stages_extern.pxd @@ -114,11 +114,16 @@ cpdef inline stage_fft_plan_create( int num_batches, bool single_kernel_multiple_batches, bool keep_shader_code): - assert len(dims) > 0 and len(dims) < 4, "dims must be a list of length 1, 2, or 3" - assert len(axes) <= 3, "axes must be a list of length less than or equal to 3" + + if len(dims) == 0 or len(dims) > 3: + raise ValueError("dims must be a list of length 1, 2, or 3") + + if len(axes) > 3: + raise ValueError("axes must be a list of length less than or equal to 3") for ax in axes: - assert ax < len(dims), "axes must be less than the length of dims" + if ax < 0 or ax >= len(dims): + raise ValueError("axes must be less than the length of dims") cdef Context* ctx = context cdef unsigned long long dims_ = len(dims) @@ -137,8 +142,9 @@ cpdef inline stage_fft_plan_create( if 0 <= axes[i] < 3: # Ensure the index is within bounds omits__[axes[i]] = 0 else: - print("Invalid axis index: ", axes[i]) - sys.exit(1) + free(dims__) + free(omits__) + raise ValueError("Axis index out of bounds. Must be between 0 and 2 inclusive.") cdef FFTPlan* plan = stage_fft_plan_create_extern( ctx, @@ -161,6 +167,7 @@ cpdef inline stage_fft_plan_create( 1 if keep_shader_code else 0) free(dims__) + free(omits__) return plan