diff --git a/.github/workflows/deploy_docs.yml b/.github/workflows/deploy_docs.yml index 77badf55..d2e25c74 100644 --- a/.github/workflows/deploy_docs.yml +++ b/.github/workflows/deploy_docs.yml @@ -42,7 +42,8 @@ jobs: # Always install sphinx and required extensions python -m pip install \ "sphinx>=7,<9" \ - sphinx-rtd-theme + sphinx-rtd-theme \ + "brython==3.12.*" pip install numpy diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 7d0aa64b..94124a01 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v4 @@ -36,7 +36,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest + python -m pip install pytest numpy python fetch_dependencies.py python -m pip install . #- name: Setup tmate session diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index d1c39dae..f6f99017 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -14,38 +14,13 @@ on: jobs: - #build_mac_and_windows: - # name: Build Python Package - # runs-on: ${{ matrix.os }} - # strategy: - # fail-fast: false - # matrix: - # os: [windows-latest, macos-latest] - # python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] - - # steps: - # - uses: actions/checkout@v4 - # - name: Set up Python ${{ matrix.python-version }} - # uses: actions/setup-python@v3 - # with: - # python-version: ${{ matrix.python-version }} - # - name: Install dependencies - # run: | - # python -m pip install --upgrade pip - # python fetch_dependencies.py - # python -m pip install build - # python -m build - # - name: Store the distribution packages - # uses: actions/upload-artifact@v3 - # with: - # name: python-package-distributions - # path: dist/ - build_wheels: - name: Build wheels on ${{ matrix.os }} + + build_native_wheels: + name: Build native wheels on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-13, macos-14] + os: [ubuntu-latest, windows-latest, macos-15-intel, macos-15] steps: - uses: actions/checkout@v4 @@ -53,15 +28,17 @@ jobs: # Used to host cibuildwheel - uses: actions/setup-python@v5 - - name: Install cibuildwheel + - name: Install cibuildwheel and native deps run: | python -m pip install --upgrade pip - python -m pip install cibuildwheel==2.23.3 + python -m pip install cibuildwheel==3.2.1 python fetch_dependencies.py - - name: Build wheels + - name: Build native wheels env: CIBW_SKIP: 'pp* manylinux_i686 musllinux*' + VKDISPATCH_BUILD_TARGET: native + CIBW_ENVIRONMENT: VKDISPATCH_BUILD_TARGET=native run: python -m cibuildwheel --output-dir wheelhouse # to supply options, put them in 'env', like: @@ -72,28 +49,44 @@ jobs: with: name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} path: ./wheelhouse/*.whl - build_sdist: - name: Build source distribution + build_python_dists: + name: Build native/core/meta sdists and pure wheels runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Install dependencies + - uses: actions/setup-python@v5 + + - name: Install build tooling run: | python -m pip install --upgrade pip + python -m pip install build + + - name: Build native source distribution + env: + VKDISPATCH_BUILD_TARGET: native + run: | python fetch_dependencies.py + python -m build --sdist --outdir dist + + - name: Build core wheel and source distribution + env: + VKDISPATCH_BUILD_TARGET: core + run: python -m build --wheel --sdist --outdir dist - - name: Build sdist - run: pipx run build --sdist + - name: Build meta wheel and source distribution + env: + VKDISPATCH_BUILD_TARGET: meta + run: python -m build --wheel --sdist --outdir dist - uses: actions/upload-artifact@v4 with: - name: cibw-sdist - path: dist/*.tar.gz + name: cibw-python-dists + path: dist/* publish-to-pypi: name: Publish Python package to PyPI # if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes - needs: [build_wheels, build_sdist] + needs: [build_native_wheels, build_python_dists] runs-on: ubuntu-latest environment: name: pypi diff --git a/.gitignore b/.gitignore index 654ae238..576b8d8c 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,11 @@ __pycache__/ data/ deps/ +codebase.txt + +docs/special_pages/libs/vkdispatch +docs/special_pages/libs/vkdispatch.brython.js + *.png *.csv *.exec diff --git a/docs/Makefile b/docs/Makefile index d4bb2cbb..4c660da8 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -5,14 +5,39 @@ # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build +PYTHON ?= python SOURCEDIR = . BUILDDIR = _build +# Define destination and filename for the Brython package bundle +LIB_DEST = special_pages/libs +LIB_BUNDLE = vkdispatch.brython.js +LIB_STAGE = $(LIB_DEST)/.vkdispatch_stage + # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -.PHONY: help Makefile +.PHONY: help Makefile bundle_lib + +# Target to bundle the library into a single Brython package file +bundle_lib: + @echo "Bundling vkdispatch for Brython..." + @$(PYTHON) -c "import brython" > /dev/null + @rm -rf "$(LIB_DEST)/vkdispatch" + @mkdir -p "$(LIB_DEST)" + @rm -f "$(LIB_DEST)/$(LIB_BUNDLE)" + @rm -rf "$(LIB_STAGE)" + @mkdir -p "$(LIB_STAGE)" + @cp -r ../vkdispatch "$(LIB_STAGE)/vkdispatch" + @cd "$(LIB_STAGE)" && $(PYTHON) -m brython make_package vkdispatch \ + --src-dir . \ + --output-path "$(CURDIR)/$(LIB_DEST)/$(LIB_BUNDLE)" + @rm -rf "$(LIB_STAGE)" + +# Intercept the "html" target to run bundle_lib first +html: bundle_lib + @$(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). diff --git a/docs/conf.py b/docs/conf.py index 0bff39f5..9abc2f5a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -57,3 +57,4 @@ html_theme = 'alabaster' html_static_path = ['_static'] +html_extra_path = ['special_pages'] diff --git a/docs/getting_started.rst b/docs/getting_started.rst index ecdf9b2f..79cdf173 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -76,7 +76,8 @@ Next Steps Now that you've got `vkdispatch` up and running, consider exploring the following: +* :doc:`Code Structure and Execution Flow`: A guided tour of how Python, codegen, and native layers fit together. * :doc:`Tutorials`: Our curated guide to the most commonly used classes and functions. * :doc:`Full Python API Reference`: A comprehensive list of all Python-facing components. -Happy GPU programming! \ No newline at end of file +Happy GPU programming! diff --git a/docs/index.rst b/docs/index.rst index 13302d57..fdab93aa 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,6 +11,7 @@ Welcome to vkdispatch's documentation! Welcome to the vkdispatch documentation website! To learn how to install vkdispatch, go to the :doc:`Getting Started` Section. +To understand the internals and module layout, start with :doc:`Code Structure and Execution Flow`. Additionally, below are a set of tutorials on vkdispatch usage and a full API reference. @@ -24,6 +25,11 @@ Additionally, below are a set of tutorials on vkdispatch usage and a full API re Tutorials +.. toctree:: + :maxdepth: 2 + + Special + .. toctree:: :maxdepth: 2 diff --git a/docs/internal_api.rst b/docs/internal_api.rst index 1ce0889a..a7d72195 100644 --- a/docs/internal_api.rst +++ b/docs/internal_api.rst @@ -9,4 +9,4 @@ and the underlying C++/Cython implementation. :maxdepth: 2 python_api -.. cpp_api \ No newline at end of file + cpp_api diff --git a/docs/special/brython_shader_lab.rst b/docs/special/brython_shader_lab.rst new file mode 100644 index 00000000..aeeffe87 --- /dev/null +++ b/docs/special/brython_shader_lab.rst @@ -0,0 +1,16 @@ +Brython Shader Lab +================== + +This page redirects to a standalone HTML app page. + +.. raw:: html + + + +

+ Redirecting to the Brython shader lab page. + If you are not redirected, open + the standalone HTML page. +

diff --git a/docs/special/index.rst b/docs/special/index.rst new file mode 100644 index 00000000..da840951 --- /dev/null +++ b/docs/special/index.rst @@ -0,0 +1,9 @@ +Special Pages +============= + +Standalone pages integrated into the docs navigation. + +.. toctree:: + :maxdepth: 1 + + brython_shader_lab diff --git a/docs/special_pages/brython_shader_lab.html b/docs/special_pages/brython_shader_lab.html new file mode 100644 index 00000000..22404647 --- /dev/null +++ b/docs/special_pages/brython_shader_lab.html @@ -0,0 +1,1073 @@ + + + + + + VkDispatch Shader Playground + + + + + + + + + + + + + + +
+ + + +
+ + +
+

VkDispatch Shader Playground

+ + +
+
+ + +
+
+ Help + +
+
+

+ This web-based shader compiler is designed for rapid shader-authoring workflows: write Python in the left pane, + run it in the browser, and inspect generated output on the right. It is especially useful + for researchers who want to iterate on kernel structure and inspect code generation without + switching to a full native setup. +

+
+

+ The key feature is shader visibility: when you print a decorated shader function (for + example, print(add_scalar)), the panel shows the generated GLSL. This makes + it easy to validate indexing logic, control flow, and type usage directly from the + high-level Python definition. +

+
+

+ The Options panel controls a dummy device model, not your physical GPU. You can adjust + limits such as subgroup size, workgroup limits, and shared memory to test how your shader + configuration behaves under different device constraints. +

+
+

+ Most standard vkdispatch APIs are available in this environment (buffers, + images, descriptor bindings, and dispatch calls), but this page is intended for codegen and + interface exploration. Many operations are simulated, and dispatch execution is not a real + GPU compute run. +

+
+
+ + +
+
+ VkDispatch Device Parameters + +
+
+
+ + +
+ +
+ +
+ + + +
+
+
+ + +
+
+ +
+ + + +
+
+
+ + +
+
+
+ +
+
+
Code
+ +
+
+
+
Output
+ +
+
+ + + + + + + diff --git a/docs/tutorials/code_structure.rst b/docs/tutorials/code_structure.rst new file mode 100644 index 00000000..b05cb6fe --- /dev/null +++ b/docs/tutorials/code_structure.rst @@ -0,0 +1,110 @@ +Code Structure and Execution Flow +================================= + +This page explains how the vkdispatch repository is organized and how a Python call +is translated into GPU work. If you are extending the project or debugging behavior, +this should be your first stop. + +In normal usage, ``vkdispatch`` will call ``initialize()`` and ``make_context()`` +automatically the first time you invoke most runtime APIs. You only need to call +them manually if you want non-default settings (for example debug logging, custom +device selection, or multi-queue behavior). + +Repository Layout +----------------- + +Top-level folders you will use most often: + +* ``vkdispatch/``: Public Python API and high-level runtime logic. +* ``vkdispatch_native/``: Native C++/Cython backend called by the Python layer. +* ``tests/``: End-to-end usage examples and regression coverage. +* ``docs/``: Sphinx docs (this site). +* ``deps/``: Third-party dependencies used for source builds. + +Python Package Layout +--------------------- + +Inside ``vkdispatch/``, modules are grouped by responsibility: + +* ``vkdispatch/base``: Core runtime objects and Vulkan-facing wrappers. + + * ``init.py``: Vulkan instance/device discovery and initialization. + * ``context.py``: Global context creation, queue/device selection, lifecycle. + * ``buffer.py`` / ``image.py``: GPU data containers. + * ``compute_plan.py`` / ``descriptor_set.py`` / ``command_list.py``: Low-level execution objects. + +* ``vkdispatch/shader``: Python-to-shader front-end. + + * ``decorator.py``: ``@vd.shader`` entry point. + * ``signature.py``: Type-annotated argument parsing and shader signature building. + * ``shader_function.py``: Build, bind, and dispatch compiled shader functions. + * ``map.py``: Mapping-function abstraction shared by FFT/reduction paths. + +* ``vkdispatch/codegen``: GLSL code generation utilities and typed shader variables. + +* ``vkdispatch/execution_pipeline``: Higher-level command recording. + + * ``command_graph.py``: ``CommandGraph`` wrapper over ``CommandList`` with automatic buffer/constant management. + +* ``vkdispatch/reduce``: Reduction decorators and staged reduction pipeline generation. + +* ``vkdispatch/fft`` and ``vkdispatch/vkfft``: FFT/convolution front-ends. + + * ``fft``: vkdispatch shader-generated FFT path. + * ``vkfft``: VkFFT-backed path with plan caching. + +Native Backend Layout +--------------------- + +The compiled extension module is built from ``vkdispatch_native/``: + +* ``wrapper.pyx``: Cython bridge exposing native entry points to Python. +* ``context/``: Device/context creation and global state. +* ``objects/``: Native Buffer/Image/DescriptorSet/CommandList objects. +* ``stages/``: Compute/FFT stage planning and recording. +* ``queue/``: Queue management, signals, and barriers. +* ``libs/``: Third-party integration glue (Volk, VMA). + +During execution, most Python API methods forward to ``vkdispatch_native`` and then +call error checks to surface native failures as Python exceptions. + +End-to-End Runtime Flow +----------------------- + +Typical call path for a shader dispatch: + +1. First vkdispatch runtime call triggers ``initialize()`` and ``make_context()`` (unless you called them manually first). +2. ``@vd.shader`` wraps a Python function and records typed operations via ``vkdispatch.codegen``. +3. ``ShaderFunction.build()`` generates GLSL and creates a ``ComputePlan``. +4. A ``CommandGraph`` (default or explicit) records bindings and dispatch dimensions. +5. ``CommandGraph.submit()`` submits the command list to selected queue(s). +6. Data is read back with ``Buffer.read()`` or ``Image.read()``. + +Minimal Example (API Layer View) +-------------------------------- + +.. code-block:: python + + import numpy as np + import vkdispatch as vd + import vkdispatch.codegen as vc + from vkdispatch.codegen.abreviations import * + + @vd.shader("data.size") + def scale_inplace(data: Buff[f32], alpha: Const[f32]): + tid = vc.global_invocation_id().x + data[tid] = data[tid] * alpha + + arr = np.arange(16, dtype=np.float32) + buf = vd.asbuffer(arr) + scale_inplace(buf, 2.0) + + out = buf.read(0) + print(out) # [0, 2, 4, ...] + +Related Tutorials +----------------- + +* :doc:`Context System ` +* :doc:`Shader Authoring and Dispatch ` +* :doc:`Command Graph Recording ` diff --git a/docs/tutorials/command_graph_tutorial.rst b/docs/tutorials/command_graph_tutorial.rst new file mode 100644 index 00000000..51cdf98f --- /dev/null +++ b/docs/tutorials/command_graph_tutorial.rst @@ -0,0 +1,84 @@ +Command Graph Recording +======================= + +``CommandGraph`` is the high-level recording API in vkdispatch. It lets you queue +multiple shader dispatches and submit them together, with automatic descriptor/uniform +handling. + +When to Use a CommandGraph +-------------------------- + +Use ``CommandGraph`` when you want: + +* Multiple dispatches in one recorded sequence. +* Explicit control over when work is submitted. +* Lower overhead than immediate submit-per-call flows. + +Single Graph, Multiple Dispatches +--------------------------------- + +.. code-block:: python + + import numpy as np + import vkdispatch as vd + import vkdispatch.codegen as vc + from vkdispatch.codegen.abreviations import * + + graph = vd.CommandGraph() + + @vd.shader("buff.size") + def add_scalar(buff: Buff[f32], value: Const[f32]): + tid = vc.global_invocation_id().x + buff[tid] = buff[tid] + value + + arr = np.arange(32, dtype=np.float32) + buff = vd.asbuffer(arr) + + # Record 3 dispatches, then submit once. + add_scalar(buff, 1.0, graph=graph) + add_scalar(buff, 1.0, graph=graph) + add_scalar(buff, 1.0, graph=graph) + + graph.submit() + vd.queue_wait_idle() + + out = buff.read(0) + print(np.allclose(out, arr + 3.0)) # True + +Immediate vs Deferred Submission +-------------------------------- + +``CommandGraph`` supports two common modes: + +* Deferred mode (default): record first, call ``submit()`` later. +* Immediate mode: ``submit_on_record=True`` to submit each record call. + +.. code-block:: python + + immediate_graph = vd.CommandGraph(reset_on_submit=True, submit_on_record=True) + +In practice, deferred mode is usually better for batching work and reducing submission +overhead. + +Global Graphs and Thread-Local Behavior +--------------------------------------- + +vkdispatch keeps a thread-local default graph used when no explicit ``graph=...`` is +provided. + +* ``vd.global_graph()`` returns the current graph for the thread. +* ``vd.default_graph()`` creates/returns the default immediate graph. +* ``vd.set_global_graph(graph)`` sets a custom graph for the current thread. + +For reproducible behavior in larger programs, passing ``graph=...`` explicitly is +recommended. + +CommandGraph API Reference +-------------------------- + +See the :doc:`Full Python API Reference <../python_api>` for complete API details on: + +* ``vkdispatch.CommandGraph`` +* ``vkdispatch.global_graph`` +* ``vkdispatch.default_graph`` +* ``vkdispatch.set_global_graph`` diff --git a/docs/tutorials/data_types.rst b/docs/tutorials/data_types.rst index e0482e57..73eab4a3 100644 --- a/docs/tutorials/data_types.rst +++ b/docs/tutorials/data_types.rst @@ -17,21 +17,21 @@ They also come in the following shapes: * Matricies (only :class:`vkdispatch.float32` at 2x2 and 4x4) Data Type API Reference ---------------------- +----------------------- -.. autofunction:: vkdispatch.is_dtype +.. autofunction:: vkdispatch.base.dtype.is_dtype -.. autofunction:: vkdispatch.is_scalar +.. autofunction:: vkdispatch.base.dtype.is_scalar -.. autofunction:: is_complex +.. autofunction:: vkdispatch.base.dtype.is_complex -.. autofunction:: vkdispatch.is_vector +.. autofunction:: vkdispatch.base.dtype.is_vector -.. autofunction:: vkdispatch.is_matrix +.. autofunction:: vkdispatch.base.dtype.is_matrix -.. autofunction:: vkdispatch.from_numpy_dtype +.. autofunction:: vkdispatch.base.dtype.from_numpy_dtype -.. autofunction:: vkdispatch.to_numpy_dtype +.. autofunction:: vkdispatch.base.dtype.to_numpy_dtype .. autoclass:: vkdispatch.dtype @@ -63,4 +63,4 @@ Data Type API Reference .. autoclass:: vkdispatch.mat2 -.. autoclass:: vkdispatch.mat4 \ No newline at end of file +.. autoclass:: vkdispatch.mat4 diff --git a/docs/tutorials/images_and_sampling.rst b/docs/tutorials/images_and_sampling.rst new file mode 100644 index 00000000..f60bc9b7 --- /dev/null +++ b/docs/tutorials/images_and_sampling.rst @@ -0,0 +1,86 @@ +Images and Sampling +=================== + +Buffers are the default data container in vkdispatch, but image objects are available +for texture-like sampling workflows. + +Image Types +----------- + +vkdispatch provides: + +* ``vd.Image1D`` +* ``vd.Image2D`` +* ``vd.Image2DArray`` +* ``vd.Image3D`` + +Each image supports host-side ``write(...)`` and ``read(...)`` as well as shader-side +sampling through ``image.sample()``. + +Basic Upload/Download Example +----------------------------- + +.. code-block:: python + + import numpy as np + import vkdispatch as vd + + data = np.sin( + np.array([[i / 8 + j / 17 for i in range(64)] for j in range(64)]) + ).astype(np.float32) + + img = vd.Image2D(data.shape, vd.float32) + img.write(data) + + roundtrip = img.read(0) + print(np.allclose(roundtrip, data)) + +Sampling in a Shader +-------------------- + +Use codegen image argument types (``Img1``, ``Img2``, ``Img3``) inside ``@vd.shader``: + +.. code-block:: python + + import vkdispatch.codegen as vc + from vkdispatch.codegen.abreviations import * + + upscale = 4 + out = vd.Buffer((data.shape[0] * upscale, data.shape[1] * upscale), vd.float32) + + @vd.shader("out.size") + def sample_2d(out: Buff[f32], src: Img2[f32], scale: Const[f32]): + tid = vc.global_invocation_id().x + ij = vc.ravel_index(tid, out.shape) + uv = vc.new_vec2_register(ij.y, ij.x) / scale + out[tid] = src.sample(uv).x + + sample_2d(out, img.sample(), float(upscale)) + sampled = out.read(0) + +``img.sample()`` creates a sampler object with configurable filtering/address modes. + +Sampler Configuration +--------------------- + +You can override sampling behavior: + +.. code-block:: python + + sampler = img.sample( + mag_filter=vd.Filter.LINEAR, + min_filter=vd.Filter.LINEAR, + address_mode=vd.AddressMode.CLAMP_TO_EDGE, + ) + + sample_2d(out, sampler, float(upscale)) + +Image API Reference +------------------- + +See the :doc:`Full Python API Reference <../python_api>` for complete API details on: + +* ``vkdispatch.Image``, ``vkdispatch.Image1D``, ``vkdispatch.Image2D`` +* ``vkdispatch.Image2DArray``, ``vkdispatch.Image3D`` +* ``vkdispatch.Sampler``, ``vkdispatch.Filter`` +* ``vkdispatch.AddressMode``, ``vkdispatch.BorderColor`` diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index 4522f2ec..04ecc5b1 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -6,9 +6,14 @@ A collection of tutorials covering how to use and modify the vkdispatch library. .. toctree:: :maxdepth: 2 + code_structure context_system buffer_tutorial + shader_tutorial + command_graph_tutorial data_types + reductions_and_fft + images_and_sampling logging - building_from_source \ No newline at end of file + building_from_source diff --git a/docs/tutorials/reductions_and_fft.rst b/docs/tutorials/reductions_and_fft.rst new file mode 100644 index 00000000..6b77430a --- /dev/null +++ b/docs/tutorials/reductions_and_fft.rst @@ -0,0 +1,287 @@ +Reductions and FFT Workflows +============================ + +This page covers common high-level numeric workflows in vkdispatch: + +* reductions with ``vd.reduce`` +* Fourier transforms with ``vd.fft`` +* VkFFT-backed transforms with ``vd.vkfft`` + +FFT Subsystem Overview +---------------------- + +vkdispatch provides two FFT backends: + +* ``vd.fft``: vkdispatch-generated shaders (runtime code generation). +* ``vd.vkfft``: VkFFT-backed plan execution. + +Use ``vd.fft`` when you want shader-level customization and fusion through mapping +hooks (``input_map``, ``output_map``, ``kernel_map``). Use ``vd.vkfft`` when you want +the VkFFT path with plan caching and a similar high-level API. + +Reduction Basics +---------------- + +Use ``@vd.reduce.reduce`` for pure binary reductions: + +.. code-block:: python + + import numpy as np + import vkdispatch as vd + from vkdispatch.codegen.abreviations import * + + @vd.reduce.reduce(0) + def sum_reduce(a: f32, b: f32) -> f32: + return a + b + + arr = np.random.rand(4096).astype(np.float32) + buf = vd.asbuffer(arr) + out = sum_reduce(buf).read(0) + + print("GPU sum:", float(out[0])) + print("CPU sum:", float(arr.sum(dtype=np.float32))) + +Mapped Reductions +----------------- + +Use ``@vd.reduce.map_reduce`` when you want a map stage before reduction: + +.. code-block:: python + + import vkdispatch.codegen as vc + + @vd.reduce.map_reduce(vd.reduce.SubgroupAdd) + def l2_energy_map(buffer: Buff[f32]) -> f32: + idx = vd.reduce.mapped_io_index() + v = buffer[idx] + return v * v + + energy_buf = l2_energy_map(buf) + energy = energy_buf.read(0)[0] + +This pattern is useful for sums of transformed values (norms, weighted sums, etc.). + +FFT with ``vd.fft`` +------------------- + +The ``vd.fft`` module dispatches vkdispatch-generated FFT shaders. + +.. code-block:: python + + import numpy as np + import vkdispatch as vd + + complex_signal = ( + np.random.rand(256) + 1j * np.random.rand(256) + ).astype(np.complex64) + + fft_buf = vd.asbuffer(complex_signal) + + vd.fft.fft(fft_buf) + freq = fft_buf.read(0) + + vd.fft.ifft(fft_buf) + recovered = fft_buf.read(0) + + print(np.allclose(recovered, complex_signal, atol=1e-3)) + +By default, inverse transforms use normalization (``normalize=True`` in ``vd.fft.ifft``). +Set ``normalize=False`` when you need raw inverse scaling behavior. + +To inspect generated FFT shaders, use: + +.. code-block:: python + + vd.fft.fft(fft_buf, print_shader=True) + +Axis and Dimensionality +----------------------- + +FFT routines accept an ``axis`` argument for explicit axis control and provide ``fft2`` +and ``fft3`` convenience functions. + +.. code-block:: python + + # Strided FFT over the second axis of a 2D batch (from performance-test workflows). + batch = ( + np.random.rand(8, 1024) + 1j * np.random.rand(8, 1024) + ).astype(np.complex64) + batch_buf = vd.asbuffer(batch) + + vd.fft.fft(batch_buf, axis=1) + + # 2D transform helper (last two axes). + image = ( + np.random.rand(512, 512) + 1j * np.random.rand(512, 512) + ).astype(np.complex64) + image_buf = vd.asbuffer(image) + vd.fft.fft2(image_buf) + vd.fft.ifft2(image_buf) + +Real FFT (RFFT) helpers: + +.. code-block:: python + + real_signal = np.random.rand(512).astype(np.float32) + rbuf = vd.asrfftbuffer(real_signal) + + vd.fft.rfft(rbuf) + spectrum = rbuf.read_fourier(0) + + vd.fft.irfft(rbuf) + restored = rbuf.read_real(0) + + print(np.allclose(restored, real_signal, atol=1e-3)) + +Fusion with ``kernel_map`` (Frequency-Domain In-Register Ops) +-------------------------------------------------------------- + +``vd.fft.convolve`` can inject custom frequency-domain logic via ``kernel_map``. +Inside a kernel map callback, ``vd.fft.read_op()`` exposes the current FFT register +being processed. + +.. code-block:: python + + import vkdispatch.codegen as vc + + @vd.map + def scale_spectrum(scale_factor: vc.Var[vc.f32]): + op = vd.fft.read_op() + op.register[:] = op.register * scale_factor + + # Fused forward FFT + frequency scaling + inverse FFT + vd.fft.convolve(fft_buf, np.float32(0.5), kernel_map=scale_spectrum) + +This pattern avoids a separate full-buffer dispatch for many pointwise spectral +operations. + +Input/Output Mapping for Padded or Sparse Regions +------------------------------------------------- + +For advanced workflows (for example padded 2D cross-correlation), use ``input_map`` and +``output_map`` to remap FFT I/O indices and ``input_signal_range`` to skip inactive +regions. + +Map argument annotations do not determine FFT compute precision. ``read_op.register`` +and ``write_op.register`` always use the internal FFT compute type; map callbacks should +cast user-chosen buffer values to and from that register type as needed. If both FFT I/O +paths are mapped and ``compute_type`` is not provided, ``vd.fft`` defaults to +``complex64`` (falling back to ``complex32`` when required by device support). +When ``output_map`` is provided without ``input_map``, pass an explicit input buffer +argument after the ``output_map`` arguments so read and write phases use different proxies. + +.. code-block:: python + + import vkdispatch.codegen as vc + + def padded_axis_fft(buffer: vd.Buffer, signal_cols: int): + # Example expects buffer shape: (batch, rows, cols) + trimmed_shape = (buffer.shape[0], signal_cols, buffer.shape[2]) + + def remap(io_index: vc.ShaderVariable): + return vc.unravel_index( + vc.ravel_index(io_index, trimmed_shape).to_register(), + buffer.shape + ) + + @vd.map + def input_map(input_buffer: vc.Buffer[vc.c64]): + op = vd.fft.read_op() + op.read_from_buffer(input_buffer, io_index=remap(op.io_index)) + + @vd.map + def output_map(output_buffer: vc.Buffer[vc.c64]): + op = vd.fft.write_op() + op.write_to_buffer(output_buffer, io_index=remap(op.io_index)) + + vd.fft.fft( + buffer, + buffer, + buffer_shape=trimmed_shape, + axis=1, + input_map=input_map, + output_map=output_map, + input_signal_range=(0, signal_cols), + ) + +Transposed Kernel Path for 2D Convolution +----------------------------------------- + +When convolving along a strided axis, pre-transposing kernel layout can improve access +patterns. ``vd.fft`` provides helper APIs used by the benchmark suite: + +.. code-block:: python + + # signal_buf and kernel_buf are complex buffers with compatible FFT shapes. + transposed_size = vd.fft.get_transposed_size(signal_buf.shape, axis=1) + kernel_t = vd.Buffer((transposed_size,), vd.complex64) + + vd.fft.transpose(kernel_buf, axis=1, out_buffer=kernel_t) + + vd.fft.fft(signal_buf) + vd.fft.convolve(signal_buf, kernel_t, axis=1, transposed_kernel=True) + vd.fft.ifft(signal_buf) + +Low-Level Procedural FFT Generation with ``fft_context`` +-------------------------------------------------------- + +For full control over read/compute/write staging, build FFT shaders procedurally using +``vd.fft.fft_context`` and iterators from ``vd.fft``: + +.. code-block:: python + + import vkdispatch.codegen as vc + + with vd.fft.fft_context(buffer_shape=(1024,), axis=0) as ctx: + args = ctx.declare_shader_args([vc.Buffer[vc.c64]]) + + for read_op in vd.fft.global_reads_iterator(ctx.registers): + read_op.read_from_buffer(args[0]) + + ctx.execute(inverse=False) + + for write_op in vd.fft.global_writes_iterator(ctx.registers): + write_op.write_to_buffer(args[0]) + + fft_kernel = ctx.get_callable() + fft_kernel(fft_buf) + +FFT with ``vd.vkfft`` +--------------------- + +``vd.vkfft`` exposes a similar API but routes operations through VkFFT plan objects +with internal plan caching. + +.. code-block:: python + + vkfft_buf = vd.asbuffer(complex_signal.copy()) + vd.vkfft.fft(vkfft_buf) + vd.vkfft.ifft(vkfft_buf) + print(np.allclose(vkfft_buf.read(0), complex_signal, atol=1e-3)) + +After large parameter sweeps, clearing cached plans can be helpful: + +.. code-block:: python + + vd.vkfft.clear_plan_cache() + vd.fft.cache_clear() + +Convolution Helpers +------------------- + +vkdispatch also includes FFT-based convolution helpers: + +* ``vd.fft.convolve`` / ``vd.fft.convolve2D`` / ``vd.fft.convolve2DR`` +* ``vd.vkfft.convolve2D`` and ``vd.vkfft.transpose_kernel2D`` + +These APIs are most useful when you repeatedly convolve signals/images with known +kernel layouts. + +Reduction and FFT API Reference +------------------------------- + +See the :doc:`Full Python API Reference <../python_api>` for complete API details on: + +* ``vkdispatch.reduce`` +* ``vkdispatch.fft`` +* ``vkdispatch.vkfft`` diff --git a/docs/tutorials/shader_tutorial.rst b/docs/tutorials/shader_tutorial.rst new file mode 100644 index 00000000..060425dc --- /dev/null +++ b/docs/tutorials/shader_tutorial.rst @@ -0,0 +1,246 @@ +Shader Authoring and Dispatch +============================= + +vkdispatch lets you write compute logic in Python syntax and compile it to GLSL at +runtime. This page covers shader launch patterns and the key semantics of vkdispatch's +runtime shader generation model. + +Examples below omit ``vd.initialize()`` and ``vd.make_context()`` because vkdispatch +creates them automatically on first runtime use. Call them manually only when you need +custom initialization/context settings. + +Runtime Generation Model +------------------------ + +``@vd.shader`` executes your Python function with tracing objects and emits shader code +as each operation runs. In practice: + +1. vkdispatch inspects type-annotated arguments and creates shader variables. +2. arithmetic, indexing, swizzles, and assignment append GLSL statements. +3. the generated source is compiled into a compute plan and then dispatched. + +This is different from AST/IR compilers: it is a forward streaming model, so explicit +register materialization and explicit shader control-flow helpers matter for performance +and correctness. + +Imports and Type Annotations +---------------------------- + +Most shader examples use these imports: + +.. code-block:: python + + import vkdispatch as vd + import vkdispatch.codegen as vc + from vkdispatch.codegen.abreviations import * + +* ``Buff[...]`` is a shader buffer argument type. +* ``Const[...]`` is a uniform/constant argument type. +* Dtype aliases such as ``f32``, ``i32``, and ``v2`` come from abbreviations. + +Basic In-Place Kernel +--------------------- + +.. code-block:: python + + import numpy as np + import vkdispatch as vd + import vkdispatch.codegen as vc + from vkdispatch.codegen.abreviations import * + + @vd.shader("buff.size") + def add_scalar(buff: Buff[f32], bias: Const[f32]): + tid = vc.global_invocation_id().x + buff[tid] = buff[tid] + bias + + arr = np.arange(32, dtype=np.float32) + buff = vd.asbuffer(arr) + add_scalar(buff, 1.5) + + result = buff.read(0) + print(result[:4]) # [1.5 2.5 3.5 4.5] + +Launch Configuration +-------------------- + +Use one of these launch patterns: + +* String expression (evaluated from function argument names): + + .. code-block:: python + + @vd.shader("in_buf.size") + def kernel(in_buf: Buff[f32], out_buf: Buff[f32]): + ... + +* Fixed total dispatch size: + + .. code-block:: python + + @vd.shader(exec_size=(1024, 1, 1)) + def kernel(...): + ... + +* Dynamic size from call arguments: + + .. code-block:: python + + @vd.shader(exec_size=lambda args: args.in_buf.size) + def kernel(in_buf: Buff[f32], out_buf: Buff[f32]): + ... + +* Explicit workgroups instead of ``exec_size``: + + .. code-block:: python + + @vd.shader(workgroups=(64, 1, 1), local_size=(128, 1, 1)) + def kernel(...): + ... + +``exec_size`` and ``workgroups`` are mutually exclusive. +The string form is often the most concise option for argument-dependent dispatch size. + +You can also override launch parameters per call: + +.. code-block:: python + + # Reuse the same compiled shader with different dispatch sizes. + add_scalar(buff, 1.5, exec_size=buff.size) + +Symbolic Expressions vs Mutable Registers +----------------------------------------- + +vkdispatch variables are symbolic by default. Reusing an expression in multiple places +inlines that expression each time in generated code. + +To materialize a value once and mutate it, convert it to a register with +``to_register()``: + +.. code-block:: python + + @vd.shader("buff.size") + def register_example(buff: Buff[f32]): + tid = vc.global_invocation_id().x + + # Expression variable: may be inlined at each use. + expr = vc.sin(tid * 0.1) + + # Register variable: emitted once, then reused. + cached = expr.to_register("cached") + + buff[tid] = cached * 2.0 + cached / 3.0 + +Register Store Syntax (``[:]``) +------------------------------- + +Python assignment rebinding (``x = ...``) changes the Python name, not the generated +shader register. To emit a GLSL assignment into an existing register, use full-slice +store syntax ``x[:] = ...``. + +.. code-block:: python + + @vd.shader("buff.size") + def register_store(buff: Buff[f32]): + tid = vc.global_invocation_id().x + value = buff[tid].to_register("value") + value[:] = value * 0.5 + 1.0 + buff[tid] = value + +Shader Control Flow vs Python Control Flow +------------------------------------------ + +Native Python control flow with vkdispatch variables is intentionally blocked: + +.. code-block:: python + + @vd.shader("buff.size") + def bad_branch(buff: Buff[f32]): + tid = vc.global_invocation_id().x + if tid < 10: # Raises ValueError: vkdispatch variables are not Python booleans. + buff[tid] = 1.0 + +Use shader control-flow helpers so both branches are emitted into generated code: + +.. code-block:: python + + @vd.shader("buff.size") + def threshold(buff: Buff[f32], cutoff: Const[f32]): + tid = vc.global_invocation_id().x + + vc.if_statement(buff[tid] > cutoff) + buff[tid] = 1.0 + vc.else_statement() + buff[tid] = 0.0 + vc.end() + +Generation-Time Specialization (Meta-Programming) +------------------------------------------------- + +Because kernel bodies execute as normal Python during generation, Python loops and +conditionals are useful for specialization and unrolling. + +.. code-block:: python + + def make_unrolled_sum(unroll: int): + @vd.shader("dst.size") + def unrolled_sum(src: Buff[f32], dst: Buff[f32]): + tid = vc.global_invocation_id().x + base = (tid * unroll).to_register("base") + acc = vc.new_float_register(0.0) + + # Unrolled at generation time. + for i in range(unroll): + acc += src[base + i] + + dst[tid] = acc + + return unrolled_sum + + sum4 = make_unrolled_sum(4) + sum8 = make_unrolled_sum(8) + + # sum4 and sum8 compile to different shaders with different unrolled bodies. + +Mapping Functions +----------------- + +Mapping functions are reusable typed snippets (often used with reductions and FFT I/O). + +.. code-block:: python + + @vd.map + def square_value(x: Buff[f32]) -> f32: + idx = vd.reduce.mapped_io_index() + return x[idx] * x[idx] + +You can pass mapping functions into APIs that accept ``mapping_function``, +``input_map``, or ``output_map`` arguments. + +Inspecting Generated Shader Source +---------------------------------- + +A built shader can be printed for debugging: + +.. code-block:: python + + print(add_scalar) + +This prints GLSL-like generated source with line numbers, which is useful when debugging +type issues or unsupported expressions. + +Common Notes +------------ + +* All shader parameters must be type annotated. +* Buffer/image arguments must use codegen types (for example, ``Buff[f32]``, ``Img2[f32]``). +* If you need batched submissions, prefer :doc:`Command Graph Recording `. + +Shader API Reference +-------------------- + +See the :doc:`Full Python API Reference <../python_api>` for complete API details on: + +* ``vkdispatch.shader`` +* ``vkdispatch.map`` +* ``vkdispatch.ShaderFunction`` +* ``vkdispatch.MappingFunction`` diff --git a/examples/pytorch_cuda_graph_cuda_python.py b/examples/pytorch_cuda_graph_cuda_python.py new file mode 100644 index 00000000..51a949f9 --- /dev/null +++ b/examples/pytorch_cuda_graph_cuda_python.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +"""Capture and replay a vkdispatch CUDA kernel inside a PyTorch CUDA Graph. + +This example uses: + - vkdispatch runtime backend: "cuda" + - a custom vkdispatch shader recorded into CommandGraph + - torch.cuda.CUDAGraph capture + replay + - zero-copy tensor sharing via __cuda_array_interface__ +""" + +import torch + +import vkdispatch as vd +import vkdispatch.codegen as vc +from vkdispatch.codegen.abreviations import Buff, Const, f32 + + +@vd.shader(exec_size=lambda args: args.x.size) +def custom_shader(out: Buff[f32], x: Buff[f32], bias: Const[f32]): + tid = vc.global_invocation_id().x + out[tid] = x[tid] * 1.5 + vc.sin(x[tid]) + bias + + +def main() -> None: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this example.") + + torch.cuda.set_device(0) + torch.manual_seed(0) + + vd.initialize(backend="cuda") + vd.make_context(device_ids=torch.cuda.current_device()) + + n = 16 + bias = 0.25 + + # Static allocations are required for CUDA Graph replay. + x = torch.empty(n, device="cuda", dtype=torch.float32) + out = torch.empty_like(x) + x.fill_(0.0) + + x_vd = vd.from_cuda_array(x) + out_vd = vd.from_cuda_array(out) + + cmd_graph = vd.CommandGraph() + + # Record one vkdispatch kernel launch into the command graph. + # For backend="cuda-python", Const/Var payloads are fixed at record time. + custom_shader(out=out_vd, x=x_vd, bias=bias, graph=cmd_graph) + + torch.cuda.synchronize() + # Pre-stage internal uniform uploads outside torch capture so only dispatch is captured. + #cmd_graph.prepare_for_cuda_graph_capture() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + # torch.cuda.graph(...) may switch to an internal capture stream. + # Bind vkdispatch to the active stream from inside that context. + with vd.cuda_graph_capture(torch.cuda.current_stream()): + print("Submitting vkdispatch CommandGraph to CUDA Graph...") + cmd_graph.submit() + print("Done recording.") + + replay_inputs = [0.0, 1.0, 2.0, 3.0] + for i, value in enumerate(replay_inputs, start=1): + x.fill_(value) + graph.replay() + torch.cuda.synchronize() + + expected = x * 1.5 + torch.sin(x) + bias + torch.testing.assert_close(out, expected, rtol=1e-5, atol=1e-5) + print( + f"replay {i} input={value:.1f} output[:8]={out[:8].detach().cpu().tolist()}" + ) + + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/fetch_dependencies.py b/fetch_dependencies.py index 436f392d..05a21b66 100644 --- a/fetch_dependencies.py +++ b/fetch_dependencies.py @@ -60,7 +60,7 @@ def clone_and_checkout(repo_url, commit_hash, output_dir): os.makedirs("deps/MoltenVK", exist_ok=True) -molten_vk_url = "https://github.com/KhronosGroup/MoltenVK/releases/download/v1.2.8/MoltenVK-macos.tar" +molten_vk_url = "https://github.com/KhronosGroup/MoltenVK/releases/download/v1.4.0/MoltenVK-macos.tar" molten_vk_path = "deps/MoltenVK" molten_vk_filename = "MoltenVK-macos.tar" molten_vk_full_file_path = os.path.join(molten_vk_path, molten_vk_filename) diff --git a/merge.py b/merge.py new file mode 100644 index 00000000..2ad25474 --- /dev/null +++ b/merge.py @@ -0,0 +1,51 @@ +import os + +def consolidate_repo(root_dir, output_file): + # Extensions to include + extensions = {'.cpp', '.h', '.hh', '.py', '.pxd', '.pyx', '.toml'} + + # Files to ignore (common venv or git directories) + ignore_dirs = {'.git', '__pycache__', 'build', 'dist', 'deps', 'venv', 'env', '.idea', '.vscode'} + + with open(output_file, 'w', encoding='utf-8') as outfile: + # Walk through the directory tree + for dirpath, dirnames, filenames in os.walk(root_dir): + # Modify dirnames in-place to skip ignored directories + dirnames[:] = [d for d in dirnames if d not in ignore_dirs] + + for filename in filenames: + if filename == "wrapper.cpp": + continue + _, ext = os.path.splitext(filename) + + if ext in extensions: + file_path = os.path.join(dirpath, filename) + # Create a relative path for cleaner metadata + rel_path = os.path.relpath(file_path, root_dir) + + try: + with open(file_path, 'r', encoding='utf-8', errors='replace') as infile: + content = infile.read() + + # Write metadata header + outfile.write(f"\n{'='*80}\n") + outfile.write(f"FILE: {rel_path}\n") + outfile.write(f"{'='*80}\n\n") + + # Write file content + outfile.write(content) + outfile.write("\n") # Ensure separation + + print(f"Processed: {rel_path}") + + except Exception as e: + print(f"Error reading {rel_path}: {e}") + +if __name__ == "__main__": + # You can change these paths as needed + source_directory = "." # Current directory + output_filename = "codebase.txt" + + print(f"Scanning directory: {os.path.abspath(source_directory)}") + consolidate_repo(source_directory, output_filename) + print(f"\nDone! All files consolidated into: {output_filename}") \ No newline at end of file diff --git a/performance_tests/conv_2d/conv_cufft.cu b/performance_tests/conv_2d/conv_cufft.cu deleted file mode 100644 index 6c88c92b..00000000 --- a/performance_tests/conv_2d/conv_cufft.cu +++ /dev/null @@ -1,237 +0,0 @@ -// actual_test_cuda.cu -// Usage: ./actual_test_cuda -// Output: fft_cuda__axis.csv with the same columns as your Torch script. -// -// Build (example): -// nvcc -O3 -std=c++17 actual_test_cuda.cu -lcufft -o actual_test_cuda - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -__global__ void fill_randomish(cufftComplex* a, long long n){ - long long i = blockIdx.x * 1LL * blockDim.x + threadIdx.x; - if(i \n"; - std::exit(1); - } - Config c; - c.data_size = std::stoll(argv[1]); - c.iter_count = std::stoi(argv[2]); - c.iter_batch = std::stoi(argv[3]); - c.run_count = std::stoi(argv[4]); - return c; -} - -static std::vector get_fft_sizes() { - std::vector sizes; - for (int p = 6; p <= 12; ++p) sizes.push_back(1 << p); // 64..4096 - return sizes; -} - -// Compute GB processed per single FFT execution (read + write) for shape (dim0, dim1) -static double gb_per_exec(long long dim0, long long dim1, long long dim2) { - // complex64 = 8 bytes; count both read and write -> *2 - const double bytes = static_cast(dim0) * static_cast(dim1) * static_cast(dim2) * 8.0; - return bytes / (1024.0 * 1024.0 * 1024.0); -} - -static double run_cufft_case(const Config& cfg, int fft_size) { - const long long total_fft_area = fft_size * fft_size; - - const long long dim0 = cfg.data_size / total_fft_area; - const long long dim1 = fft_size; - const long long dim2 = fft_size; - const long long total_elems = dim0 * dim1 * dim2; - - // Device buffers (in-place transform will overwrite input) - cufftComplex* d_data = nullptr; - checkCuda(cudaMalloc(&d_data, total_elems * sizeof(cufftComplex)), "cudaMalloc d_data"); - // Optionally zero-fill - checkCuda(cudaMemset(d_data, 0, total_elems * sizeof(cufftComplex)), "cudaMemset d_data"); - - cufftComplex* d_kernel = nullptr; - checkCuda(cudaMalloc(&d_kernel, (total_elems) * sizeof(cufftComplex)), "cudaMalloc d_kernel"); - // Optionally zero-fill - checkCuda(cudaMemset(d_kernel, 0, (total_elems) * sizeof(cufftComplex)), "cudaMemset d_kernel"); - - { - int t = 256, b = int((total_elems + t - 1) / t); - fill_randomish<<>>(d_data, total_elems); - checkCuda(cudaGetLastError(), "fill launch"); - checkCuda(cudaDeviceSynchronize(), "fill sync"); - - int kt = 256, kb = int((total_elems + kt - 1) / kt); - fill_randomish<<>>(d_kernel, total_elems); - checkCuda(cudaGetLastError(), "fill kernel launch"); - checkCuda(cudaDeviceSynchronize(), "fill kernel sync"); - } - - // --- plan bound to the stream --- - cufftHandle plan; - checkCuFFT(cufftCreate(&plan), "cufftCreate"); - - int n[2] = { int(dim1), int(dim2) }; - int inembed[2] = { int(dim1), int(dim2) }; // physical layout (same as n for tight pack) - int onembed[2] = { int(dim1), int(dim2) }; - int istride = 1; // contiguous within each 2D image - int ostride = 1; - int idist = int(dim1)* int(dim2); // distance between images - int odist = int(dim1)* int(dim2); - - checkCuFFT(cufftPlanMany(&plan, 2, n, - inembed, istride, idist, - onembed, ostride, odist, - CUFFT_C2C, int(dim0)), "plan2d"); - - // --- warmup on the stream --- - for (int i = 0; i < cfg.warmup; ++i) { - checkCuFFT(cufftExecC2C(plan, d_data, d_data, CUFFT_FORWARD), "warmup"); - convolve_arrays<<<(total_elems+255)/256,256>>>(d_data, d_kernel, total_elems); - checkCuFFT(cufftExecC2C(plan, d_data, d_data, CUFFT_INVERSE), "warmup"); - } - - checkCuda(cudaDeviceSynchronize(), "warmup sync"); - - // === OPTION A: plain single-stream timing (simple & robust) === - cudaEvent_t evA, evB; - checkCuda(cudaEventCreate(&evA), "evA"); - checkCuda(cudaEventCreate(&evB), "evB"); - checkCuda(cudaEventRecord(evA), "record A"); - for (int it = 0; it < cfg.iter_count; ++it) { - checkCuFFT(cufftExecC2C(plan, d_data, d_data, CUFFT_FORWARD), "exec"); - convolve_arrays<<<(total_elems+255)/256,256>>>(d_data, d_kernel, total_elems); - checkCuFFT(cufftExecC2C(plan, d_data, d_data, CUFFT_INVERSE), "exec"); - } - checkCuda(cudaEventRecord(evB), "record B"); - checkCuda(cudaEventSynchronize(evB), "sync B"); - checkCuda(cudaDeviceSynchronize(), "warmup sync"); - float ms = 0.f; checkCuda(cudaEventElapsedTime(&ms, evA, evB), "elapsed"); - checkCuda(cudaEventDestroy(evA), "dA"); - checkCuda(cudaEventDestroy(evB), "dB"); - - // Convert elapsed to seconds - const double seconds = static_cast(ms) / 1000.0; - - // Compute throughput in GB/s (same accounting as Torch: 2 * elems * 8 bytes per exec) - const double gb_per_exec_once = 11 * gb_per_exec(dim0, dim1, dim2); - const double total_execs = static_cast(cfg.iter_count); // * static_cast(cfg.iter_batch); - const double gb_per_second = (total_execs * gb_per_exec_once) / seconds; - - // Cleanup - cufftDestroy(plan); - cudaFree(d_data); - cudaFree(d_kernel); - - return gb_per_second; -} - -int main(int argc, char** argv) { - const Config cfg = parse_args(argc, argv); - const auto sizes = get_fft_sizes(); - - const std::string output_name = "conv_cufft.csv"; - std::ofstream out(output_name); - if (!out) { - std::cerr << "Failed to open output file: " << output_name << "\n"; - return 1; - } - - std::cout << "Running cuFFT tests with data size " << cfg.data_size - << ", iter_count " << cfg.iter_count - << ", iter_batch " << cfg.iter_batch - << ", run_count " << cfg.run_count << "\n"; - - // Header: Backend, FFT Size, Run 1..N, Mean, Std Dev - out << "Backend,FFT Size"; - for (int i = 0; i < cfg.run_count; ++i) out << ",Run " << (i + 1) << " (GB/s)"; - out << ",Mean,Std Dev\n"; - - for (int fft_size : sizes) { - std::vector rates; - rates.reserve(cfg.run_count); - - for (int r = 0; r < cfg.run_count; ++r) { - const double gbps = run_cufft_case(cfg, fft_size); - std::cout << "FFT Size: " << fft_size << ", Throughput: " << std::fixed << std::setprecision(2) - << gbps << " GB/s\n"; - rates.push_back(gbps); - } - - // Compute mean/std - double mean = 0.0; - for (double v : rates) mean += v; - mean /= static_cast(rates.size()); - - double var = 0.0; - for (double v : rates) { - const double d = v - mean; - var += d * d; - } - var /= static_cast(rates.size()); - const double stdev = std::sqrt(var); - - // Round to 2 decimals like your Torch script - out << "cufft," << fft_size; - out << std::fixed << std::setprecision(2); - for (double v : rates) out << "," << v; - out << "," << mean << "," << stdev << "\n"; - } - - std::cout << "Results saved to " << output_name << "\n"; - return 0; -} diff --git a/performance_tests/conv_2d/conv_cufft_callback.cu b/performance_tests/conv_2d/conv_cufft_callback.cu deleted file mode 100644 index fb14be84..00000000 --- a/performance_tests/conv_2d/conv_cufft_callback.cu +++ /dev/null @@ -1,266 +0,0 @@ -// actual_test_cuda.cu -// Usage: ./actual_test_cuda -// Output: fft_cuda__axis.csv with the same columns as your Torch script. -// -// Build (example): -// nvcc -O3 -std=c++17 actual_test_cuda.cu -lcufft -o actual_test_cuda - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -struct CallbackParams { - cufftComplex* filter; // device pointer, length = NX * NY - size_t elemsPerImage; // NX * NY -}; - -__global__ void fill_randomish(cufftComplex* a, long long n){ - long long i = blockIdx.x * 1LL * blockDim.x + threadIdx.x; - if(i(callerInfo); - const size_t idxInImage = offset; - - // Multiply element by filter[idxInImage] - const cufftComplex h = p->filter[idxInImage]; - cufftComplex y; - y.x = element.x * h.x - element.y * h.y; - y.y = element.x * h.y + element.y * h.x; - - static_cast(dataOut)[offset] = y; -} - -__device__ cufftCallbackStoreC d_store_cb_ptr = store_mul_cb; - -static inline void checkCuda(cudaError_t err, const char* what) { - if (err != cudaSuccess) { - std::cerr << "[CUDA] " << what << " failed: " << cudaGetErrorString(err) << "\n"; - std::exit(1); - } -} - -static inline void checkCuFFT(cufftResult err, const char* what) { - if (err != CUFFT_SUCCESS) { - std::cerr << "[cuFFT] " << what << " failed: " << err << "\n"; - std::exit(1); - } -} - -struct Config { - long long data_size; - int iter_count; - int iter_batch; - int run_count; - int warmup = 10; // match Torch script’s warmup -}; - -static Config parse_args(int argc, char** argv) { - if (argc != 5) { - std::cerr << "Usage: " << argv[0] - << " \n"; - std::exit(1); - } - Config c; - c.data_size = std::stoll(argv[1]); - c.iter_count = std::stoi(argv[2]); - c.iter_batch = std::stoi(argv[3]); - c.run_count = std::stoi(argv[4]); - return c; -} - -static std::vector get_fft_sizes() { - std::vector sizes; - for (int p = 6; p <= 12; ++p) sizes.push_back(1 << p); // 64..4096 - return sizes; -} - -// Compute GB processed per single FFT execution (read + write) for shape (dim0, dim1) -static double gb_per_exec(long long dim0, long long dim1, long long dim2) { - // complex64 = 8 bytes; count both read and write -> *2 - const double bytes = static_cast(dim0) * static_cast(dim1) * static_cast(dim2) * 8.0; - return bytes / (1024.0 * 1024.0 * 1024.0); -} - -static double run_cufft_case(const Config& cfg, int fft_size) { - const long long total_fft_area = fft_size * fft_size; - - const long long dim0 = cfg.data_size / total_fft_area; - const long long dim1 = fft_size; - const long long dim2 = fft_size; - const long long total_elems = dim0 * dim1 * dim2; - - // Device buffers (in-place transform will overwrite input) - cufftComplex* d_data = nullptr; - checkCuda(cudaMalloc(&d_data, total_elems * sizeof(cufftComplex)), "cudaMalloc d_data"); - // Optionally zero-fill - checkCuda(cudaMemset(d_data, 0, total_elems * sizeof(cufftComplex)), "cudaMemset d_data"); - - cufftComplex* d_kernel = nullptr; - checkCuda(cudaMalloc(&d_kernel, (total_elems) * sizeof(cufftComplex)), "cudaMalloc d_kernel"); - // Optionally zero-fill - checkCuda(cudaMemset(d_kernel, 0, (total_elems) * sizeof(cufftComplex)), "cudaMemset d_kernel"); - - { - int t = 256, b = int((total_elems + t - 1) / t); - fill_randomish<<>>(d_data, total_elems); - checkCuda(cudaGetLastError(), "fill launch"); - checkCuda(cudaDeviceSynchronize(), "fill sync"); - - int kt = 256, kb = int((total_elems + kt - 1) / kt); - fill_randomish<<>>(d_kernel, total_elems); - checkCuda(cudaGetLastError(), "fill kernel launch"); - checkCuda(cudaDeviceSynchronize(), "fill kernel sync"); - } - - CallbackParams h_params{ d_kernel, size_t(dim1) * size_t(dim2) }; - CallbackParams* d_params = nullptr; - checkCuda(cudaMalloc(&d_params, sizeof(CallbackParams)), "cudaMalloc params"); - checkCuda(cudaMemcpy(d_params, &h_params, sizeof(CallbackParams), cudaMemcpyHostToDevice), "cudaMemcpy params"); - - // --- plan bound to the stream --- - cufftHandle plans[2]; - checkCuFFT(cufftCreate(&plans[0]), "cufftCreate"); - checkCuFFT(cufftCreate(&plans[1]), "cufftCreate"); - - int n[2] = { int(dim1), int(dim2) }; - int inembed[2] = { int(dim1), int(dim2) }; // physical layout (same as n for tight pack) - int onembed[2] = { int(dim1), int(dim2) }; - int istride = 1; // contiguous within each 2D image - int ostride = 1; - int idist = int(dim1)* int(dim2); // distance between images - int odist = int(dim1)* int(dim2); - - checkCuFFT(cufftPlanMany(&plans[0], 2, n, - inembed, istride, idist, - onembed, ostride, odist, - CUFFT_C2C, int(dim0)), "plan2d"); - - checkCuFFT(cufftPlanMany(&plans[1], 2, n, - inembed, istride, idist, - onembed, ostride, odist, - CUFFT_C2C, int(dim0)), "plan2d"); - - cufftCallbackStoreC h_store_cb_ptr; - checkCuda(cudaMemcpyFromSymbol(&h_store_cb_ptr, d_store_cb_ptr, sizeof(h_store_cb_ptr)), "memcpy from symbol"); - - void* cb_ptrs[1] = { (void*)h_store_cb_ptr }; - void* cb_data[1] = { (void*)d_params }; // single pointer: our params struct - checkCuFFT(cufftXtSetCallback(plans[0], cb_ptrs, CUFFT_CB_ST_COMPLEX, cb_data), "set callback"); - - // --- warmup on the stream --- - for (int i = 0; i < cfg.warmup; ++i) { - checkCuFFT(cufftExecC2C(plans[0], d_data, d_data, CUFFT_FORWARD), "warmup"); - checkCuFFT(cufftExecC2C(plans[1], d_data, d_data, CUFFT_INVERSE), "warmup"); - } - - checkCuda(cudaDeviceSynchronize(), "warmup sync"); - - // === OPTION A: plain single-stream timing (simple & robust) === - cudaEvent_t evA, evB; - checkCuda(cudaEventCreate(&evA), "evA"); - checkCuda(cudaEventCreate(&evB), "evB"); - checkCuda(cudaEventRecord(evA), "record A"); - for (int it = 0; it < cfg.iter_count; ++it) { - checkCuFFT(cufftExecC2C(plans[0], d_data, d_data, CUFFT_FORWARD), "exec"); - checkCuFFT(cufftExecC2C(plans[1], d_data, d_data, CUFFT_INVERSE), "exec"); - } - checkCuda(cudaEventRecord(evB), "record B"); - checkCuda(cudaEventSynchronize(evB), "sync B"); - checkCuda(cudaDeviceSynchronize(), "warmup sync"); - float ms = 0.f; checkCuda(cudaEventElapsedTime(&ms, evA, evB), "elapsed"); - checkCuda(cudaEventDestroy(evA), "dA"); - checkCuda(cudaEventDestroy(evB), "dB"); - - // Convert elapsed to seconds - const double seconds = static_cast(ms) / 1000.0; - - // Compute throughput in GB/s (same accounting as Torch: 2 * elems * 8 bytes per exec) - const double gb_per_exec_once = 11 * gb_per_exec(dim0, dim1, dim2); - const double total_execs = static_cast(cfg.iter_count); // * static_cast(cfg.iter_batch); - const double gb_per_second = (total_execs * gb_per_exec_once) / seconds; - - // Cleanup - cufftDestroy(plans[0]); - cufftDestroy(plans[1]); - cudaFree(d_data); - cudaFree(d_kernel); - - return gb_per_second; -} - -int main(int argc, char** argv) { - const Config cfg = parse_args(argc, argv); - const auto sizes = get_fft_sizes(); - - const std::string output_name = "conv_cufft_callback.csv"; - std::ofstream out(output_name); - if (!out) { - std::cerr << "Failed to open output file: " << output_name << "\n"; - return 1; - } - - std::cout << "Running cuFFT tests with data size " << cfg.data_size - << ", iter_count " << cfg.iter_count - << ", iter_batch " << cfg.iter_batch - << ", run_count " << cfg.run_count << "\n"; - - // Header: Backend, FFT Size, Run 1..N, Mean, Std Dev - out << "Backend,FFT Size"; - for (int i = 0; i < cfg.run_count; ++i) out << ",Run " << (i + 1) << " (GB/s)"; - out << ",Mean,Std Dev\n"; - - for (int fft_size : sizes) { - std::vector rates; - rates.reserve(cfg.run_count); - - for (int r = 0; r < cfg.run_count; ++r) { - const double gbps = run_cufft_case(cfg, fft_size); - std::cout << "FFT Size: " << fft_size << ", Throughput: " << std::fixed << std::setprecision(2) - << gbps << " GB/s\n"; - rates.push_back(gbps); - } - - // Compute mean/std - double mean = 0.0; - for (double v : rates) mean += v; - mean /= static_cast(rates.size()); - - double var = 0.0; - for (double v : rates) { - const double d = v - mean; - var += d * d; - } - var /= static_cast(rates.size()); - const double stdev = std::sqrt(var); - - // Round to 2 decimals like your Torch script - out << "cufft_callback," << fft_size; - out << std::fixed << std::setprecision(2); - for (double v : rates) out << "," << v; - out << "," << mean << "," << stdev << "\n"; - } - - std::cout << "Results saved to " << output_name << "\n"; - return 0; -} diff --git a/performance_tests/conv_2d/conv_make_graph.py b/performance_tests/conv_2d/conv_make_graph.py deleted file mode 100644 index 50f3ba41..00000000 --- a/performance_tests/conv_2d/conv_make_graph.py +++ /dev/null @@ -1,92 +0,0 @@ -import glob -import csv -from typing import Dict, Tuple, Set -from matplotlib import pyplot as plt -import numpy as np -import sys - -# Nested structure: -# merged[backend][fft_size] = (mean, std) -MergedType = Dict[str, Dict[int, Tuple[float, float]]] - -def read_bench_csvs() -> Tuple[MergedType, Set[str], Set[int]]: - pattern = f"conv_*.csv" - files = glob.glob(pattern) - - merged: MergedType = {} - backends: Set[str] = set() - fft_sizes: Set[int] = set() - - for filename in files: - print(f"Reading: {filename}") - with open(filename, newline="") as f: - reader = csv.DictReader(f) - for row in reader: - backend = row["Backend"].strip() - size = int(row["FFT Size"]) - mean = float(row["Mean"]) - std = float(row["Std Dev"]) - - backends.add(backend) - fft_sizes.add(size) - - if backend not in merged: - merged[backend] = {} - - # last one wins if duplicates appear across files - merged[backend][size] = (mean, std) - - return merged, backends, fft_sizes - -def save_graph(backends: Set[str], fft_sizes: Set[int], merged: MergedType, min_fft_size: int = None): - plt.figure(figsize=(10, 6)) - - if min_fft_size is not None: - used_fft_sizes = [size for size in fft_sizes if size >= min_fft_size] - else: - used_fft_sizes = fft_sizes - - for backend_name in backends: - means = [ - merged[backend_name][i][0] - for i in used_fft_sizes - ] - stds = [ - merged[backend_name][i][1] - for i in used_fft_sizes - ] - - plt.errorbar( - used_fft_sizes, - means, - yerr=stds, - label=backend_name, - capsize=5, - ) - plt.xscale('log', base=2) - plt.xlabel('Convolution Size') - plt.ylabel('GB/s') - plt.title('Convolution Performance Comparison') - plt.legend() - plt.grid(True) - if min_fft_size is not None: - plt.savefig(f"conv_graph_min_size{min_fft_size}.png") - return - plt.savefig(f"conv_graph.png") - -if __name__ == "__main__": - # Example usage (change the number as needed) - merged, backends, fft_sizes = read_bench_csvs() - - print("\nSummary:") - print(f"Backends found: {sorted(backends)}") - print(f"Convolution sizes found: {sorted(fft_sizes)}") - print(f"Total entries: {sum(len(v) for v in merged.values())}") - - sorted_backends = sorted(backends) - sorted_fft_sizes = sorted(fft_sizes) - - save_graph(sorted_backends, sorted_fft_sizes, merged) - #save_graph(sorted_backends, sorted_fft_sizes, merged, min_fft_size=256) - - diff --git a/performance_tests/conv_2d/conv_torch.py b/performance_tests/conv_2d/conv_torch.py deleted file mode 100644 index 35a4e718..00000000 --- a/performance_tests/conv_2d/conv_torch.py +++ /dev/null @@ -1,81 +0,0 @@ -import csv -import time -import conv_utils as fu -import numpy as np -import torch - -def run_torch(config: fu.Config, fft_size: int) -> float: - shape = config.make_shape(fft_size) - random_data = config.make_random_data(fft_size) - random_data_kernel = config.make_random_data(fft_size) - - buffer = torch.empty( - shape, - dtype=torch.complex64, - device='cuda' - ) - - kernel = torch.empty( - shape, - dtype=torch.complex64, - device='cuda' - ) - - buffer.copy_(torch.from_numpy(random_data).to('cuda')) - kernel.copy_(torch.from_numpy(random_data_kernel).to('cuda')) - - stream = torch.cuda.Stream() - - torch.cuda.synchronize() - - with torch.cuda.stream(stream): - for _ in range(config.warmup): - buffer = torch.fft.ifft2(torch.fft.fft2(buffer) * kernel) - - torch.cuda.synchronize() - - gb_byte_count = 11 * np.prod(shape) * 8 / (1024 * 1024 * 1024) - - g = torch.cuda.CUDAGraph() - - # We capture either 1 or K FFTs back-to-back. All on the same stream. - with torch.cuda.graph(g, stream=stream): - for _ in range(max(1, config.iter_batch)): - buffer = torch.fft.ifft2(torch.fft.fft2(buffer) * kernel) - - torch.cuda.synchronize() - start_time = time.perf_counter() - - with torch.cuda.stream(stream): - for _ in range(config.iter_count // max(1, config.iter_batch)): - g.replay() - - torch.cuda.synchronize() - elapsed_time = time.perf_counter() - start_time - - return config.iter_count * gb_byte_count / elapsed_time - -if __name__ == "__main__": - config = fu.parse_args() - fft_sizes = fu.get_fft_sizes() - - output_name = f"conv_torch.csv" - with open(output_name, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(['Backend', 'FFT Size'] + [f'Run {i + 1} (GB/s)' for i in range(config.run_count)] + ['Mean', 'Std Dev']) - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_per_second = run_torch(config, fft_size) - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.2f} GB/s") - rates.append(gb_per_second) - - rounded_data = [round(rate, 2) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow(["torch", fft_size] + rounded_data + [rounded_mean, rounded_std]) - - print(f"Results saved to {output_name}.csv") \ No newline at end of file diff --git a/performance_tests/conv_2d/conv_utils.py b/performance_tests/conv_2d/conv_utils.py deleted file mode 100644 index e749346b..00000000 --- a/performance_tests/conv_2d/conv_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -import sys -from typing import Tuple -import dataclasses - -import numpy as np - -@dataclasses.dataclass -class Config: - data_size: int - iter_count: int - iter_batch: int - run_count: int - warmup: int = 10 - - def make_shape(self, fft_size: int) -> Tuple[int, ...]: - total_square_size = fft_size * fft_size - assert self.data_size % total_square_size == 0, "Data size must be a multiple of fft_size squared" - return (self.data_size // total_square_size, fft_size, fft_size) - - def make_random_data(self, fft_size: int): - shape = self.make_shape(fft_size) - return np.random.rand(*shape).astype(np.complex64) - -def parse_args() -> Config: - if len(sys.argv) != 5: - print(f"Usage: {sys.argv[0]} ") - sys.exit(1) - - return Config( - data_size=int(sys.argv[1]), - iter_count=int(sys.argv[2]), - iter_batch=int(sys.argv[3]), - run_count=int(sys.argv[4]), - ) - -def get_fft_sizes(): - return [2**i for i in range(6, 13)] # FFT sizes from 64 to 4096 (inclusive) - diff --git a/performance_tests/conv_2d/conv_vkdispatch.py b/performance_tests/conv_2d/conv_vkdispatch.py deleted file mode 100644 index d3246408..00000000 --- a/performance_tests/conv_2d/conv_vkdispatch.py +++ /dev/null @@ -1,104 +0,0 @@ -import csv -import time -import conv_utils as fu -import vkdispatch as vd -import vkdispatch.codegen as vc -import numpy as np - -def run_vkdispatch(config: fu.Config, fft_size: int) -> float: - shape = config.make_shape(fft_size) - random_data = config.make_random_data(fft_size) - random_data_2 = config.make_random_data(fft_size) - - buffer = vd.Buffer(shape, var_type=vd.complex64) - buffer.write(random_data) - - kernel = vd.Buffer(shape, var_type=vd.complex64) - kernel.write(random_data_2) - - graph = vd.CommandGraph() - - @vd.map_registers([vc.c64]) - def kernel_mapping(kernel_buffer: vc.Buffer[vc.c64]): - img_val = vc.mapping_registers()[0] - read_register = vc.mapping_registers()[1] - - # Calculate the invocation within this FFT batch - in_group_index = vc.local_invocation().y * vc.workgroup_size().x + vc.local_invocation().x - out_group_index = vc.workgroup().y * vc.num_workgroups().x + vc.workgroup().x - workgroup_index = in_group_index + out_group_index * ( - vc.workgroup_size().x * vc.workgroup_size().y - ) - - # Calculate the batch index of the FFT - batch_index = ( - vc.mapping_index() - ) / ( - vc.workgroup_size().x * vc.workgroup_size().y * - vc.num_workgroups().x * vc.num_workgroups().y - ) - - # Calculate the transposed index - transposed_index = workgroup_index + batch_index * ( - vc.workgroup_size().x * vc.workgroup_size().y * - vc.num_workgroups().x * vc.num_workgroups().y - ) - - read_register[:] = kernel_buffer[transposed_index] - img_val[:] = vc.mult_conj_c64(read_register, img_val) - - vd.fft.convolve2D(buffer, kernel, graph=graph, kernel_map=kernel_mapping) - - for _ in range(config.warmup): - graph.submit(config.iter_batch) - - vd.queue_wait_idle() - - gb_byte_count = 11 * 8 * buffer.size / (1024 * 1024 * 1024) - - start_time = time.perf_counter() - - for _ in range(config.iter_count // config.iter_batch): - graph.submit(config.iter_batch) - - vd.queue_wait_idle() - - elapsed_time = time.perf_counter() - start_time - - buffer.destroy() - graph.destroy() - vd.fft.cache_clear() - - time.sleep(1) - - vd.queue_wait_idle() - - return config.iter_count * gb_byte_count / elapsed_time - -if __name__ == "__main__": - config = fu.parse_args() - fft_sizes = fu.get_fft_sizes() - - output_name = f"conv_vkdispatch.csv" - with open(output_name, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(['Backend', 'FFT Size'] + [f'Run {i + 1} (GB/s)' for i in range(config.run_count)] + ['Mean', 'Std Dev']) - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_per_second = run_vkdispatch(config, fft_size) - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.2f} GB/s") - rates.append(gb_per_second) - - rounded_data = [round(rate, 2) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow(["vkdispatch", fft_size] + rounded_data + [rounded_mean, rounded_std]) - - print(f"Results saved to {output_name}.csv") - - - \ No newline at end of file diff --git a/performance_tests/conv_2d/conv_vkfft.py b/performance_tests/conv_2d/conv_vkfft.py deleted file mode 100644 index 38478048..00000000 --- a/performance_tests/conv_2d/conv_vkfft.py +++ /dev/null @@ -1,71 +0,0 @@ -import csv -import time -import conv_utils as fu -import vkdispatch as vd -import numpy as np - -def run_vkfft(config: fu.Config, fft_size: int) -> float: - shape = config.make_shape(fft_size) - random_data = config.make_random_data(fft_size) - random_data_2 = config.make_random_data(fft_size) - - buffer = vd.Buffer(shape, var_type=vd.complex64) - buffer.write(random_data) - - kernel = vd.Buffer(shape, var_type=vd.complex64) - kernel.write(random_data_2) - - graph = vd.CommandGraph() - - vd.vkfft.convolve_2D(buffer, kernel, graph=graph) - - for _ in range(config.warmup): - graph.submit(config.iter_batch) - - vd.queue_wait_idle() - - gb_byte_count = 11 * 8 * buffer.size / (1024 * 1024 * 1024) - - start_time = time.perf_counter() - - for _ in range(config.iter_count // config.iter_batch): - graph.submit(config.iter_batch) - - vd.queue_wait_idle() - - elapsed_time = time.perf_counter() - start_time - - buffer.destroy() - graph.destroy() - vd.vkfft.clear_plan_cache() - - time.sleep(1) - - vd.queue_wait_idle() - - return config.iter_count * gb_byte_count / elapsed_time - -if __name__ == "__main__": - config = fu.parse_args() - fft_sizes = fu.get_fft_sizes() - - output_name = f"conv_vkfft.csv" - with open(output_name, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(['Backend', 'FFT Size'] + [f'Run {i + 1} (GB/s)' for i in range(config.run_count)] + ['Mean', 'Std Dev']) - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_per_second = run_vkfft(config, fft_size) - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.2f} GB/s") - rates.append(gb_per_second) - - rounded_data = [round(rate, 2) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow(["vkfft", fft_size] + rounded_data + [rounded_mean, rounded_std]) - - print(f"Results saved to {output_name}.csv") \ No newline at end of file diff --git a/performance_tests/conv_2d/conv_zipfft.py b/performance_tests/conv_2d/conv_zipfft.py deleted file mode 100644 index c423af5b..00000000 --- a/performance_tests/conv_2d/conv_zipfft.py +++ /dev/null @@ -1,95 +0,0 @@ -import csv -import time -import conv_utils as fu -import numpy as np -import torch - -try: - from zipfft import cfft1d - from zipfft import conv1d_strided_padded -except ImportError: - print("zipfft is not installed. Please install it via 'pip install zipfft'.") - exit(0) - -def run_zipfft(config: fu.Config, fft_size: int) -> float: - shape = config.make_shape(fft_size) - random_data = config.make_random_data(fft_size) - - buffer = torch.empty( - shape, - dtype=torch.complex64, - device='cuda' - ) - - - kernel = torch.empty( - shape, - dtype=torch.complex64, - device='cuda' - ) - - - buffer.copy_(torch.from_numpy(random_data).to('cuda')) - kernel.copy_(torch.from_numpy(random_data).to('cuda')) - - stream = torch.cuda.Stream() - - torch.cuda.synchronize() - - with torch.cuda.stream(stream): - for _ in range(config.warmup): - cfft1d.fft(buffer.view(-1, buffer.size(2))) - conv1d_strided_padded.conv(buffer, kernel, fft_size) - cfft1d.ifft(buffer.view(-1, buffer.size(2))) - - - torch.cuda.synchronize() - - g = torch.cuda.CUDAGraph() - - # We capture either 1 or K FFTs back-to-back. All on the same stream. - with torch.cuda.graph(g, stream=stream): - for _ in range(max(1, config.iter_batch)): - cfft1d.fft(buffer.view(-1, buffer.size(2))) - conv1d_strided_padded.conv(buffer, kernel, fft_size) - cfft1d.ifft(buffer.view(-1, buffer.size(2))) - - torch.cuda.synchronize() - - gb_byte_count = 11 * np.prod(shape) * 8 / (1024 * 1024 * 1024) - - start_time = time.perf_counter() - - for _ in range(config.iter_count // max(1, config.iter_batch)): - g.replay() - - torch.cuda.synchronize() - - elapsed_time = time.perf_counter() - start_time - - return config.iter_count * gb_byte_count / elapsed_time - -if __name__ == "__main__": - config = fu.parse_args() - fft_sizes = fu.get_fft_sizes() - - output_name = f"conv_zipfft.csv" - with open(output_name, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(['Backend', 'FFT Size'] + [f'Run {i + 1} (GB/s)' for i in range(config.run_count)] + ['Mean', 'Std Dev']) - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_per_second = run_zipfft(config, fft_size) - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.2f} GB/s") - rates.append(gb_per_second) - - rounded_data = [round(rate, 2) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow(["zipfft", fft_size] + rounded_data + [rounded_mean, rounded_std]) - - print(f"Results saved to {output_name}.csv") \ No newline at end of file diff --git a/performance_tests/conv_2d/conv_zipfft_no_transpose.py b/performance_tests/conv_2d/conv_zipfft_no_transpose.py deleted file mode 100644 index a278cda5..00000000 --- a/performance_tests/conv_2d/conv_zipfft_no_transpose.py +++ /dev/null @@ -1,95 +0,0 @@ -import csv -import time -import conv_utils as fu -import numpy as np -import torch - -try: - from zipfft import fft_nonstrided - from zipfft import conv1d_strided_padded -except ImportError: - print("zipfft is not installed. Please install it via 'pip install zipfft'.") - exit(0) - -def run_zipfft(config: fu.Config, fft_size: int) -> float: - shape = config.make_shape(fft_size) - random_data = config.make_random_data(fft_size) - - buffer = torch.empty( - shape, - dtype=torch.complex64, - device='cuda' - ) - - - kernel = torch.empty( - shape, - dtype=torch.complex64, - device='cuda' - ) - - - buffer.copy_(torch.from_numpy(random_data).to('cuda')) - kernel.copy_(torch.from_numpy(random_data).to('cuda')) - - stream = torch.cuda.Stream() - - torch.cuda.synchronize() - - with torch.cuda.stream(stream): - for _ in range(config.warmup): - fft_nonstrided.fft(buffer.view(-1, buffer.size(2))) - conv1d_strided_padded.conv(buffer, kernel, fft_size, True) - fft_nonstrided.fft(buffer.view(-1, buffer.size(2))) - - - torch.cuda.synchronize() - - g = torch.cuda.CUDAGraph() - - # We capture either 1 or K FFTs back-to-back. All on the same stream. - with torch.cuda.graph(g, stream=stream): - for _ in range(max(1, config.iter_batch)): - fft_nonstrided.fft(buffer.view(-1, buffer.size(2))) - conv1d_strided_padded.conv(buffer, kernel, fft_size, True) - fft_nonstrided.fft(buffer.view(-1, buffer.size(2))) - - torch.cuda.synchronize() - - gb_byte_count = 11 * np.prod(shape) * 8 / (1024 * 1024 * 1024) - - start_time = time.perf_counter() - - for _ in range(config.iter_count // max(1, config.iter_batch)): - g.replay() - - torch.cuda.synchronize() - - elapsed_time = time.perf_counter() - start_time - - return config.iter_count * gb_byte_count / elapsed_time - -if __name__ == "__main__": - config = fu.parse_args() - fft_sizes = fu.get_fft_sizes() - - output_name = f"conv_zipfft_no_transpose.csv" - with open(output_name, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(['Backend', 'FFT Size'] + [f'Run {i + 1} (GB/s)' for i in range(config.run_count)] + ['Mean', 'Std Dev']) - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_per_second = run_zipfft(config, fft_size) - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.2f} GB/s") - rates.append(gb_per_second) - - rounded_data = [round(rate, 2) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow(["zipfft_no_transpose", fft_size] + rounded_data + [rounded_mean, rounded_std]) - - print(f"Results saved to {output_name}.csv") \ No newline at end of file diff --git a/performance_tests/conv_2d/run_tests.sh b/performance_tests/conv_2d/run_tests.sh deleted file mode 100644 index 2f87467e..00000000 --- a/performance_tests/conv_2d/run_tests.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash - -mkdir -p test_results - -cd test_results - -#DATA_SIZE=134217728 -DATA_SIZE=67108864 -#DATA_SIZE=33554432 -SIGNAL_FACTOR=8 -ITER_COUNT=80 -BATCH_SIZE=10 -REPEATS=3 - -# /usr/local/cuda/bin/nvcc -O2 -std=c++17 ../conv_cufft.cu -rdc=true -lcufft_static -lculibos -o conv_cufft.exec -# /usr/local/cuda/bin/nvcc -O2 -std=c++17 ../conv_cufft_callback.cu -rdc=true -lcufft_static -lculibos -o conv_cufft_callback.exec - -echo "Running performance tests with the following parameters:" -echo "Data Size: $DATA_SIZE" -echo "Iteration Count: $ITER_COUNT" -echo "Batch Size: $BATCH_SIZE" -echo "Repeats: $REPEATS" - -# echo "Running cuFFT FFT..." -# ./conv_cufft.exec $DATA_SIZE $ITER_COUNT $BATCH_SIZE $REPEATS - -# echo "Running cuFFT with callbacks FFT..." -# ./conv_cufft_callback.exec $DATA_SIZE $ITER_COUNT $BATCH_SIZE $REPEATS - -# echo "Running VKFFT FFT..." -# python3 ../conv_vkfft.py $DATA_SIZE $ITER_COUNT $BATCH_SIZE $REPEATS - -echo "Running Vkdispatch FFT..." -python3 ../conv_vkdispatch.py $DATA_SIZE $ITER_COUNT $BATCH_SIZE $REPEATS - -# echo "Running PyTorch FFT..." -# python3 ../conv_torch.py $DATA_SIZE $ITER_COUNT $BATCH_SIZE $REPEATS - -# echo "Running ZipFFT FFT..." -# python3 ../conv_zipfft.py $DATA_SIZE $ITER_COUNT $BATCH_SIZE $REPEATS - -python3 ../conv_make_graph.py \ No newline at end of file diff --git a/performance_tests/conv_padded_2d/conv_padded_cufft.cu b/performance_tests/conv_padded_2d/conv_padded_cufft.cu deleted file mode 100644 index 9ee51c3a..00000000 --- a/performance_tests/conv_padded_2d/conv_padded_cufft.cu +++ /dev/null @@ -1,237 +0,0 @@ -// actual_test_cuda.cu -// Usage: ./actual_test_cuda -// Output: fft_cuda__axis.csv with the same columns as your Torch script. -// -// Build (example): -// nvcc -O3 -std=c++17 actual_test_cuda.cu -lcufft -o actual_test_cuda - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -__global__ void fill_randomish(cufftComplex* a, long long n){ - long long i = blockIdx.x * 1LL * blockDim.x + threadIdx.x; - if(i \n"; - std::exit(1); - } - Config c; - c.data_size = std::stoll(argv[1]); - c.iter_count = std::stoi(argv[3]); - c.iter_batch = std::stoi(argv[4]); - c.run_count = std::stoi(argv[5]); - return c; -} - -static std::vector get_fft_sizes() { - std::vector sizes; - for (int p = 6; p <= 12; ++p) sizes.push_back(1 << p); // 64..4096 - return sizes; -} - -// Compute GB processed per single FFT execution (read + write) for shape (dim0, dim1) -static double gb_per_exec(long long dim0, long long dim1, long long dim2) { - // complex64 = 8 bytes; count both read and write -> *2 - const double bytes = static_cast(dim0) * static_cast(dim1) * static_cast(dim2) * 8.0; - return bytes / (1024.0 * 1024.0 * 1024.0); -} - -static double run_cufft_case(const Config& cfg, int fft_size) { - const long long total_fft_area = fft_size * fft_size; - - const long long dim0 = cfg.data_size / total_fft_area; - const long long dim1 = fft_size; - const long long dim2 = fft_size; - const long long total_elems = dim0 * dim1 * dim2; - - // Device buffers (in-place transform will overwrite input) - cufftComplex* d_data = nullptr; - checkCuda(cudaMalloc(&d_data, total_elems * sizeof(cufftComplex)), "cudaMalloc d_data"); - // Optionally zero-fill - checkCuda(cudaMemset(d_data, 0, total_elems * sizeof(cufftComplex)), "cudaMemset d_data"); - - cufftComplex* d_kernel = nullptr; - checkCuda(cudaMalloc(&d_kernel, total_elems * sizeof(cufftComplex)), "cudaMalloc d_kernel"); - // Optionally zero-fill - checkCuda(cudaMemset(d_kernel, 0, total_elems * sizeof(cufftComplex)), "cudaMemset d_kernel"); - - { - int t = 256, b = int((total_elems + t - 1) / t); - fill_randomish<<>>(d_data, total_elems); - checkCuda(cudaGetLastError(), "fill launch"); - checkCuda(cudaDeviceSynchronize(), "fill sync"); - - int kt = 256, kb = int((total_elems + kt - 1) / kt); - fill_randomish<<>>(d_kernel, total_elems); - checkCuda(cudaGetLastError(), "fill kernel launch"); - checkCuda(cudaDeviceSynchronize(), "fill kernel sync"); - } - - // --- plan bound to the stream --- - cufftHandle plan; - checkCuFFT(cufftCreate(&plan), "cufftCreate"); - - int n[2] = { int(dim1), int(dim2) }; - int inembed[2] = { int(dim1), int(dim2) }; // physical layout (same as n for tight pack) - int onembed[2] = { int(dim1), int(dim2) }; - int istride = 1; // contiguous within each 2D image - int ostride = 1; - int idist = int(dim1)* int(dim2); // distance between images - int odist = int(dim1)* int(dim2); - - checkCuFFT(cufftPlanMany(&plan, 2, n, - inembed, istride, idist, - onembed, ostride, odist, - CUFFT_C2C, int(dim0)), "plan2d"); - - // --- warmup on the stream --- - for (int i = 0; i < cfg.warmup; ++i) { - checkCuFFT(cufftExecC2C(plan, d_data, d_data, CUFFT_FORWARD), "warmup"); - convolve_arrays<<<(total_elems+255)/256,256>>>(d_data, d_kernel, total_elems); - checkCuFFT(cufftExecC2C(plan, d_data, d_data, CUFFT_INVERSE), "warmup"); - } - - checkCuda(cudaDeviceSynchronize(), "warmup sync"); - - // === OPTION A: plain single-stream timing (simple & robust) === - cudaEvent_t evA, evB; - checkCuda(cudaEventCreate(&evA), "evA"); - checkCuda(cudaEventCreate(&evB), "evB"); - checkCuda(cudaEventRecord(evA), "record A"); - for (int it = 0; it < cfg.iter_count; ++it) { - checkCuFFT(cufftExecC2C(plan, d_data, d_data, CUFFT_FORWARD), "exec"); - convolve_arrays<<<(total_elems+255)/256,256>>>(d_data, d_kernel, total_elems); - checkCuFFT(cufftExecC2C(plan, d_data, d_data, CUFFT_INVERSE), "exec"); - } - checkCuda(cudaEventRecord(evB), "record B"); - checkCuda(cudaEventSynchronize(evB), "sync B"); - checkCuda(cudaDeviceSynchronize(), "warmup sync"); - float ms = 0.f; checkCuda(cudaEventElapsedTime(&ms, evA, evB), "elapsed"); - checkCuda(cudaEventDestroy(evA), "dA"); - checkCuda(cudaEventDestroy(evB), "dB"); - - // Convert elapsed to seconds - const double seconds = static_cast(ms) / 1000.0; - - // Compute throughput in GB/s (same accounting as Torch: 2 * elems * 8 bytes per exec) - const double gb_per_exec_once = 11 * gb_per_exec(dim0, dim1, dim2); - const double total_execs = static_cast(cfg.iter_count); // * static_cast(cfg.iter_batch); - const double gb_per_second = (total_execs * gb_per_exec_once) / seconds; - - // Cleanup - cufftDestroy(plan); - cudaFree(d_data); - cudaFree(d_kernel); - - return gb_per_second; -} - -int main(int argc, char** argv) { - const Config cfg = parse_args(argc, argv); - const auto sizes = get_fft_sizes(); - - const std::string output_name = "conv_padded_cufft.csv"; - std::ofstream out(output_name); - if (!out) { - std::cerr << "Failed to open output file: " << output_name << "\n"; - return 1; - } - - std::cout << "Running cuFFT tests with data size " << cfg.data_size - << ", iter_count " << cfg.iter_count - << ", iter_batch " << cfg.iter_batch - << ", run_count " << cfg.run_count << "\n"; - - // Header: Backend, FFT Size, Run 1..N, Mean, Std Dev - out << "Backend,FFT Size"; - for (int i = 0; i < cfg.run_count; ++i) out << ",Run " << (i + 1) << " (GB/s)"; - out << ",Mean,Std Dev\n"; - - for (int fft_size : sizes) { - std::vector rates; - rates.reserve(cfg.run_count); - - for (int r = 0; r < cfg.run_count; ++r) { - const double gbps = run_cufft_case(cfg, fft_size); - std::cout << "FFT Size: " << fft_size << ", Throughput: " << std::fixed << std::setprecision(2) - << gbps << " GB/s\n"; - rates.push_back(gbps); - } - - // Compute mean/std - double mean = 0.0; - for (double v : rates) mean += v; - mean /= static_cast(rates.size()); - - double var = 0.0; - for (double v : rates) { - const double d = v - mean; - var += d * d; - } - var /= static_cast(rates.size()); - const double stdev = std::sqrt(var); - - // Round to 2 decimals like your Torch script - out << "cufft," << fft_size; - out << std::fixed << std::setprecision(2); - for (double v : rates) out << "," << v; - out << "," << mean << "," << stdev << "\n"; - } - - std::cout << "Results saved to " << output_name << "\n"; - return 0; -} diff --git a/performance_tests/conv_padded_2d/conv_padded_cufft_callback.cu b/performance_tests/conv_padded_2d/conv_padded_cufft_callback.cu deleted file mode 100644 index 54b12578..00000000 --- a/performance_tests/conv_padded_2d/conv_padded_cufft_callback.cu +++ /dev/null @@ -1,297 +0,0 @@ -// actual_test_cuda.cu -// Usage: ./actual_test_cuda -// Output: fft_cuda__axis.csv with the same columns as your Torch script. -// -// Build (example): -// nvcc -O3 -std=c++17 actual_test_cuda.cu -lcufft -o actual_test_cuda - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -struct CallbackParams { - cufftComplex* filter; // device pointer, length = NX * NY - size_t NX; - size_t NY; - size_t signal_factor; // = NX * NY -}; - -__global__ void fill_randomish(cufftComplex* a, long long n){ - long long i = blockIdx.x * 1LL * blockDim.x + threadIdx.x; - if(i(callerInfo); - const size_t idxInImage = offset;// % (p->NX * p->NY); - - // Multiply element by filter[idxInImage] - const cufftComplex h = p->filter[idxInImage]; - cufftComplex y; - y.x = element.x * h.x - element.y * h.y; - y.y = element.x * h.y + element.y * h.x; - - static_cast(dataOut)[offset] = y; -} - -__device__ cufftCallbackStoreC d_store_cb_ptr = store_mul_cb; - -__device__ __noinline__ cufftComplex load_cb(void* dataOut, - size_t offset, - void* callerInfo, - void* /*sharedPtr*/) -{ - const CallbackParams* p = static_cast(callerInfo); - //const size_t idxInImage = offset; - - const size_t signal_size = p->NX / p->signal_factor; - - if (offset % p->NY >= signal_size || (offset / p->NY) % p->NX >= signal_size) { - return make_float2(0.f, 0.f); - - } - - return static_cast(dataOut)[offset]; -} - -__device__ cufftCallbackLoadC d_load_ptr = load_cb; - -static inline void checkCuda(cudaError_t err, const char* what) { - if (err != cudaSuccess) { - std::cerr << "[CUDA] " << what << " failed: " << cudaGetErrorString(err) << "\n"; - std::exit(1); - } -} - -static inline void checkCuFFT(cufftResult err, const char* what) { - if (err != CUFFT_SUCCESS) { - std::cerr << "[cuFFT] " << what << " failed: " << err << "\n"; - std::exit(1); - } -} - -struct Config { - long long data_size; - long long signal_factor; - int iter_count; - int iter_batch; - int run_count; - int warmup = 10; // match Torch script’s warmup -}; - -static Config parse_args(int argc, char** argv) { - if (argc != 6) { - std::cerr << "Usage: " << argv[0] - << " \n"; - std::exit(1); - } - Config c; - c.data_size = std::stoll(argv[1]); - c.signal_factor = std::stoll(argv[2]); - c.iter_count = std::stoi(argv[3]); - c.iter_batch = std::stoi(argv[4]); - c.run_count = std::stoi(argv[5]); - return c; -} - -static std::vector get_fft_sizes() { - std::vector sizes; - for (int p = 6; p <= 12; ++p) sizes.push_back(1 << p); // 64..4096 - return sizes; -} - -// Compute GB processed per single FFT execution (read + write) for shape (dim0, dim1) -static double gb_per_exec(long long dim0, long long dim1, long long dim2) { - // complex64 = 8 bytes; count both read and write -> *2 - const double bytes = static_cast(dim0) * static_cast(dim1) * static_cast(dim2) * 8.0; - return bytes / (1024.0 * 1024.0 * 1024.0); -} - -static double run_cufft_case(const Config& cfg, int fft_size) { - const size_t total_fft_area = fft_size * fft_size; - - const size_t dim0 = cfg.data_size / total_fft_area; - const size_t dim1 = fft_size; - const size_t dim2 = fft_size; - const size_t total_elems = dim0 * dim1 * dim2; - - // Device buffers (in-place transform will overwrite input) - cufftComplex* d_data = nullptr; - checkCuda(cudaMalloc(&d_data, total_elems * sizeof(cufftComplex)), "cudaMalloc d_data"); - // Optionally zero-fill - checkCuda(cudaMemset(d_data, 0, total_elems * sizeof(cufftComplex)), "cudaMemset d_data"); - - cufftComplex* d_kernel = nullptr; - checkCuda(cudaMalloc(&d_kernel, (total_elems) * sizeof(cufftComplex)), "cudaMalloc d_kernel"); - // Optionally zero-fill - checkCuda(cudaMemset(d_kernel, 0, (total_elems) * sizeof(cufftComplex)), "cudaMemset d_kernel"); - - { - int t = 256, b = int((total_elems + t - 1) / t); - fill_randomish<<>>(d_data, total_elems); - checkCuda(cudaGetLastError(), "fill launch"); - checkCuda(cudaDeviceSynchronize(), "fill sync"); - - int kt = 256, kb = int((total_elems + kt - 1) / kt); - fill_randomish<<>>(d_kernel, total_elems); - checkCuda(cudaGetLastError(), "fill kernel launch"); - checkCuda(cudaDeviceSynchronize(), "fill kernel sync"); - } - - CallbackParams h_params{ d_kernel, size_t(dim1), size_t(dim2), cfg.signal_factor }; - CallbackParams* d_params = nullptr; - checkCuda(cudaMalloc(&d_params, sizeof(CallbackParams)), "cudaMalloc params"); - checkCuda(cudaMemcpy(d_params, &h_params, sizeof(CallbackParams), cudaMemcpyHostToDevice), "cudaMemcpy params"); - - // --- plan bound to the stream --- - cufftHandle plans[2]; - checkCuFFT(cufftCreate(&plans[0]), "cufftCreate"); - checkCuFFT(cufftCreate(&plans[1]), "cufftCreate"); - - int n[2] = { int(dim1), int(dim2) }; - int inembed[2] = { int(dim1), int(dim2) }; // physical layout (same as n for tight pack) - int onembed[2] = { int(dim1), int(dim2) }; - int istride = 1; // contiguous within each 2D image - int ostride = 1; - int idist = int(dim1)* int(dim2); // distance between images - int odist = int(dim1)* int(dim2); - - checkCuFFT(cufftPlanMany(&plans[0], 2, n, - inembed, istride, idist, - onembed, ostride, odist, - CUFFT_C2C, int(dim0)), "plan2d"); - - checkCuFFT(cufftPlanMany(&plans[1], 2, n, - inembed, istride, idist, - onembed, ostride, odist, - CUFFT_C2C, int(dim0)), "plan2d"); - - cufftCallbackStoreC h_store_cb_ptr; - checkCuda(cudaMemcpyFromSymbol(&h_store_cb_ptr, d_store_cb_ptr, sizeof(h_store_cb_ptr)), "memcpy from symbol"); - - cufftCallbackLoadC h_load_ptr; - checkCuda(cudaMemcpyFromSymbol(&h_load_ptr, d_load_ptr, sizeof(h_load_ptr)), "memcpy from symbol"); - - void* cb_ptrs[1] = { (void*)h_store_cb_ptr }; - void* cb_data[1] = { (void*)d_params }; // single pointer: our params struct - checkCuFFT(cufftXtSetCallback(plans[0], cb_ptrs, CUFFT_CB_ST_COMPLEX, cb_data), "set callback"); - - void* cb_ptrs_ld[1] = { (void*)h_load_ptr }; - void* cb_data_ld[1] = { (void*)d_params }; // single pointer: our params struct - checkCuFFT(cufftXtSetCallback(plans[0], cb_ptrs_ld, CUFFT_CB_LD_COMPLEX, cb_data_ld), "load callback"); - - // --- warmup on the stream --- - for (int i = 0; i < cfg.warmup; ++i) { - checkCuFFT(cufftExecC2C(plans[0], d_data, d_data, CUFFT_FORWARD), "warmup"); - checkCuFFT(cufftExecC2C(plans[1], d_data, d_data, CUFFT_INVERSE), "warmup"); - } - - checkCuda(cudaDeviceSynchronize(), "warmup sync"); - - // === OPTION A: plain single-stream timing (simple & robust) === - cudaEvent_t evA, evB; - checkCuda(cudaEventCreate(&evA), "evA"); - checkCuda(cudaEventCreate(&evB), "evB"); - checkCuda(cudaEventRecord(evA), "record A"); - for (int it = 0; it < cfg.iter_count; ++it) { - checkCuFFT(cufftExecC2C(plans[0], d_data, d_data, CUFFT_FORWARD), "exec"); - checkCuFFT(cufftExecC2C(plans[1], d_data, d_data, CUFFT_INVERSE), "exec"); - } - checkCuda(cudaEventRecord(evB), "record B"); - checkCuda(cudaEventSynchronize(evB), "sync B"); - checkCuda(cudaDeviceSynchronize(), "warmup sync"); - float ms = 0.f; checkCuda(cudaEventElapsedTime(&ms, evA, evB), "elapsed"); - checkCuda(cudaEventDestroy(evA), "dA"); - checkCuda(cudaEventDestroy(evB), "dB"); - - // Convert elapsed to seconds - const double seconds = static_cast(ms) / 1000.0; - - // Compute throughput in GB/s (same accounting as Torch: 2 * elems * 8 bytes per exec) - const double gb_per_exec_once = 11 * gb_per_exec(dim0, dim1, dim2); - const double total_execs = static_cast(cfg.iter_count); // * static_cast(cfg.iter_batch); - const double gb_per_second = (total_execs * gb_per_exec_once) / seconds; - - // Cleanup - cufftDestroy(plans[0]); - cufftDestroy(plans[1]); - cudaFree(d_data); - cudaFree(d_kernel); - - return gb_per_second; -} - -int main(int argc, char** argv) { - const Config cfg = parse_args(argc, argv); - const auto sizes = get_fft_sizes(); - - const std::string output_name = "conv_padded_cufft_callback.csv"; - std::ofstream out(output_name); - if (!out) { - std::cerr << "Failed to open output file: " << output_name << "\n"; - return 1; - } - - std::cout << "Running cuFFT tests with data size " << cfg.data_size - << ", iter_count " << cfg.iter_count - << ", iter_batch " << cfg.iter_batch - << ", run_count " << cfg.run_count << "\n"; - - // Header: Backend, FFT Size, Run 1..N, Mean, Std Dev - out << "Backend,FFT Size"; - for (int i = 0; i < cfg.run_count; ++i) out << ",Run " << (i + 1) << " (GB/s)"; - out << ",Mean,Std Dev\n"; - - for (int fft_size : sizes) { - std::vector rates; - rates.reserve(cfg.run_count); - - for (int r = 0; r < cfg.run_count; ++r) { - const double gbps = run_cufft_case(cfg, fft_size); - std::cout << "FFT Size: " << fft_size << ", Throughput: " << std::fixed << std::setprecision(2) - << gbps << " GB/s\n"; - rates.push_back(gbps); - } - - // Compute mean/std - double mean = 0.0; - for (double v : rates) mean += v; - mean /= static_cast(rates.size()); - - double var = 0.0; - for (double v : rates) { - const double d = v - mean; - var += d * d; - } - var /= static_cast(rates.size()); - const double stdev = std::sqrt(var); - - // Round to 2 decimals like your Torch script - out << "cufft_callback," << fft_size; - out << std::fixed << std::setprecision(2); - for (double v : rates) out << "," << v; - out << "," << mean << "," << stdev << "\n"; - } - - std::cout << "Results saved to " << output_name << "\n"; - return 0; -} diff --git a/performance_tests/conv_padded_2d/conv_padded_make_graph.py b/performance_tests/conv_padded_2d/conv_padded_make_graph.py deleted file mode 100644 index 2e9c79fc..00000000 --- a/performance_tests/conv_padded_2d/conv_padded_make_graph.py +++ /dev/null @@ -1,92 +0,0 @@ -import glob -import csv -from typing import Dict, Tuple, Set -from matplotlib import pyplot as plt -import numpy as np -import sys - -# Nested structure: -# merged[backend][fft_size] = (mean, std) -MergedType = Dict[str, Dict[int, Tuple[float, float]]] - -def read_bench_csvs() -> Tuple[MergedType, Set[str], Set[int]]: - pattern = f"conv_padded_*.csv" - files = glob.glob(pattern) - - merged: MergedType = {} - backends: Set[str] = set() - fft_sizes: Set[int] = set() - - for filename in files: - print(f"Reading: {filename}") - with open(filename, newline="") as f: - reader = csv.DictReader(f) - for row in reader: - backend = row["Backend"].strip() - size = int(row["FFT Size"]) - mean = float(row["Mean"]) - std = float(row["Std Dev"]) - - backends.add(backend) - fft_sizes.add(size) - - if backend not in merged: - merged[backend] = {} - - # last one wins if duplicates appear across files - merged[backend][size] = (mean, std) - - return merged, backends, fft_sizes - -def save_graph(backends: Set[str], fft_sizes: Set[int], merged: MergedType, min_fft_size: int = None): - plt.figure(figsize=(10, 6)) - - if min_fft_size is not None: - used_fft_sizes = [size for size in fft_sizes if size >= min_fft_size] - else: - used_fft_sizes = fft_sizes - - for backend_name in backends: - means = [ - merged[backend_name][i][0] - for i in used_fft_sizes - ] - stds = [ - merged[backend_name][i][1] - for i in used_fft_sizes - ] - - plt.errorbar( - used_fft_sizes, - means, - yerr=stds, - label=backend_name, - capsize=5, - ) - plt.xscale('log', base=2) - plt.xlabel('Convolution Size') - plt.ylabel('GB/s') - plt.title('Padded Convolution Performance Comparison') - plt.legend() - plt.grid(True) - if min_fft_size is not None: - plt.savefig(f"conv_padded_graph_min_size{min_fft_size}.png") - return - plt.savefig(f"conv_padded_graph.png") - -if __name__ == "__main__": - # Example usage (change the number as needed) - merged, backends, fft_sizes = read_bench_csvs() - - print("\nSummary:") - print(f"Backends found: {sorted(backends)}") - print(f"Convolution sizes found: {sorted(fft_sizes)}") - print(f"Total entries: {sum(len(v) for v in merged.values())}") - - sorted_backends = sorted(backends) - sorted_fft_sizes = sorted(fft_sizes) - - save_graph(sorted_backends, sorted_fft_sizes, merged) - #save_graph(sorted_backends, sorted_fft_sizes, merged, min_fft_size=256) - - diff --git a/performance_tests/conv_padded_2d/conv_padded_torch.py b/performance_tests/conv_padded_2d/conv_padded_torch.py deleted file mode 100644 index 772042a1..00000000 --- a/performance_tests/conv_padded_2d/conv_padded_torch.py +++ /dev/null @@ -1,94 +0,0 @@ -import csv -import time -import conv_padded_utils as fu -import numpy as np -import torch - -def run_torch(config: fu.Config, fft_size: int) -> float: - shape = config.make_shape(fft_size) - random_data = config.make_random_data(fft_size) - random_data_kernel = config.make_random_data(fft_size) - - signal_size = fft_size // config.signal_factor - - signal_shape = (shape[0], signal_size, signal_size) - - buffer = torch.empty( - signal_shape, - dtype=torch.complex64, - device='cuda' - ) - - buffer_out = torch.empty( - shape, - dtype=torch.complex64, - device='cuda' - ) - - kernel = torch.empty( - shape, - dtype=torch.complex64, - device='cuda' - ) - - buffer.copy_(torch.from_numpy(random_data[:, :signal_size, :signal_size]).to('cuda')) - buffer_out.copy_(torch.from_numpy(random_data).to('cuda')) - kernel.copy_(torch.from_numpy(random_data_kernel).to('cuda')) - - stream = torch.cuda.Stream() - - torch.cuda.synchronize() - - - - with torch.cuda.stream(stream): - for _ in range(config.warmup): - buffer_out = torch.fft.ifft2(torch.fft.fft2(buffer, s=(fft_size, fft_size)) * kernel) - - torch.cuda.synchronize() - - gb_byte_count = 11 * np.prod(shape) * 8 / (1024 * 1024 * 1024) - - g = torch.cuda.CUDAGraph() - - # We capture either 1 or K FFTs back-to-back. All on the same stream. - with torch.cuda.graph(g, stream=stream): - for _ in range(max(1, config.iter_batch)): - buffer_out = torch.fft.ifft2(torch.fft.fft2(buffer, s=(fft_size, fft_size)) * kernel) - - torch.cuda.synchronize() - start_time = time.perf_counter() - - with torch.cuda.stream(stream): - for _ in range(config.iter_count // max(1, config.iter_batch)): - g.replay() - - torch.cuda.synchronize() - elapsed_time = time.perf_counter() - start_time - - return config.iter_count * gb_byte_count / elapsed_time - -if __name__ == "__main__": - config = fu.parse_args() - fft_sizes = fu.get_fft_sizes() - - output_name = f"conv_padded_torch.csv" - with open(output_name, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(['Backend', 'FFT Size'] + [f'Run {i + 1} (GB/s)' for i in range(config.run_count)] + ['Mean', 'Std Dev']) - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_per_second = run_torch(config, fft_size) - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.2f} GB/s") - rates.append(gb_per_second) - - rounded_data = [round(rate, 2) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow(["torch", fft_size] + rounded_data + [rounded_mean, rounded_std]) - - print(f"Results saved to {output_name}.csv") \ No newline at end of file diff --git a/performance_tests/conv_padded_2d/conv_padded_utils.py b/performance_tests/conv_padded_2d/conv_padded_utils.py deleted file mode 100644 index ebaef5fe..00000000 --- a/performance_tests/conv_padded_2d/conv_padded_utils.py +++ /dev/null @@ -1,40 +0,0 @@ -import sys -from typing import Tuple -import dataclasses - -import numpy as np - -@dataclasses.dataclass -class Config: - data_size: int - signal_factor: int - iter_count: int - iter_batch: int - run_count: int - warmup: int = 10 - - def make_shape(self, fft_size: int) -> Tuple[int, ...]: - total_square_size = fft_size * fft_size - assert self.data_size % total_square_size == 0, "Data size must be a multiple of fft_size squared" - return (self.data_size // total_square_size, fft_size, fft_size) - - def make_random_data(self, fft_size: int): - shape = self.make_shape(fft_size) - return np.random.rand(*shape).astype(np.complex64) - -def parse_args() -> Config: - if len(sys.argv) != 6: - print(f"Usage: {sys.argv[0]} ") - sys.exit(1) - - return Config( - data_size=int(sys.argv[1]), - signal_factor=int(sys.argv[2]), - iter_count=int(sys.argv[3]), - iter_batch=int(sys.argv[4]), - run_count=int(sys.argv[5]), - ) - -def get_fft_sizes(): - return [2**i for i in range(6, 13)] # FFT sizes from 64 to 4096 (inclusive) - diff --git a/performance_tests/conv_padded_2d/conv_padded_vkdispatch.py b/performance_tests/conv_padded_2d/conv_padded_vkdispatch.py deleted file mode 100644 index 505022a4..00000000 --- a/performance_tests/conv_padded_2d/conv_padded_vkdispatch.py +++ /dev/null @@ -1,174 +0,0 @@ -import csv -import time -import conv_padded_utils as fu -import vkdispatch as vd -import vkdispatch.codegen as vc -import numpy as np - -def padded_cross_correlation( - buffer: vd.Buffer, - kernel: vd.Buffer, - signal_shape: tuple, - graph: vd.CommandGraph): - - - # Fill input buffer with zeros where needed - @vd.map_registers([vc.c64]) - def initial_input_mapping(input_buffer: vc.Buffer[vc.c64]): - vc.if_statement(vc.mapping_index() % buffer.shape[2] < signal_shape[1]) - - in_layer_index = vc.mapping_index() % (signal_shape[1] * buffer.shape[2]) - out_layer_index = vc.mapping_index() / (signal_shape[1] * buffer.shape[2]) - actual_index = in_layer_index + out_layer_index * (buffer.shape[1] * buffer.shape[2]) - - vc.mapping_registers()[0][:] = input_buffer[actual_index] - vc.else_statement() - vc.mapping_registers()[0][:] = "vec2(0)" - vc.end() - - # Remap output indicies to match the actual buffer shape - @vd.map_registers([vc.c64]) - def initial_output_mapping(output_buffer: vc.Buffer[vc.c64]): - in_layer_index = vc.mapping_index() % (signal_shape[1] * buffer.shape[2]) - out_layer_index = vc.mapping_index() / (signal_shape[1] * buffer.shape[2]) - actual_index = in_layer_index + out_layer_index * (buffer.shape[1] * buffer.shape[2]) - output_buffer[actual_index] = vc.mapping_registers()[0] - - # Do the first FFT on the correlation buffer accross the first axis - vd.fft.fft( - buffer, - buffer, - buffer_shape=( - buffer.shape[0], - signal_shape[1], - buffer.shape[2] - ), - input_map=initial_input_mapping, - output_map=initial_output_mapping, - graph=graph - ) - - # Again, we skip reading the zero-padded values from the input - @vd.map_registers([vc.c64]) - def input_mapping(input_buffer: vc.Buffer[vc.c64]): - in_layer_index = vc.mapping_index() % ( - buffer.shape[1] * buffer.shape[2] - ) - - vc.if_statement(in_layer_index / buffer.shape[2] < signal_shape[1]) - vc.mapping_registers()[0][:] = input_buffer[vc.mapping_index()] - vc.else_statement() - vc.mapping_registers()[0][:] = "vec2(0)" - vc.end() - - @vd.map_registers([vc.c64]) - def kernel_mapping(kernel_buffer: vc.Buffer[vc.c64]): - img_val = vc.mapping_registers()[0] - read_register = vc.mapping_registers()[1] - - # Calculate the invocation within this FFT batch - in_group_index = vc.local_invocation().y * vc.workgroup_size().x + vc.local_invocation().x - out_group_index = vc.workgroup().y * vc.num_workgroups().x + vc.workgroup().x - workgroup_index = in_group_index + out_group_index * ( - vc.workgroup_size().x * vc.workgroup_size().y - ) - - # Calculate the batch index of the FFT - batch_index = ( - vc.mapping_index() - ) / ( - vc.workgroup_size().x * vc.workgroup_size().y * - vc.num_workgroups().x * vc.num_workgroups().y - ) - - # Calculate the transposed index - transposed_index = workgroup_index + batch_index * ( - vc.workgroup_size().x * vc.workgroup_size().y * - vc.num_workgroups().x * vc.num_workgroups().y - ) - - read_register[:] = kernel_buffer[transposed_index] - img_val[:] = vc.mult_conj_c64(read_register, img_val) - - vd.fft.convolve( - buffer, - buffer, - kernel, - input_map=input_mapping, - kernel_map=kernel_mapping, - axis=1, - graph=graph - ) - - vd.fft.ifft(buffer, graph=graph) - -def run_vkdispatch(config: fu.Config, fft_size: int) -> float: - shape = config.make_shape(fft_size) - random_data = config.make_random_data(fft_size) - random_data_2 = config.make_random_data(fft_size) - - buffer = vd.Buffer(shape, var_type=vd.complex64) - buffer.write(random_data) - - kernel = vd.Buffer(shape, var_type=vd.complex64) - kernel.write(random_data_2) - - graph = vd.CommandGraph() - - signal_size = fft_size // config.signal_factor - - padded_cross_correlation(buffer, kernel, (signal_size, signal_size), graph) - - for _ in range(config.warmup): - graph.submit(config.iter_batch) - - vd.queue_wait_idle() - - gb_byte_count = 11 * 8 * buffer.size / (1024 * 1024 * 1024) - - start_time = time.perf_counter() - - for _ in range(config.iter_count // config.iter_batch): - graph.submit(config.iter_batch) - - vd.queue_wait_idle() - - elapsed_time = time.perf_counter() - start_time - - buffer.destroy() - graph.destroy() - vd.fft.cache_clear() - - time.sleep(1) - - vd.queue_wait_idle() - - return config.iter_count * gb_byte_count / elapsed_time - -if __name__ == "__main__": - config = fu.parse_args() - fft_sizes = fu.get_fft_sizes() - - output_name = f"conv_padded_vkdispatch.csv" - with open(output_name, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(['Backend', 'FFT Size'] + [f'Run {i + 1} (GB/s)' for i in range(config.run_count)] + ['Mean', 'Std Dev']) - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_per_second = run_vkdispatch(config, fft_size) - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.2f} GB/s") - rates.append(gb_per_second) - - rounded_data = [round(rate, 2) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow(["vkdispatch", fft_size] + rounded_data + [rounded_mean, rounded_std]) - - print(f"Results saved to {output_name}.csv") - - - \ No newline at end of file diff --git a/performance_tests/conv_padded_2d/conv_padded_zipfft.py b/performance_tests/conv_padded_2d/conv_padded_zipfft.py deleted file mode 100644 index 54b8b12a..00000000 --- a/performance_tests/conv_padded_2d/conv_padded_zipfft.py +++ /dev/null @@ -1,97 +0,0 @@ -import csv -import time -import conv_padded_utils as fu -import numpy as np -import torch - -try: - from zipfft import cfft1d - from zipfft import conv1d_strided_padded - from zipfft import padded_fft1d -except ImportError: - print("zipfft is not installed. Please install it via 'pip install zipfft'.") - exit(0) - -def run_zipfft(config: fu.Config, fft_size: int) -> float: - shape = config.make_shape(fft_size) - random_data = config.make_random_data(fft_size) - - buffer = torch.empty( - shape, - dtype=torch.complex64, - device='cuda' - ) - - - kernel = torch.empty( - shape, - dtype=torch.complex64, - device='cuda' - ) - - buffer.copy_(torch.from_numpy(random_data).to('cuda')) - kernel.copy_(torch.from_numpy(random_data).to('cuda')) - - stream = torch.cuda.Stream() - - signal_size = fft_size // config.signal_factor - - torch.cuda.synchronize() - - with torch.cuda.stream(stream): - for _ in range(config.warmup): - padded_fft1d.pfft_layered(buffer, signal_size, signal_size) - conv1d_strided_padded.conv(buffer, kernel, signal_size) - cfft1d.ifft(buffer.view(-1, buffer.size(2))) - - - torch.cuda.synchronize() - - g = torch.cuda.CUDAGraph() - - # We capture either 1 or K FFTs back-to-back. All on the same stream. - with torch.cuda.graph(g, stream=stream): - for _ in range(max(1, config.iter_batch)): - padded_fft1d.pfft_layered(buffer, signal_size, signal_size) - conv1d_strided_padded.conv(buffer, kernel, signal_size) - cfft1d.ifft(buffer.view(-1, buffer.size(2))) - - torch.cuda.synchronize() - - gb_byte_count = 11 * np.prod(shape) * 8 / (1024 * 1024 * 1024) - - start_time = time.perf_counter() - - for _ in range(config.iter_count // max(1, config.iter_batch)): - g.replay() - - torch.cuda.synchronize() - - elapsed_time = time.perf_counter() - start_time - - return config.iter_count * gb_byte_count / elapsed_time - -if __name__ == "__main__": - config = fu.parse_args() - fft_sizes = fu.get_fft_sizes() - - output_name = f"conv_padded_zipfft.csv" - with open(output_name, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(['Backend', 'FFT Size'] + [f'Run {i + 1} (GB/s)' for i in range(config.run_count)] + ['Mean', 'Std Dev']) - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_per_second = run_zipfft(config, fft_size) - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.2f} GB/s") - rates.append(gb_per_second) - - rounded_data = [round(rate, 2) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow(["zipfft", fft_size] + rounded_data + [rounded_mean, rounded_std]) - - print(f"Results saved to {output_name}.csv") \ No newline at end of file diff --git a/performance_tests/conv_padded_2d/run_tests.sh b/performance_tests/conv_padded_2d/run_tests.sh deleted file mode 100644 index f111bbbf..00000000 --- a/performance_tests/conv_padded_2d/run_tests.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash - -mkdir -p test_results - -cd test_results - -#DATA_SIZE=134217728 -DATA_SIZE=67108864 -#DATA_SIZE=33554432 -SIGNAL_FACTOR=8 -ITER_COUNT=150 -BATCH_SIZE=10 -REPEATS=4 - -# /usr/local/cuda/bin/nvcc -O2 -std=c++17 ../conv_padded_cufft.cu -rdc=true -lcufft_static -lculibos -o conv_padded_cufft.exec -# /usr/local/cuda/bin/nvcc -O2 -std=c++17 ../conv_padded_cufft_callback.cu -rdc=true -lcufft_static -lculibos -o conv_padded_cufft_callback.exec - -echo "Running performance tests with the following parameters:" -echo "Data Size: $DATA_SIZE" -echo "Signal Factor: $SIGNAL_FACTOR" -echo "Iteration Count: $ITER_COUNT" -echo "Batch Size: $BATCH_SIZE" -echo "Repeats: $REPEATS" - -echo "Running Vkdispatch FFT..." -python3 ../conv_padded_vkdispatch.py $DATA_SIZE $SIGNAL_FACTOR $ITER_COUNT $BATCH_SIZE $REPEATS - -# echo "Running cuFFT FFT..." -# ./conv_padded_cufft.exec $DATA_SIZE $SIGNAL_FACTOR $ITER_COUNT $BATCH_SIZE $REPEATS - -# echo "Running cuFFT callback FFT..." -# ./conv_padded_cufft_callback.exec $DATA_SIZE $SIGNAL_FACTOR $ITER_COUNT $BATCH_SIZE $REPEATS - -# echo "Running PyTorch FFT..." -# python3 ../conv_padded_torch.py $DATA_SIZE $SIGNAL_FACTOR $ITER_COUNT $BATCH_SIZE $REPEATS - -# echo "Running ZipFFT FFT..." -# python3 ../conv_padded_zipfft.py $DATA_SIZE $SIGNAL_FACTOR $ITER_COUNT $BATCH_SIZE $REPEATS - -python3 ../conv_padded_make_graph.py \ No newline at end of file diff --git a/performance_tests/conv_padded_2d/run_tests_old.sh b/performance_tests/conv_padded_2d/run_tests_old.sh deleted file mode 100644 index 48f4cdee..00000000 --- a/performance_tests/conv_padded_2d/run_tests_old.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash - -mkdir -p test_results - -cd test_results - -DATA_SIZE=134217728 -#DATA_SIZE=33554432 #134217728 -SIGNAL_FACTOR=8 -ITER_COUNT=200 -BATCH_SIZE=10 -REPEATS=5 - -/usr/local/cuda/bin/nvcc -O2 -std=c++17 ../conv_padded_cufft.cu -rdc=true -lcufft_static -lculibos -o conv_padded_cufft.exec -/usr/local/cuda/bin/nvcc -O2 -std=c++17 ../conv_padded_cufft_callback.cu -rdc=true -lcufft_static -lculibos -o conv_padded_cufft_callback.exec - -echo "Running performance tests with the following parameters:" -echo "Data Size: $DATA_SIZE" -echo "Signal Factor: $SIGNAL_FACTOR" -echo "Iteration Count: $ITER_COUNT" -echo "Batch Size: $BATCH_SIZE" -echo "Repeats: $REPEATS" - -echo "Running Vkdispatch FFT..." -python3 ../conv_padded_vkdispatch.py $DATA_SIZE $SIGNAL_FACTOR $ITER_COUNT $BATCH_SIZE $REPEATS - -echo "Running cuFFT FFT..." -./conv_padded_cufft.exec $DATA_SIZE $SIGNAL_FACTOR $ITER_COUNT $BATCH_SIZE $REPEATS - -echo "Running cuFFT callback FFT..." -./conv_padded_cufft_callback.exec $DATA_SIZE $SIGNAL_FACTOR $ITER_COUNT $BATCH_SIZE $REPEATS - -echo "Running PyTorch FFT..." -python3 ../conv_padded_torch.py $DATA_SIZE $SIGNAL_FACTOR $ITER_COUNT $BATCH_SIZE $REPEATS - -echo "Running ZipFFT FFT..." -python3 ../conv_padded_zipfft.py $DATA_SIZE $SIGNAL_FACTOR $ITER_COUNT $BATCH_SIZE $REPEATS - -python3 ../conv_padded_make_graph.py \ No newline at end of file diff --git a/performance_tests/fft_2d/fft_cufft.cu b/performance_tests/fft_2d/fft_cufft.cu deleted file mode 100644 index 3ce18d9b..00000000 --- a/performance_tests/fft_2d/fft_cufft.cu +++ /dev/null @@ -1,208 +0,0 @@ -// actual_test_cuda.cu -// Usage: ./actual_test_cuda -// Output: fft_cuda__axis.csv with the same columns as your Torch script. -// -// Build (example): -// nvcc -O3 -std=c++17 actual_test_cuda.cu -lcufft -o actual_test_cuda - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -__global__ void fill_randomish(cufftComplex* a, long long n){ - long long i = blockIdx.x * 1LL * blockDim.x + threadIdx.x; - if(i \n"; - std::exit(1); - } - Config c; - c.data_size = std::stoll(argv[1]); - c.iter_count = std::stoi(argv[2]); - c.iter_batch = std::stoi(argv[3]); - c.run_count = std::stoi(argv[4]); - return c; -} - -static std::vector get_fft_sizes() { - std::vector sizes; - for (int p = 6; p <= 12; ++p) sizes.push_back(1 << p); // 64..4096 - return sizes; -} - -// Compute GB processed per single FFT execution (read + write) for shape (dim0, dim1) -static double gb_per_exec(long long dim0, long long dim1, long long dim2) { - // complex64 = 8 bytes; count both read and write -> *2 - const double bytes = 2.0 * static_cast(dim0) * static_cast(dim1) * static_cast(dim2) * 8.0; - return bytes / (1024.0 * 1024.0 * 1024.0); -} - -static double run_cufft_case(const Config& cfg, int fft_size) { - const long long total_fft_area = fft_size * fft_size; - - const long long dim0 = cfg.data_size / total_fft_area; - const long long dim1 = fft_size; - const long long dim2 = fft_size; - const long long total_elems = dim0 * dim1 * dim2; - - // Device buffers (in-place transform will overwrite input) - cufftComplex* d_data = nullptr; - checkCuda(cudaMalloc(&d_data, total_elems * sizeof(cufftComplex)), "cudaMalloc d_data"); - // Optionally zero-fill - checkCuda(cudaMemset(d_data, 0, total_elems * sizeof(cufftComplex)), "cudaMemset d_data"); - - { - int t = 256, b = int((total_elems + t - 1) / t); - fill_randomish<<>>(d_data, total_elems); - checkCuda(cudaGetLastError(), "fill launch"); - checkCuda(cudaDeviceSynchronize(), "fill sync"); - } - - // --- plan bound to the stream --- - cufftHandle plan; - checkCuFFT(cufftCreate(&plan), "cufftCreate"); - - int n[2] = { int(dim1), int(dim2) }; - int inembed[2] = { int(dim1), int(dim2) }; // physical layout (same as n for tight pack) - int onembed[2] = { int(dim1), int(dim2) }; - int istride = 1; // contiguous within each 2D image - int ostride = 1; - int idist = int(dim1)* int(dim2); // distance between images - int odist = int(dim1)* int(dim2); - - checkCuFFT(cufftPlanMany(&plan, 2, n, - inembed, istride, idist, - onembed, ostride, odist, - CUFFT_C2C, int(dim0)), "plan2d"); - - // --- warmup on the stream --- - for (int i = 0; i < cfg.warmup; ++i) - checkCuFFT(cufftExecC2C(plan, d_data, d_data, CUFFT_FORWARD), "warmup"); - - checkCuda(cudaDeviceSynchronize(), "warmup sync"); - - // === OPTION A: plain single-stream timing (simple & robust) === - cudaEvent_t evA, evB; - checkCuda(cudaEventCreate(&evA), "evA"); - checkCuda(cudaEventCreate(&evB), "evB"); - checkCuda(cudaEventRecord(evA), "record A"); - for (int it = 0; it < cfg.iter_count; ++it) - checkCuFFT(cufftExecC2C(plan, d_data, d_data, CUFFT_FORWARD), "exec"); - checkCuda(cudaEventRecord(evB), "record B"); - checkCuda(cudaEventSynchronize(evB), "sync B"); - checkCuda(cudaDeviceSynchronize(), "warmup sync"); - float ms = 0.f; checkCuda(cudaEventElapsedTime(&ms, evA, evB), "elapsed"); - checkCuda(cudaEventDestroy(evA), "dA"); - checkCuda(cudaEventDestroy(evB), "dB"); - - // Convert elapsed to seconds - const double seconds = static_cast(ms) / 1000.0; - - // Compute throughput in GB/s (same accounting as Torch: 2 * elems * 8 bytes per exec) - const double gb_per_exec_once = 2 * gb_per_exec(dim0, dim1, dim2); - const double total_execs = static_cast(cfg.iter_count); // * static_cast(cfg.iter_batch); - const double gb_per_second = (total_execs * gb_per_exec_once) / seconds; - - // Cleanup - cufftDestroy(plan); - cudaFree(d_data); - - return gb_per_second; -} - -int main(int argc, char** argv) { - const Config cfg = parse_args(argc, argv); - const auto sizes = get_fft_sizes(); - - const std::string output_name = "fft_cufft.csv"; - std::ofstream out(output_name); - if (!out) { - std::cerr << "Failed to open output file: " << output_name << "\n"; - return 1; - } - - std::cout << "Running cuFFT tests with data size " << cfg.data_size - << ", iter_count " << cfg.iter_count - << ", iter_batch " << cfg.iter_batch - << ", run_count " << cfg.run_count << "\n"; - - // Header: Backend, FFT Size, Run 1..N, Mean, Std Dev - out << "Backend,FFT Size"; - for (int i = 0; i < cfg.run_count; ++i) out << ",Run " << (i + 1) << " (GB/s)"; - out << ",Mean,Std Dev\n"; - - for (int fft_size : sizes) { - std::vector rates; - rates.reserve(cfg.run_count); - - for (int r = 0; r < cfg.run_count; ++r) { - const double gbps = run_cufft_case(cfg, fft_size); - std::cout << "FFT Size: " << fft_size << ", Throughput: " << std::fixed << std::setprecision(2) - << gbps << " GB/s\n"; - rates.push_back(gbps); - } - - // Compute mean/std - double mean = 0.0; - for (double v : rates) mean += v; - mean /= static_cast(rates.size()); - - double var = 0.0; - for (double v : rates) { - const double d = v - mean; - var += d * d; - } - var /= static_cast(rates.size()); - const double stdev = std::sqrt(var); - - // Round to 2 decimals like your Torch script - out << "cufft," << fft_size; - out << std::fixed << std::setprecision(2); - for (double v : rates) out << "," << v; - out << "," << mean << "," << stdev << "\n"; - } - - std::cout << "Results saved to " << output_name << "\n"; - return 0; -} diff --git a/performance_tests/fft_2d/fft_make_graph.py b/performance_tests/fft_2d/fft_make_graph.py deleted file mode 100644 index 2284d0c2..00000000 --- a/performance_tests/fft_2d/fft_make_graph.py +++ /dev/null @@ -1,92 +0,0 @@ -import glob -import csv -from typing import Dict, Tuple, Set -from matplotlib import pyplot as plt -import numpy as np -import sys - -# Nested structure: -# merged[backend][fft_size] = (mean, std) -MergedType = Dict[str, Dict[int, Tuple[float, float]]] - -def read_bench_csvs() -> Tuple[MergedType, Set[str], Set[int]]: - pattern = f"fft_*.csv" - files = glob.glob(pattern) - - merged: MergedType = {} - backends: Set[str] = set() - fft_sizes: Set[int] = set() - - for filename in files: - print(f"Reading: {filename}") - with open(filename, newline="") as f: - reader = csv.DictReader(f) - for row in reader: - backend = row["Backend"].strip() - size = int(row["FFT Size"]) - mean = float(row["Mean"]) - std = float(row["Std Dev"]) - - backends.add(backend) - fft_sizes.add(size) - - if backend not in merged: - merged[backend] = {} - - # last one wins if duplicates appear across files - merged[backend][size] = (mean, std) - - return merged, backends, fft_sizes - -def save_graph(backends: Set[str], fft_sizes: Set[int], merged: MergedType, min_fft_size: int = None): - plt.figure(figsize=(10, 6)) - - if min_fft_size is not None: - used_fft_sizes = [size for size in fft_sizes if size >= min_fft_size] - else: - used_fft_sizes = fft_sizes - - for backend_name in backends: - means = [ - merged[backend_name][i][0] - for i in used_fft_sizes - ] - stds = [ - merged[backend_name][i][1] - for i in used_fft_sizes - ] - - plt.errorbar( - used_fft_sizes, - means, - yerr=stds, - label=backend_name, - capsize=5, - ) - plt.xscale('log', base=2) - plt.xlabel('FFT Size') - plt.ylabel('GB/s') - plt.title('FFT Performance Comparison') - plt.legend() - plt.grid(True) - if min_fft_size is not None: - plt.savefig(f"fft_graph_min_size{min_fft_size}.png") - return - plt.savefig(f"fft_graph.png") - -if __name__ == "__main__": - # Example usage (change the number as needed) - merged, backends, fft_sizes = read_bench_csvs() - - print("\nSummary:") - print(f"Backends found: {sorted(backends)}") - print(f"FFT sizes found: {sorted(fft_sizes)}") - print(f"Total entries: {sum(len(v) for v in merged.values())}") - - sorted_backends = sorted(backends) - sorted_fft_sizes = sorted(fft_sizes) - - save_graph(sorted_backends, sorted_fft_sizes, merged) - save_graph(sorted_backends, sorted_fft_sizes, merged, min_fft_size=256) - - diff --git a/performance_tests/fft_2d/fft_torch.py b/performance_tests/fft_2d/fft_torch.py deleted file mode 100644 index af3162d1..00000000 --- a/performance_tests/fft_2d/fft_torch.py +++ /dev/null @@ -1,73 +0,0 @@ -import csv -import time -import ffts_utils as fu -import numpy as np -import torch - -def run_torch(config: fu.Config, fft_size: int) -> float: - shape = config.make_shape(fft_size) - random_data = config.make_random_data(fft_size) - - buffer = torch.empty( - shape, - dtype=torch.complex64, - device='cuda' - ) - - buffer.copy_(torch.from_numpy(random_data).to('cuda')) - - stream = torch.cuda.Stream() - - torch.cuda.synchronize() - - with torch.cuda.stream(stream): - for _ in range(config.warmup): - buffer = torch.fft.fft2(buffer) - - torch.cuda.synchronize() - - gb_byte_count = 4 * np.prod(shape) * 8 / (1024 * 1024 * 1024) - - g = torch.cuda.CUDAGraph() - - # We capture either 1 or K FFTs back-to-back. All on the same stream. - with torch.cuda.graph(g, stream=stream): - for _ in range(max(1, config.iter_batch)): - buffer = torch.fft.fft2(buffer) # creates a tensor once during capture - - torch.cuda.synchronize() - start_time = time.perf_counter() - - with torch.cuda.stream(stream): - for _ in range(config.iter_count // max(1, config.iter_batch)): - g.replay() - - torch.cuda.synchronize() - elapsed_time = time.perf_counter() - start_time - - return config.iter_count * gb_byte_count / elapsed_time - -if __name__ == "__main__": - config = fu.parse_args() - fft_sizes = fu.get_fft_sizes() - - output_name = f"fft_torch.csv" - with open(output_name, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(['Backend', 'FFT Size'] + [f'Run {i + 1} (GB/s)' for i in range(config.run_count)] + ['Mean', 'Std Dev']) - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_per_second = run_torch(config, fft_size) - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.2f} GB/s") - rates.append(gb_per_second) - - rounded_data = [round(rate, 2) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow(["torch", fft_size] + rounded_data + [rounded_mean, rounded_std]) - - print(f"Results saved to {output_name}.csv") \ No newline at end of file diff --git a/performance_tests/fft_2d/fft_vkdispatch.py b/performance_tests/fft_2d/fft_vkdispatch.py deleted file mode 100644 index 4444a45f..00000000 --- a/performance_tests/fft_2d/fft_vkdispatch.py +++ /dev/null @@ -1,70 +0,0 @@ -import csv -import time -import ffts_utils as fu -import vkdispatch as vd -import numpy as np - -def run_vkdispatch(config: fu.Config, fft_size: int) -> float: - shape = config.make_shape(fft_size) - random_data = config.make_random_data(fft_size) - - buffer = vd.Buffer(shape, var_type=vd.complex64) - buffer.write(random_data) - - graph = vd.CommandGraph() - - vd.fft.fft2(buffer, graph=graph) - - for _ in range(config.warmup): - graph.submit(config.iter_batch) - - vd.queue_wait_idle() - - gb_byte_count = 4 * 8 * buffer.size / (1024 * 1024 * 1024) - - start_time = time.perf_counter() - - for _ in range(config.iter_count // config.iter_batch): - graph.submit(config.iter_batch) - - vd.queue_wait_idle() - - elapsed_time = time.perf_counter() - start_time - - buffer.destroy() - graph.destroy() - vd.fft.cache_clear() - - time.sleep(1) - - vd.queue_wait_idle() - - return config.iter_count * gb_byte_count / elapsed_time - -if __name__ == "__main__": - config = fu.parse_args() - fft_sizes = fu.get_fft_sizes() - - output_name = f"fft_vkdispatch.csv" - with open(output_name, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(['Backend', 'FFT Size'] + [f'Run {i + 1} (GB/s)' for i in range(config.run_count)] + ['Mean', 'Std Dev']) - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_per_second = run_vkdispatch(config, fft_size) - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.2f} GB/s") - rates.append(gb_per_second) - - rounded_data = [round(rate, 2) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow(["vkdispatch", fft_size] + rounded_data + [rounded_mean, rounded_std]) - - print(f"Results saved to {output_name}.csv") - - - \ No newline at end of file diff --git a/performance_tests/fft_2d/fft_vkfft.py b/performance_tests/fft_2d/fft_vkfft.py deleted file mode 100644 index 5ca93a81..00000000 --- a/performance_tests/fft_2d/fft_vkfft.py +++ /dev/null @@ -1,66 +0,0 @@ -import csv -import time -import ffts_utils as fu -import vkdispatch as vd -import numpy as np - -def run_vkfft(config: fu.Config, fft_size: int) -> float: - shape = config.make_shape(fft_size) - random_data = config.make_random_data(fft_size) - - buffer = vd.Buffer(shape, var_type=vd.complex64) - buffer.write(random_data) - graph = vd.CommandGraph() - - vd.vkfft.fft2(buffer, graph=graph) - - for _ in range(config.warmup): - graph.submit(config.iter_batch) - - vd.queue_wait_idle() - - gb_byte_count = 4 * 8 * buffer.size / (1024 * 1024 * 1024) - - start_time = time.perf_counter() - - for _ in range(config.iter_count // config.iter_batch): - graph.submit(config.iter_batch) - - vd.queue_wait_idle() - - elapsed_time = time.perf_counter() - start_time - - buffer.destroy() - graph.destroy() - vd.vkfft.clear_plan_cache() - - time.sleep(1) - - vd.queue_wait_idle() - - return config.iter_count * gb_byte_count / elapsed_time - -if __name__ == "__main__": - config = fu.parse_args() - fft_sizes = fu.get_fft_sizes() - - output_name = f"fft_vkfft.csv" - with open(output_name, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(['Backend', 'FFT Size'] + [f'Run {i + 1} (GB/s)' for i in range(config.run_count)] + ['Mean', 'Std Dev']) - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_per_second = run_vkfft(config, fft_size) - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.2f} GB/s") - rates.append(gb_per_second) - - rounded_data = [round(rate, 2) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow(["vkfft", fft_size] + rounded_data + [rounded_mean, rounded_std]) - - print(f"Results saved to {output_name}.csv") \ No newline at end of file diff --git a/performance_tests/fft_2d/fft_zipfft.py b/performance_tests/fft_2d/fft_zipfft.py deleted file mode 100644 index eee58e16..00000000 --- a/performance_tests/fft_2d/fft_zipfft.py +++ /dev/null @@ -1,83 +0,0 @@ -import csv -import time -import ffts_utils as fu -import numpy as np -import torch - -try: - from zipfft import cfft1d - from zipfft import cfft1d_strided -except ImportError: - print("zipfft is not installed. Please install it via 'pip install zipfft'.") - exit(0) - -def run_zipfft(config: fu.Config, fft_size: int) -> float: - shape = config.make_shape(fft_size) - random_data = config.make_random_data(fft_size) - - buffer = torch.empty( - shape, - dtype=torch.complex64, - device='cuda' - ) - - buffer.copy_(torch.from_numpy(random_data).to('cuda')) - - stream = torch.cuda.Stream() - - torch.cuda.synchronize() - - with torch.cuda.stream(stream): - for _ in range(config.warmup): - cfft1d.fft(buffer.view(-1, buffer.size(2))) - cfft1d_strided.fft(buffer) - - torch.cuda.synchronize() - - g = torch.cuda.CUDAGraph() - - # We capture either 1 or K FFTs back-to-back. All on the same stream. - with torch.cuda.graph(g, stream=stream): - for _ in range(max(1, config.iter_batch)): - cfft1d.fft(buffer.view(-1, buffer.size(2))) - cfft1d_strided.fft(buffer) - - torch.cuda.synchronize() - - gb_byte_count = 4 * np.prod(shape) * 8 / (1024 * 1024 * 1024) - - start_time = time.perf_counter() - - for _ in range(config.iter_count // max(1, config.iter_batch)): - g.replay() - - torch.cuda.synchronize() - - elapsed_time = time.perf_counter() - start_time - - return config.iter_count * gb_byte_count / elapsed_time - -if __name__ == "__main__": - config = fu.parse_args() - fft_sizes = fu.get_fft_sizes() - - output_name = f"fft_zipfft.csv" - with open(output_name, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(['Backend', 'FFT Size'] + [f'Run {i + 1} (GB/s)' for i in range(config.run_count)] + ['Mean', 'Std Dev']) - - for fft_size in fft_sizes: - rates = [] - - for _ in range(config.run_count): - gb_per_second = run_zipfft(config, fft_size) - print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.2f} GB/s") - rates.append(gb_per_second) - - rounded_data = [round(rate, 2) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow(["zipfft", fft_size] + rounded_data + [rounded_mean, rounded_std]) - - print(f"Results saved to {output_name}.csv") \ No newline at end of file diff --git a/performance_tests/fft_2d/ffts_utils.py b/performance_tests/fft_2d/ffts_utils.py deleted file mode 100644 index e749346b..00000000 --- a/performance_tests/fft_2d/ffts_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -import sys -from typing import Tuple -import dataclasses - -import numpy as np - -@dataclasses.dataclass -class Config: - data_size: int - iter_count: int - iter_batch: int - run_count: int - warmup: int = 10 - - def make_shape(self, fft_size: int) -> Tuple[int, ...]: - total_square_size = fft_size * fft_size - assert self.data_size % total_square_size == 0, "Data size must be a multiple of fft_size squared" - return (self.data_size // total_square_size, fft_size, fft_size) - - def make_random_data(self, fft_size: int): - shape = self.make_shape(fft_size) - return np.random.rand(*shape).astype(np.complex64) - -def parse_args() -> Config: - if len(sys.argv) != 5: - print(f"Usage: {sys.argv[0]} ") - sys.exit(1) - - return Config( - data_size=int(sys.argv[1]), - iter_count=int(sys.argv[2]), - iter_batch=int(sys.argv[3]), - run_count=int(sys.argv[4]), - ) - -def get_fft_sizes(): - return [2**i for i in range(6, 13)] # FFT sizes from 64 to 4096 (inclusive) - diff --git a/performance_tests/fft_2d/run_tests.sh b/performance_tests/fft_2d/run_tests.sh deleted file mode 100644 index a9f16908..00000000 --- a/performance_tests/fft_2d/run_tests.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -mkdir -p test_results - -cd test_results - -DATA_SIZE=134217728 -#DATA_SIZE=33554432 -ITER_COUNT=500 -BATCH_SIZE=10 -REPEATS=5 - -/usr/local/cuda/bin/nvcc ../fft_cufft.cu -o fft_cufft.exec -lcufft - -echo "Running performance tests with the following parameters:" -echo "Data Size: $DATA_SIZE" -echo "Iteration Count: $ITER_COUNT" -echo "Batch Size: $BATCH_SIZE" -echo "Repeats: $REPEATS" - -echo "Running cuFFT FFT..." -./fft_cufft.exec $DATA_SIZE $ITER_COUNT $BATCH_SIZE $REPEATS - -echo "Running Vkdispatch FFT..." -python3 ../fft_vkdispatch.py $DATA_SIZE $ITER_COUNT $BATCH_SIZE $REPEATS - -echo "Running VKFFT FFT..." -python3 ../fft_vkfft.py $DATA_SIZE $ITER_COUNT $BATCH_SIZE $REPEATS - -echo "Running PyTorch FFT..." -python3 ../fft_torch.py $DATA_SIZE $ITER_COUNT $BATCH_SIZE $REPEATS - -echo "Running ZipFFT FFT..." -python3 ../fft_zipfft.py $DATA_SIZE $ITER_COUNT $BATCH_SIZE $REPEATS - -python3 ../fft_make_graph.py \ No newline at end of file diff --git a/performance_tests/kernel_overhead/kernels_per_batch_size.py b/performance_tests/kernel_overhead/kernels_per_batch_size.py deleted file mode 100644 index 2f456c5e..00000000 --- a/performance_tests/kernel_overhead/kernels_per_batch_size.py +++ /dev/null @@ -1,139 +0,0 @@ -import numpy as np -import vkdispatch as vd -import matplotlib.pyplot as plt -import sys -import time -import csv - -from kernels_utils import do_benchmark, adjust_lightness - -platforms = [ - "warp", - "vkdispatch" -] - -kernel_types = [ - "const", - "param_stream", -] - -test_configs = [ - ("warp", "const"), - ("warp", "param_stream"), - - ("vkdispatch", "const"), - ("vkdispatch", "param_stream"), -] - - -# ----------- Define kernels dictionary ----------------------------------- - -# Assign base colors for each platform -platform_colors = { - platform: plt.cm.tab10(i % 10) # tab10 colormap cycles nicely - for i, platform in enumerate(platforms) -} - -# Kernel lightness factors -kernel_factors = { - kernel_type: 0.50 + 0.5 * (i / max(1, len(kernel_types) - 1)) - for i, kernel_type in enumerate(kernel_types) -} - -stream_count = int(sys.argv[1]) -device_ids = list(range(int(sys.argv[2]))) - -vkdispatch_queue_families = [] - -for device_id in device_ids: - vkdispatch_queue_families.append(vd.select_queue_families(device_id, stream_count)) - -vd.make_context(devices=device_ids, queue_families=vkdispatch_queue_families) - -datas = {platform: {kernel_type: [] for kernel_type in kernel_types} for platform in platforms} - -iter_count = 1024 * 1024 # Total number of iterations for the benchmark -run_count = 3 # Number of times to run each benchmark - -identity_matrix = np.diag(np.ones(shape=(4,), dtype=np.float32)) - -params_host = np.zeros(shape=(2*iter_count, 4, 4), dtype=np.float32) -params_host[:] = identity_matrix - -batch_size_exponents = list(range(2, 14)) # Batch sizes from 8 to 1024 - -for batch_size_exp in batch_size_exponents: - batch_size = 2 ** batch_size_exp - - for platform, kernel_type in test_configs: - rates = [] - for i in range(run_count): - print(f"Benchmarking {kernel_type} kernel with batch size {batch_size} on {platform} Run {i + 1}/{run_count}...") - time.sleep(0.25) # Simulate some delay before starting the benchmark - rates.append(do_benchmark( - platform, - kernel_type, - params_host, - batch_size, - iter_count, - stream_count, - stream_count, - device_ids - )) - - datas[platform][kernel_type].append(rates) - -# ----------- Print results ------------------------------------------------ - -output_name = f"kernels_per_batch_size_{len(device_ids)}_devices_{stream_count}_streams" - -with open(output_name + ".csv", 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - # Write header - writer.writerow(['Platform', 'Kernel Type', 'Batch Size'] + [f'Run {i + 1} (Kernels/second)' for i in range(run_count)] + ['Mean', 'Std Dev']) - for platform, kernel_type in test_configs: - test_data = datas[platform][kernel_type] - for batch_size_idx, rates in enumerate(test_data): - batch_size = 2 ** batch_size_exponents[batch_size_idx] - - rounded_rates = [int(round(rate, 0)) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow([platform, kernel_type, batch_size] + rounded_rates + [rounded_mean, rounded_std]) -print(f"Raw benchmark data written to {output_name}.csv") - - -# ----------- Plot results (optional) ----------------------------- - -plt.figure(figsize=(10, 6)) -for platform, kernel_type in test_configs: - base_color = platform_colors[platform] - color = adjust_lightness(base_color, kernel_factors[kernel_type]) - - test_data = datas[platform][kernel_type] - - means = [np.mean(data) for data in test_data] - stds = [np.std(data) for data in test_data] - - plt.errorbar( - [2 ** (batch_size_exponents[i]) for i in range(len(means))], - means, - yerr=stds, - label=f"{platform} - {kernel_type}", - capsize=5, - color=color - ) - -plt.xscale('log', base=2) -plt.yscale('log') -plt.xlabel('Batch Size') -plt.ylabel('Kernels/second') -plt.title(f'Kernel Launch Overhead Benchmark (Stream Count: {stream_count}, Devices: {len(device_ids)}, Param Size: 128 bytes)') -plt.legend() -plt.grid(True) -plt.tight_layout() -plt.savefig(output_name + "_log.png") - -plt.yscale('linear') -plt.savefig(output_name + "_linear.png") diff --git a/performance_tests/kernel_overhead/kernels_per_streams.py b/performance_tests/kernel_overhead/kernels_per_streams.py deleted file mode 100644 index 862ab2cb..00000000 --- a/performance_tests/kernel_overhead/kernels_per_streams.py +++ /dev/null @@ -1,141 +0,0 @@ -import numpy as np -import vkdispatch as vd -import matplotlib.pyplot as plt -import sys -import time - -from kernels_utils import do_benchmark, adjust_lightness -import csv - -platforms = [ - "warp", - "vkdispatch" -] - -kernel_types = [ - "const", - "param_stream", -] - -test_configs = [ - ("warp", "const"), - ("warp", "param_stream"), - - ("vkdispatch", "const"), - ("vkdispatch", "param_stream"), -] - - -# ----------- Define kernels dictionary ----------------------------------- - -# Assign base colors for each platform -platform_colors = { - platform: plt.cm.tab10(i % 10) # tab10 colormap cycles nicely - for i, platform in enumerate(platforms) -} - -# Kernel lightness factors -kernel_factors = { - kernel_type: 0.50 + 0.5 * (i / max(1, len(kernel_types) - 1)) - for i, kernel_type in enumerate(kernel_types) -} - -total_stream_count = int(sys.argv[1]) -device_ids = list(range(int(sys.argv[2]))) - -vkdispatch_queue_families = [] - -#vd.initialize(log_level=vd.LogLevel.INFO) - -for device_id in device_ids: - vkdispatch_queue_families.append(vd.select_queue_families(device_id, total_stream_count)) - -vd.make_context(device_ids=device_ids, queue_families=vkdispatch_queue_families) - -datas = {platform: {kernel_type: [] for kernel_type in kernel_types} for platform in platforms} - -iter_count = 1024 * 1024 # Total number of iterations for the benchmark -run_count = 3 # Number of times to run each benchmark - -identity_matrix = np.diag(np.ones(shape=(4,), dtype=np.float32)) - -params_host = np.zeros(shape=(2*iter_count, 4, 4), dtype=np.float32) -params_host[:] = identity_matrix - -batch_size = 512 - -stream_counts = list(range(1, total_stream_count + 1)) # Stream counts from 1 to stream_count - -for streams in stream_counts: - for platform, kernel_type in test_configs: - rates = [] - for i in range(run_count): - print(f"Benchmarking {kernel_type} kernel with streams={streams} on {platform} Run {i + 1}/{run_count}...") - time.sleep(0.25) # Simulate some delay before starting the benchmark - rates.append(do_benchmark( - platform, - kernel_type, - params_host, - batch_size, - iter_count, - streams, - total_stream_count, - device_ids - )) - - datas[platform][kernel_type].append(rates) - - -# ----------- Print results ------------------------------------------------ - -output_name = f"kernels_per_streams_{len(device_ids)}_devices_{batch_size}_batch_size" - -with open(output_name + ".csv", 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - # Write header - writer.writerow(['Platform', 'Kernel Type', 'Stream Count'] + [f'Run {i + 1} (Kernels/second)' for i in range(run_count)] + ['Mean', 'Std Dev']) - for platform, kernel_type in test_configs: - test_data = datas[platform][kernel_type] - for stream_idx, rates in enumerate(test_data): - stream_count = stream_counts[stream_idx] - #for run_idx, rate in enumerate(rates): - - rounded_rates = [int(round(rate, 0)) for rate in rates] - rounded_mean = round(np.mean(rates), 2) - rounded_std = round(np.std(rates), 2) - - writer.writerow([platform, kernel_type, stream_count] + rounded_rates + [rounded_mean, rounded_std]) -print(f"Raw benchmark data written to {output_name}.csv") - -# ----------- Plot results (optional) ----------------------------- - -plt.figure(figsize=(10, 6)) -for platform, kernel_type in test_configs: - base_color = platform_colors[platform] - color = adjust_lightness(base_color, kernel_factors[kernel_type]) - - test_data = datas[platform][kernel_type] - - means = [np.mean(data) for data in test_data] - stds = [np.std(data) for data in test_data] - - plt.errorbar( - [stream_counts[i] for i in range(len(test_data))], - means, - yerr=stds, - label=f"{platform} - {kernel_type}", - capsize=5, - color=color - ) - -plt.yscale('log') -plt.xlabel('Stream Count') -plt.ylabel('Kernels/second') -plt.title(f'Kernel Launch Overhead Benchmark (Devices: {len(device_ids)}, Param Size: 128 bytes, Batch Size: {batch_size})') -plt.legend() -plt.grid(True) -plt.tight_layout() -plt.savefig(output_name + "_log.png") - -plt.yscale('linear') -plt.savefig(output_name + "_linear.png") \ No newline at end of file diff --git a/performance_tests/kernel_overhead/kernels_utils.py b/performance_tests/kernel_overhead/kernels_utils.py deleted file mode 100644 index 7ac612bf..00000000 --- a/performance_tests/kernel_overhead/kernels_utils.py +++ /dev/null @@ -1,216 +0,0 @@ -import warp as wp -import time -import gc -import numpy as np -import vkdispatch as vd -import vkdispatch.codegen as vc -import matplotlib.colors as mcolors -import colorsys - -reference_list = [] - -def register_object(obj): - reference_list.append(obj) - -# ----------- Define kernels for measuring launch overheads --------------- - -@wp.kernel -def k_const_warp(out: wp.array(dtype=float), mat1: wp.mat44f, mat2: wp.mat44f): - i = wp.tid() - if i == 0: - out[i] = out[i] + wp.determinant(mat1) + wp.determinant(mat2) - -@wp.kernel -def k_param_stream_warp(out: wp.array(dtype=float), matricies: wp.array(dtype=wp.mat44f), param_index: int): - i = wp.tid() - if i == 0: - out[i] = out[i] + wp.determinant(matricies[param_index]) + wp.determinant(matricies[param_index + 1]) - -def make_graph_warp(kernel, out, matricies, batch_size, stream, device, do_streaming): - identity_matrix = np.diag(np.ones(shape=(4,), dtype=np.float32)) - - with wp.ScopedCapture(device=device, stream=stream) as capture: - for i in range(batch_size): - inputs = [out, identity_matrix, identity_matrix] if not do_streaming else [out, matricies, 2*i] - - wp.launch( - kernel, - dim=1, - inputs=inputs, - device=device, - stream=stream - ) - - return capture.graph - -def do_benchmark_warp(kernel, params_host, kernel_type, batch_size, iter_count, streams_per_device, stream_count, device_ids): - out_arrays = [] - params_arrays = [] - h_buffs = [] - graphs = [] - streams = [] - - devices = [wp.get_device(f"cuda:{device_id}") for device_id in device_ids] - - total_streams = streams_per_device * len(device_ids) - - for i in range(total_streams): - device = devices[i % len(device_ids)] - - stream = wp.Stream(device=device) - - streams.append(stream) - - out_arrays.append(wp.zeros(shape=(1,), dtype=wp.float32, device=device)) - - if kernel_type == "param_stream": - h_buffs.append(wp.zeros(shape=(2 * batch_size,), dtype=wp.mat44f, device=device, pinned=True)) - params_arrays.append(wp.zeros(shape=(2 * batch_size,), dtype=wp.mat44f, device=device)) - else: - h_buffs.append(None) - params_arrays.append(None) - - graphs.append(make_graph_warp( - kernel, - out_arrays[i], - params_arrays[i] , - batch_size, - stream, - device, - kernel_type == "param_stream" - )) - - assert iter_count % batch_size == 0, "iter_count must be a multiple of batch_size" - - num_graph_launches = iter_count // batch_size - - start_time = time.perf_counter() - for i in range(num_graph_launches): - device = devices[i % len(device_ids)] - stream_idx = i % total_streams - - if kernel_type == "param_stream": - h_buffs[stream_idx].numpy()[:] = params_host[2*i*batch_size:2*(i+1)*batch_size] - wp.copy(params_arrays[stream_idx], h_buffs[stream_idx], stream=streams[stream_idx]) - - wp.capture_launch(graphs[stream_idx], stream=streams[stream_idx]) - - for dev in devices: - wp.synchronize_device(dev) - end_time = time.perf_counter() - - # Cleanup - del graphs - del streams - del out_arrays - del params_arrays - - if kernel_type == "param_stream": - del h_buffs - - wp.synchronize_device("cuda:0") - gc.collect() - - return end_time - start_time - -# ----------- Define kernels for measuring launch overheads --------------- - - -@vd.shader(local_size=(1, 1, 1), workgroups=(1, 1, 1), enable_exec_bounds=False) -def k_const_vkdispatch(out: vc.Buff[vc.f32], mat1: vc.Const[vc.m4], mat2: vc.Const[vc.m4]): - i = vc.global_invocation().x - vc.if_statement(i == 0) - out[i] = out[i] + vc.determinant(mat1) + vc.determinant(mat2) - vc.end() - -@vd.shader(local_size=(1, 1, 1), workgroups=(1, 1, 1), enable_exec_bounds=False) -def k_param_stream_vkdispatch(out: vc.Buff[vc.f32], mat1: vc.Var[vc.m4], mat2: vc.Var[vc.m4]): - i = vc.global_invocation().x - vc.if_statement(i == 0) - out[i] = out[i] + vc.determinant(mat1) + vc.determinant(mat2) - vc.end() - -def do_benchmark_vkdispatch(kernel, params_host, kernel_type, batch_size, iter_count, streams_per_device, stream_count, device_ids): - out_buff = vd.Buffer(shape=(1,), var_type=vd.float32) - identity_matrix = np.diag(np.ones(shape=(4,), dtype=np.float32)) - - do_streaming = kernel_type == "param_stream" - - graph = vd.CommandGraph() - - kernel( - out_buff, - graph.bind_var("mat1") if do_streaming else identity_matrix, - graph.bind_var("mat2") if do_streaming else identity_matrix, - graph=graph - ) - - register_object(out_buff) - register_object(graph) - - assert iter_count % batch_size == 0, "iter_count must be a multiple of batch_size" - - num_graph_launches = iter_count // batch_size - - total_streams = streams_per_device * len(device_ids) - - vd.queue_wait_idle() - - start_time = time.perf_counter() - for i in range(num_graph_launches): - if kernel_type == "param_stream": - graph.set_var("mat1", params_host[2*i*batch_size:2*(i+1)*batch_size:2]) - graph.set_var("mat2", params_host[2*i*batch_size+1:2*(i+1)*batch_size:2]) - - raw_stream_index = i % total_streams - raw_stream_index = raw_stream_index + (stream_count - streams_per_device) * raw_stream_index // streams_per_device - graph.submit(instance_count=batch_size, queue_index=raw_stream_index) - - vd.queue_wait_idle() - end_time = time.perf_counter() - - gc.collect() - - return end_time - start_time - -kernels = { - "warp": { - "const": k_const_warp, - "param_stream": k_param_stream_warp, - }, - "vkdispatch": { - "const": k_const_vkdispatch, - "param_stream": k_param_stream_vkdispatch, - } -} - -benchmarks = { - "warp": do_benchmark_warp, - "vkdispatch": do_benchmark_vkdispatch -} - -def do_benchmark(platform, kernel_type, params_host, batch_size, iter_count, streams_per_device, stream_count, device_ids): - elapsed_time = benchmarks[platform]( - kernels[platform][kernel_type], - params_host, - kernel_type, - batch_size, - iter_count, - streams_per_device, - stream_count, - device_ids - ) - - return iter_count / elapsed_time - -def adjust_lightness(color, factor): - """Lighten or darken a given matplotlib color by multiplying its lightness by 'factor'.""" - try: - c = mcolors.cnames[color] - except KeyError: - c = color - r, g, b = mcolors.to_rgb(c) - h, l, s = colorsys.rgb_to_hls(r, g, b) - l = max(0, min(1, l * factor)) - r, g, b = colorsys.hls_to_rgb(h, l, s) - return (r, g, b) \ No newline at end of file diff --git a/performance_tests/kernel_overhead/run_performance_tests.sh b/performance_tests/kernel_overhead/run_performance_tests.sh deleted file mode 100644 index 14a1240a..00000000 --- a/performance_tests/kernel_overhead/run_performance_tests.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -mkdir -p test_results - -cd test_results - -python3 ../kernels_per_streams.py 10 1 # Test with up to 10 streams and 1 device -python3 ../kernels_per_streams.py 10 2 # Test with up to 10 streams and 2 devices -python3 ../kernels_per_streams.py 10 3 # Test with up to 10 streams and 3 devices -python3 ../kernels_per_streams.py 10 4 # Test with up to 10 streams and 4 devices - -python3 ../kernels_per_batch_size.py 1 1 # Test batch sizes with 1 device and 1 stream -python3 ../kernels_per_batch_size.py 2 1 # Test batch sizes with 1 device and 2 streams -python3 ../kernels_per_batch_size.py 4 1 # Test batch sizes with 1 device and 4 streams - -python3 ../kernels_per_batch_size.py 1 4 # Test batch sizes with 4 device and 1 stream -python3 ../kernels_per_batch_size.py 2 4 # Test batch sizes with 4 device and 2 streams -python3 ../kernels_per_batch_size.py 4 4 # Test batch sizes with 4 device and 3 streams \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3867a051..7379c159 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,34 +2,7 @@ requires = [ "setuptools>=59.0", "wheel", - "Cython" + "Cython", + "packaging" ] build-backend = "setuptools.build_meta" - -[project] -name = "vkdispatch" -version = "0.0.30" -authors = [ - { name="Shahar Sandhaus", email="shahar.sandhaus@gmail.com" }, -] -description = "A Python module for orchestrating and dispatching large computations across multi-GPU systems using Vulkan." -readme = "README.md" -requires-python = ">=3.6" -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - "Development Status :: 2 - Pre-Alpha", -] -dependencies = [ - "setuptools>=59.0", - "numpy", -] -scripts = { vdlist = 'vkdispatch.cli:cli_entrypoint' } - -[project.urls] -Homepage = "https://github.com/sharhar/vkdispatch" -Issues = "https://github.com/sharhar/vkdispatch/issues" - -[project.optional-dependencies] -cli = ["Click"] diff --git a/setup.py b/setup.py index 40dd1841..422495ce 100644 --- a/setup.py +++ b/setup.py @@ -1,243 +1,129 @@ import os import platform +import re import subprocess +from pathlib import Path from setuptools import Extension +from setuptools import find_packages from setuptools import setup from setuptools.command.build_ext import build_ext -import re - -# Typically you'll put `packaging` in your setup_requires or pyproject.toml if needed. try: from packaging.version import Version except ImportError: - # As a fallback, if you absolutely can't rely on `packaging`, - # you could use distutils: from distutils.version import LooseVersion as Version print("Warning: 'packaging' not found; version comparisons might be less accurate.") from distutils.version import LooseVersion as Version -system = platform.system() +BUILD_TARGET_FULL = "full" +BUILD_TARGET_CORE = "core" +BUILD_TARGET_NATIVE = "native" +BUILD_TARGET_META = "meta" +VALID_BUILD_TARGETS = { + BUILD_TARGET_FULL, + BUILD_TARGET_CORE, + BUILD_TARGET_NATIVE, + BUILD_TARGET_META, +} -proj_root = os.path.abspath(os.path.dirname(__file__)) -molten_vk_path = "./deps/MoltenVK/MoltenVK/MoltenVK/static/MoltenVK.xcframework/macos-arm64_x86_64/" -vulkan_sdk_root = os.environ.get('VULKAN_SDK') -platform_name_dict = { - "Darwin": "MACOS", - "Windows": "WINDOWS", - "Linux": "LINUX" -} +def get_build_target() -> str: + target = os.environ.get("VKDISPATCH_BUILD_TARGET", BUILD_TARGET_FULL).strip().lower() + if target not in VALID_BUILD_TARGETS: + valid = ", ".join(sorted(VALID_BUILD_TARGETS)) + raise RuntimeError( + f"Invalid VKDISPATCH_BUILD_TARGET={target!r}. Expected one of: {valid}" + ) + return target -platform_library_dirs = [] -platform_define_macros = [] #[(f"__VKDISPATCH_PLATFORM_{platform_name_dict[system]}__", 1), ("LOG_VERBOSE_ENABLED", 1)] -platform_link_libraries = [] -platform_extra_link_args = [] -platform_extra_compile_args = ( - ["/W3", "/GL", "/DNDEBUG", "/MD", "/EHsc", "/std:c++17"] - if system == "Windows" - else [ - "-O0", - "-g", - "-std=c++17", - #"-fsanitize=address", - #"-fsanitize-address-use-after-scope", - ] -) -include_directories = [ - proj_root + "/deps/VMA/include", - proj_root + "/deps/volk", - proj_root + "/deps/VkFFT/vkFFT", -] +BUILD_TARGET = get_build_target() -if os.name == "posix": - platform_extra_link_args.append("-g") - platform_extra_link_args.append("-O0") - platform_extra_link_args.append("-fno-omit-frame-pointer") - #platform_extra_link_args.append("-fsanitize=address") - #platform_extra_link_args.append("-fsanitize-address-use-after-scope") - platform_link_libraries.extend(["dl", "pthread"]) - - -if vulkan_sdk_root is None: - include_directories.extend([ - proj_root + "/include_ext", - proj_root + "/deps/Vulkan-Headers/include", - proj_root + "/deps/Vulkan-Utility-Libraries/include", - proj_root + "/deps/glslang", - proj_root + "/deps/glslang/glslang/Include", - ]) - - if system == "Darwin": - platform_library_dirs.append(molten_vk_path) - platform_link_libraries.append("MoltenVK") - platform_extra_link_args.extend([ - "-framework", "Metal", - "-framework", "AVFoundation", - "-framework", "AppKit" - ]) - platform_extra_compile_args.append("-mmacosx-version-min=10.15") - else: - platform_define_macros.append(("VKDISPATCH_USE_VOLK", 1)) -else: - include_directories.extend([ - vulkan_sdk_root + '/include', - vulkan_sdk_root + '/include/utility', - vulkan_sdk_root + '/include/glslang/Include', - ]) +proj_root = Path(__file__).resolve().parent +system = platform.system() +molten_vk_path = "./deps/MoltenVK/MoltenVK/MoltenVK/static/MoltenVK.xcframework/macos-arm64_x86_64/" +vulkan_sdk_root = os.environ.get("VULKAN_SDK") - platform_define_macros.append(("VKDISPATCH_USE_VOLK", 1)) - platform_define_macros.append(("VKDISPATCH_LOADER_PATH", '"' + os.path.abspath(f"{vulkan_sdk_root}") + '/"')) - #if os.name == "posix": - # platform_link_libraries.append("vulkan") - #else: - # platform_link_libraries.append("vulkan-1") +def read_version() -> str: + init_path = proj_root / "vkdispatch" / "__init__.py" + text = init_path.read_text(encoding="utf-8") + match = re.search(r'^__version__\s*=\s*"([^"]+)"', text, re.MULTILINE) + if not match: + raise RuntimeError(f"Could not find __version__ in {init_path}") + return match.group(1) - platform_library_dirs.append(vulkan_sdk_root + '/lib') - platform_link_libraries.extend([ - "glslang", - "SPIRV", - "MachineIndependent", - "GenericCodeGen", - "SPIRV-Tools-opt", - "SPIRV-Tools-link", - "SPIRV-Tools-reduce", - "SPIRV-Tools", - "glslang-default-resource-limits" - ]) +def read_readme() -> str: + return (proj_root / "README.md").read_text(encoding="utf-8") -sources = [] +VERSION = read_version() -def append_to_sources(prefix, source_list): - global sources +COMMON_CLASSIFIERS = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Development Status :: 2 - Pre-Alpha", +] + +COMMON_PROJECT_URLS = { + "Homepage": "https://github.com/sharhar/vkdispatch", + "Issues": "https://github.com/sharhar/vkdispatch/issues", +} + +COMMON_EXTRAS = { + "cuda": ["cuda-python"], + "opencl": ["pyopencl", "numpy"], + "pycuda": ["pycuda"], + "numpy": ["numpy"], +} - for source in source_list: - sources.append(prefix + source) - - -sources.append("vkdispatch_native/wrapper.pyx") - -append_to_sources("vkdispatch_native/", [ - "context/init.cpp", - "context/context.cpp", - "context/errors.cpp", - "context/handles.cpp", - - "objects/buffer.cpp", - "objects/image.cpp", - "objects/command_list.cpp", - "objects/descriptor_set.cpp", - - "stages/stage_fft.cpp", - "stages/stage_compute.cpp", - - "queue/queue.cpp", - "queue/signal.cpp", - "queue/work_queue.cpp", - "queue/barrier_manager.cpp", - - "libs/VMAImpl.cpp", - "libs/VolkImpl.cpp" -]) - -if vulkan_sdk_root is None: - append_to_sources("deps/glslang/glslang/", [ - "CInterface/glslang_c_interface.cpp", - "GenericCodeGen/CodeGen.cpp", - "GenericCodeGen/Link.cpp", - "MachineIndependent/glslang_tab.cpp", - "MachineIndependent/attribute.cpp", - "MachineIndependent/Constant.cpp", - "MachineIndependent/iomapper.cpp", - "MachineIndependent/InfoSink.cpp", - "MachineIndependent/Initialize.cpp", - "MachineIndependent/IntermTraverse.cpp", - "MachineIndependent/Intermediate.cpp", - "MachineIndependent/ParseContextBase.cpp", - "MachineIndependent/ParseHelper.cpp", - "MachineIndependent/PoolAlloc.cpp", - "MachineIndependent/RemoveTree.cpp", - "MachineIndependent/Scan.cpp", - "MachineIndependent/ShaderLang.cpp", - "MachineIndependent/SpirvIntrinsics.cpp", - "MachineIndependent/SymbolTable.cpp", - "MachineIndependent/Versions.cpp", - "MachineIndependent/intermOut.cpp", - "MachineIndependent/limits.cpp", - "MachineIndependent/linkValidate.cpp", - "MachineIndependent/parseConst.cpp", - "MachineIndependent/reflection.cpp", - "MachineIndependent/preprocessor/Pp.cpp", - "MachineIndependent/preprocessor/PpAtom.cpp", - "MachineIndependent/preprocessor/PpContext.cpp", - "MachineIndependent/preprocessor/PpScanner.cpp", - "MachineIndependent/preprocessor/PpTokens.cpp", - "MachineIndependent/propagateNoContraction.cpp", - "ResourceLimits/ResourceLimits.cpp", - "ResourceLimits/resource_limits_c.cpp" - ]) - - append_to_sources("deps/glslang/SPIRV/", [ - "GlslangToSpv.cpp", - "InReadableOrder.cpp", - "Logger.cpp", - "SpvBuilder.cpp", - "SpvPostProcess.cpp", - "doc.cpp", - "SpvTools.cpp", - "disassemble.cpp", - "CInterface/spirv_c_interface.cpp" - ]) def parse_compiler_version(version_output): if not isinstance(version_output, str): return None - - # Try to match either clang or gcc version string - clang_match = re.search(r'clang version ([^\s]+)', version_output) - gcc_match = re.search(r'gcc.+?([\d.]+(?:-[a-zA-Z0-9]+)?)', version_output, re.IGNORECASE) - + + clang_match = re.search(r"clang version ([^\s]+)", version_output) + gcc_match = re.search( + r"gcc.+?([\d.]+(?:-[a-zA-Z0-9]+)?)", version_output, re.IGNORECASE + ) + match = clang_match or gcc_match if not match: return None try: return Version(match.group(1)) - except Exception as e: - print(f"Invalid version: {e}") + except Exception as exc: + print(f"Invalid version: {exc}") return None + def detect_unix_compiler(compiler_exe): - """ - Given the 'compiler_exe' (like 'gcc', 'clang', etc.), returns a string - denoting the compiler family: 'clang', 'gcc', or 'unknown'. - """ try: - # Run e.g. `gcc --version` or `clang --version` - version_output = subprocess.check_output([compiler_exe, '--version'], - stderr=subprocess.STDOUT, - universal_newlines=True) - - if 'clang' in version_output: - return 'clang', parse_compiler_version(version_output) - elif 'gcc' in version_output or 'Free Software Foundation' in version_output: - return 'gcc', parse_compiler_version(version_output) - else: - return 'unknown', None + version_output = subprocess.check_output( + [compiler_exe, "--version"], + stderr=subprocess.STDOUT, + universal_newlines=True, + ) + + if "clang" in version_output: + return "clang", parse_compiler_version(version_output) + if "gcc" in version_output or "Free Software Foundation" in version_output: + return "gcc", parse_compiler_version(version_output) + return "unknown", None except Exception: - return 'unknown', None - + return "unknown", None + + class CustomBuildExt(build_ext): def build_extensions(self): compiler_type = self.compiler.compiler_type print(f"Detected compiler type: {compiler_type}") - if compiler_type == 'unix': + if compiler_type == "unix": print(f"Detected compiler: {self.compiler.compiler}") compiler_family, version = detect_unix_compiler(self.compiler.compiler[0]) print(f"Detected compiler family: {compiler_family}") @@ -245,43 +131,290 @@ def build_extensions(self): if version is not None: for ext in self.extensions: - if compiler_family == 'clang' and version < Version('9.0'): - ext.libraries.append('c++fs') - elif compiler_family == 'gcc' and version < Version('9.1'): - ext.libraries.append('stdc++fs') + if compiler_family == "clang" and version < Version("9.0"): + ext.libraries.append("c++fs") + elif compiler_family == "gcc" and version < Version("9.1"): + ext.libraries.append("stdc++fs") else: - print("WARNING: Unknown compiler family, not adding filesystem library") + print( + "WARNING: Unknown compiler family, not adding filesystem library" + ) - # Now actually build the extensions super().build_extensions() -setup( - name="vkdispatch", - packages=[ - "vkdispatch", - "vkdispatch.base", - "vkdispatch.codegen", - "vkdispatch.execution_pipeline", - "vkdispatch.shader_generation", - "vkdispatch.vkfft", - "vkdispatch.fft" - ], - ext_modules=[ - Extension( - "vkdispatch_native", - sources=sources, - language="c++", - define_macros=platform_define_macros, - library_dirs=platform_library_dirs, - libraries=platform_link_libraries, - extra_compile_args=platform_extra_compile_args, - extra_link_args=platform_extra_link_args, - include_dirs=include_directories, + +def append_to_sources(prefix, source_list, out_sources): + for source in source_list: + out_sources.append(prefix + source) + + +def build_native_extension(): + platform_library_dirs = [] + platform_define_macros = [] + platform_link_libraries = [] + platform_extra_link_args = [] + platform_extra_compile_args = ( + ["/W3", "/GL", "/DNDEBUG", "/MD", "/EHsc", "/std:c++17"] + if system == "Windows" + else ["-O2", "-g", "-std=c++17"] + ) + + include_directories = [ + str(proj_root / "deps" / "VMA" / "include"), + str(proj_root / "deps" / "volk"), + str(proj_root / "deps" / "VkFFT" / "vkFFT"), + ] + + if os.name == "posix": + platform_extra_link_args.extend(["-g", "-O0", "-fno-omit-frame-pointer"]) + platform_link_libraries.extend(["dl", "pthread"]) + + if vulkan_sdk_root is None: + include_directories.extend( + [ + str(proj_root / "include_ext"), + str(proj_root / "deps" / "Vulkan-Headers" / "include"), + str(proj_root / "deps" / "Vulkan-Utility-Libraries" / "include"), + str(proj_root / "deps" / "glslang"), + str(proj_root / "deps" / "glslang" / "glslang" / "Include"), + ] + ) + + if system == "Darwin": + platform_library_dirs.append(molten_vk_path) + platform_link_libraries.append("MoltenVK") + platform_extra_link_args.extend( + [ + "-framework", + "Metal", + "-framework", + "AVFoundation", + "-framework", + "AppKit", + ] + ) + platform_extra_compile_args.append("-mmacosx-version-min=10.15") + else: + platform_define_macros.append(("VKDISPATCH_USE_VOLK", 1)) + else: + include_directories.extend( + [ + vulkan_sdk_root + "/include", + vulkan_sdk_root + "/include/utility", + vulkan_sdk_root + "/include/glslang/Include", + ] ) - ], - cmdclass={ - 'build_ext': CustomBuildExt, - }, - version="0.0.30", - zip_safe=False, -) + + platform_define_macros.append(("VKDISPATCH_USE_VOLK", 1)) + platform_define_macros.append( + ("VKDISPATCH_LOADER_PATH", '"' + os.path.abspath(vulkan_sdk_root) + '/"') + ) + + platform_library_dirs.append(vulkan_sdk_root + "/lib") + platform_link_libraries.extend( + [ + "glslang", + "SPIRV", + "MachineIndependent", + "GenericCodeGen", + "SPIRV-Tools-opt", + "SPIRV-Tools-link", + "SPIRV-Tools-reduce", + "SPIRV-Tools", + "glslang-default-resource-limits", + ] + ) + + sources = [] + sources.append("vkdispatch_native/wrapper.pyx") + + append_to_sources( + "vkdispatch_native/", + [ + "context/init.cpp", + "context/context.cpp", + "context/errors.cpp", + "context/handles.cpp", + "objects/buffer.cpp", + "objects/image.cpp", + "objects/command_list.cpp", + "objects/descriptor_set.cpp", + "stages/stage_fft.cpp", + "stages/stage_compute.cpp", + "queue/queue.cpp", + "queue/signal.cpp", + "queue/work_queue.cpp", + "queue/barrier_manager.cpp", + "libs/VMAImpl.cpp", + "libs/VolkImpl.cpp", + ], + sources, + ) + + if vulkan_sdk_root is None: + append_to_sources( + "deps/glslang/glslang/", + [ + "CInterface/glslang_c_interface.cpp", + "GenericCodeGen/CodeGen.cpp", + "GenericCodeGen/Link.cpp", + "MachineIndependent/glslang_tab.cpp", + "MachineIndependent/attribute.cpp", + "MachineIndependent/Constant.cpp", + "MachineIndependent/iomapper.cpp", + "MachineIndependent/InfoSink.cpp", + "MachineIndependent/Initialize.cpp", + "MachineIndependent/IntermTraverse.cpp", + "MachineIndependent/Intermediate.cpp", + "MachineIndependent/ParseContextBase.cpp", + "MachineIndependent/ParseHelper.cpp", + "MachineIndependent/PoolAlloc.cpp", + "MachineIndependent/RemoveTree.cpp", + "MachineIndependent/Scan.cpp", + "MachineIndependent/ShaderLang.cpp", + "MachineIndependent/SpirvIntrinsics.cpp", + "MachineIndependent/SymbolTable.cpp", + "MachineIndependent/Versions.cpp", + "MachineIndependent/intermOut.cpp", + "MachineIndependent/limits.cpp", + "MachineIndependent/linkValidate.cpp", + "MachineIndependent/parseConst.cpp", + "MachineIndependent/reflection.cpp", + "MachineIndependent/preprocessor/Pp.cpp", + "MachineIndependent/preprocessor/PpAtom.cpp", + "MachineIndependent/preprocessor/PpContext.cpp", + "MachineIndependent/preprocessor/PpScanner.cpp", + "MachineIndependent/preprocessor/PpTokens.cpp", + "MachineIndependent/propagateNoContraction.cpp", + "ResourceLimits/ResourceLimits.cpp", + "ResourceLimits/resource_limits_c.cpp", + ], + sources, + ) + + append_to_sources( + "deps/glslang/SPIRV/", + [ + "GlslangToSpv.cpp", + "InReadableOrder.cpp", + "Logger.cpp", + "SpvBuilder.cpp", + "SpvPostProcess.cpp", + "doc.cpp", + "SpvTools.cpp", + "disassemble.cpp", + "CInterface/spirv_c_interface.cpp", + ], + sources, + ) + + return Extension( + "vkdispatch_vulkan_native", + sources=sources, + language="c++", + define_macros=platform_define_macros, + library_dirs=platform_library_dirs, + libraries=platform_link_libraries, + extra_compile_args=platform_extra_compile_args, + extra_link_args=platform_extra_link_args, + include_dirs=include_directories, + ) + + +def base_setup_kwargs(): + return { + "version": VERSION, + "author": "Shahar Sandhaus", + "author_email": "shahar.sandhaus@gmail.com", + "description": ( + "A Python module for orchestrating and dispatching large computations " + "across multi-GPU systems using Vulkan." + ), + "long_description": read_readme(), + "long_description_content_type": "text/markdown", + "python_requires": ">=3.6", + "classifiers": COMMON_CLASSIFIERS, + "project_urls": COMMON_PROJECT_URLS, + "zip_safe": False, + } + + +def core_packages(): + return find_packages(include=["vkdispatch", "vkdispatch.*"]) + + +def setup_for_target(target: str): + kwargs = base_setup_kwargs() + + if target == BUILD_TARGET_FULL: + kwargs.update( + { + "name": "vkdispatch", + "packages": core_packages(), + "install_requires": ["setuptools>=59.0"], + "extras_require": { + "cli": ["Click"], + **COMMON_EXTRAS, + }, + "entry_points": { + "console_scripts": [ + "vdlist=vkdispatch.cli:cli_entrypoint", + ] + }, + "ext_modules": [build_native_extension()], + "cmdclass": {"build_ext": CustomBuildExt}, + } + ) + return kwargs + + if target == BUILD_TARGET_CORE: + kwargs.update( + { + "name": "vkdispatch-core", + "packages": core_packages(), + "install_requires": ["setuptools>=59.0"], + "extras_require": dict(COMMON_EXTRAS), + } + ) + return kwargs + + if target == BUILD_TARGET_NATIVE: + kwargs.update( + { + "name": "vkdispatch-vulkan-native", + "packages": [], + "py_modules": [], + "install_requires": [], + "ext_modules": [build_native_extension()], + "cmdclass": {"build_ext": CustomBuildExt}, + } + ) + return kwargs + + if target == BUILD_TARGET_META: + kwargs.update( + { + "name": "vkdispatch", + "packages": [], + "py_modules": [], + "install_requires": [ + f"vkdispatch-core=={VERSION}", + f"vkdispatch-vulkan-native=={VERSION}", + ], + "extras_require": { + "cli": ["Click"], + **COMMON_EXTRAS, + }, + "entry_points": { + "console_scripts": [ + "vdlist=vkdispatch.cli:cli_entrypoint", + ] + }, + } + ) + return kwargs + + raise AssertionError(f"Unhandled build target: {target}") + + +setup(**setup_for_target(BUILD_TARGET)) diff --git a/shader_run.py b/shader_run.py new file mode 100644 index 00000000..8c34a024 --- /dev/null +++ b/shader_run.py @@ -0,0 +1,89 @@ +import vkdispatch as vd + +from vkdispatch.base.command_list import CommandList +from vkdispatch.base.compute_plan import ComputePlan +from vkdispatch.base.descriptor_set import DescriptorSet + +import numpy as np + +def load_shader(path: str) -> ComputePlan: + shader_source = open(path, 'r').read() + + return ComputePlan( + shader_source=shader_source, + binding_type_list=[1, 1, 1], + pc_size=0, + shader_name=f"shader_{path.split('/')[-1].split('.')[0]}" + ) + +def make_descriptor(plan: ComputePlan, out_buff: vd.Buffer, in_buff: vd.Buffer, kern_buff: vd.Buffer): + descriptor_set = DescriptorSet(plan) + + descriptor_set.bind_buffer(out_buff, 0) + descriptor_set.bind_buffer(in_buff, 1) + descriptor_set.bind_buffer(kern_buff, 2) + + return descriptor_set + +def numpy_convolution(signal: np.ndarray, kernel: np.ndarray) -> np.ndarray: + return np.fft.ifft( + np.fft.fft(signal, axis=1).astype(np.complex64) + * + kernel.conjugate(), + axis=1 + ) + +BUFF_SHAPE = (4, 512, 257) + +np.random.seed(1337) + +in_data = (np.random.rand(*BUFF_SHAPE) + 1j * np.random.rand(*BUFF_SHAPE)).astype(np.complex64) +kern_data = (np.random.rand(*BUFF_SHAPE) + 1j * np.random.rand(*BUFF_SHAPE)).astype(np.complex64) + +reference_result_data = numpy_convolution(in_data, kern_data[0]) + +out_buff = vd.buffer_c64(BUFF_SHAPE) +in_buff = vd.buffer_c64(BUFF_SHAPE) +kern_buff = vd.buffer_c64(BUFF_SHAPE) + +in_buff.write(in_data) +kern_buff.write(kern_data) + +block_count = (1028, 32, 1) + +plan_bad = load_shader("conv_bad.comp") +plan_good = load_shader("conv_good.comp") + +cmd_list_bad = CommandList() + +cmd_list_bad.record_compute_plan( + plan_bad, + make_descriptor(plan_bad, out_buff, in_buff, kern_buff), + block_count +) + +cmd_list_bad.submit(instance_count=1) + +result_data_bad = out_buff.read(0) + +cmd_list_good = CommandList() + +cmd_list_good.record_compute_plan( + plan_good, + make_descriptor(plan_good, out_buff, in_buff, kern_buff), + block_count +) + +cmd_list_good.submit(instance_count=1) + +result_data_good = out_buff.read(0) + +for i in range(BUFF_SHAPE[0]): + np.save(f"result_bad_{i}.npy", result_data_bad[i]) + np.save(f"result_good_{i}.npy", result_data_good[i]) + np.save(f"reference_result_{i}.npy", reference_result_data[i]) + np.save(f"diff_bad_{i}.npy", result_data_bad[i] - reference_result_data[i]) + np.save(f"diff_good_{i}.npy", result_data_good[i] - reference_result_data[i]) + np.save(f"diff_{i}.npy", result_data_good[i] - result_data_bad[i]) + +assert np.allclose(result_data_good, result_data_bad, atol=1e-3) diff --git a/test.py b/test.py new file mode 100644 index 00000000..9645b0b6 --- /dev/null +++ b/test.py @@ -0,0 +1,57 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc +import numpy as np + +from typing import Tuple + +#vd.initialize(backend="vulkan") + +def make_shape(fft_size: int, data_size: int) -> Tuple[int, ...]: + total_square_size = fft_size * fft_size + assert data_size % total_square_size == 0, "Data size must be a multiple of fft_size squared" + return (data_size // total_square_size, fft_size, fft_size) + +def make_random_data(fft_size: int, run_index: int, data_size: int, seed: int = 1337) -> np.ndarray: + shape = make_shape(fft_size, data_size) + rng = np.random.default_rng(seed + fft_size * 1000 + run_index) + + real = rng.standard_normal(shape).astype(np.float32) + imag = rng.standard_normal(shape).astype(np.float32) + return (real + 1j * imag).astype(np.complex64) + +def compute_metrics(reference: np.ndarray, result: np.ndarray): + reference64 = reference.astype(np.complex128, copy=False) + result64 = result.astype(np.complex128, copy=False) + + delta = result64 - reference64 + abs_delta = np.abs(delta) + abs_reference = np.abs(reference64) + + eps = 1e-12 + relative_l2 = np.linalg.norm(delta.ravel()) / max(np.linalg.norm(reference64.ravel()), eps) + max_relative = np.max(abs_delta / np.maximum(abs_reference, eps)) + max_absolute = np.max(abs_delta) + + return float(relative_l2), float(max_relative), float(max_absolute) + +@vd.map +def kernel_mapping(scale_factor: vc.Var[vc.f32]): + read_op = vd.fft.read_op() + read_op.register[:] = read_op.register * scale_factor + +fft_size = 4096 +data_size = 16 * 1024 * 1024 + +input_data = make_random_data(fft_size, 0, data_size) +reference = np.fft.fft(input_data) + +shape = make_shape(fft_size, data_size) + +buffer = vd.buffer_c64(shape) #Buffer(shape, var_type=vd.complex64) + +buffer.write(input_data) +#vd.fft.fft(buffer, print_shader=True) +vd.fft.convolve(buffer, np.random.rand(), kernel_map=kernel_mapping, print_shader=True) +result_data = buffer.read(0) + +#print(compute_metrics(reference, result_data)) \ No newline at end of file diff --git a/test2.py b/test2.py index 994ff73a..5f494e18 100644 --- a/test2.py +++ b/test2.py @@ -1,42 +1,109 @@ import vkdispatch as vd import vkdispatch.codegen as vc + +#vd.initialize(debug_mode=True, backend="cuda") +#vc.set_codegen_backend("cuda") + +from typing import Callable, Union, Tuple + import numpy as np -SIZE = 512 +import time +import dataclasses + +@dataclasses.dataclass +class Config: + data_size: int + iter_count: int + iter_batch: int + run_count: int + signal_factor: int + warmup: int = 10 + + def make_shape(self, fft_size: int) -> Tuple[int, ...]: + total_square_size = fft_size * fft_size + assert self.data_size % total_square_size == 0, "Data size must be a multiple of fft_size squared" + return (self.data_size // total_square_size, fft_size, fft_size) + + def make_random_data(self, fft_size: int): + shape = self.make_shape(fft_size) + return np.random.rand(*shape).astype(np.complex64) + +def run_vkdispatch(config: Config, + fft_size: int, + io_count: Union[int, Callable], + gpu_function: Callable) -> float: + shape = config.make_shape(fft_size) + + buffer = vd.Buffer(shape, var_type=vd.complex64) + kernel = vd.Buffer(shape, var_type=vd.complex64) + + graph = vd.CommandGraph() + old_graph = vd.set_global_graph(graph) + + gpu_function(config, fft_size, buffer, kernel) + + vd.set_global_graph(old_graph) + + for _ in range(config.warmup): + graph.submit(config.iter_batch) + + vd.queue_wait_idle() + + if callable(io_count): + io_count = io_count(buffer.size, fft_size) + + gb_byte_count = io_count * 8 * buffer.size / (1024 * 1024 * 1024) + + start_time = time.perf_counter() + + for _ in range(config.iter_count // config.iter_batch): + graph.submit(config.iter_batch) + + vd.queue_wait_idle() + + elapsed_time = time.perf_counter() - start_time + + buffer.destroy() + kernel.destroy() + graph.destroy() + vd.fft.cache_clear() + + time.sleep(1) -buffer = vd.Buffer((SIZE, SIZE), vd.complex64) -kernel = vd.Buffer((SIZE, SIZE), vd.complex64) + vd.queue_wait_idle() -vd.fft.convolve2D(buffer, kernel, print_shader=True) + return gb_byte_count, elapsed_time -exit() -# make a square and circle signal in numpy -x = np.linspace(-1, 1, SIZE) -y = np.linspace(-1, 1, SIZE) -X, Y = np.meshgrid(x, y) -signal = np.zeros((SIZE, SIZE), dtype=np.complex64) -signal[np.abs(X) < 0.5] = 1.0 + 0j +def run_test(config: Config, + io_count: Union[int, Callable], + gpu_function: Callable): + fft_sizes = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] -signal2 = np.zeros((SIZE, SIZE), dtype=np.complex64) -signal2[np.sqrt(X**2 + Y**2) < 0.5] = 1.0 + 0j + for fft_size in fft_sizes: + rates = [] -buffer.write(signal) -kernel.write(signal2) + for _ in range(config.run_count): + gb_byte_count, elapsed_time = run_vkdispatch(config, fft_size, io_count, gpu_function) + gb_per_second = config.iter_count * gb_byte_count / elapsed_time -# perform convolution in numpy for validation -f_signal = np.fft.fft2(signal) -f_kernel = np.fft.fft2(signal2) -f_convolved = f_signal * f_kernel -convolved = np.fft.ifft2(f_convolved) + print(f"FFT Size: {fft_size}, Throughput: {gb_per_second:.4f} GB/s") + rates.append(gb_per_second) -np.save("signal.npy", signal) -np.save("kernel.npy", signal2) -np.save("convolved.npy", convolved) +def do_fft(config: Config, + fft_size: int, + buffer: vd.Buffer, + kernel: vd.Buffer): + vd.fft.fft(buffer) -vd.fft.fft2(kernel) -vd.fft.convolve2D(buffer, kernel) -vk_convolved = buffer.read(0) +conf = Config( + data_size=2**26, + iter_count=80, + iter_batch=10, + run_count=1, + signal_factor=8 +) -np.save("vk_convolved.npy", vk_convolved) \ No newline at end of file +run_test(conf, 2, do_fft) \ No newline at end of file diff --git a/vkdispatch/tests/test_async_processing.py b/tests/test_async_processing.py similarity index 81% rename from vkdispatch/tests/test_async_processing.py rename to tests/test_async_processing.py index d76a21e4..83082142 100644 --- a/vkdispatch/tests/test_async_processing.py +++ b/tests/test_async_processing.py @@ -1,6 +1,8 @@ import vkdispatch as vd import vkdispatch.codegen as vc +vd.initialize(debug_mode=True) #, log_level=vd.LogLevel.INFO) + import dataclasses import enum @@ -12,6 +14,10 @@ #vd.initialize(debug_mode=True) vd.make_context(use_cpu=True) +from vkdispatch.base.compute_plan import ComputePlan +from vkdispatch.base.descriptor_set import DescriptorSet +from vkdispatch.base.command_list import CommandList + import numpy as np class CommandType(enum.Enum): @@ -123,8 +129,32 @@ def get_array(index: int, config: RunConfig) -> np.ndarray: def make_source(commands: List[ProgramCommand]): local_size_x = vd.get_context().max_workgroup_size[0] - - header = """ + is_cuda_python = vd.is_cuda() + + if is_cuda_python: + header = ( + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X {local_size_x}\n" + "#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y 1\n" + "#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Z 1\n\n" + "struct PushConstant {\n" + " unsigned int exec_count;\n" + "};\n\n" + "extern \"C\" __global__ void vkdispatch_main(\n" + " float* vkdispatch_binding_0_ptr,\n" + " float* vkdispatch_binding_1_ptr,\n" + " const PushConstant* vkdispatch_pc_ptr\n" + ") {\n" + " const PushConstant& PC = *vkdispatch_pc_ptr;\n" + " unsigned int tid = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x);\n" + "\n" + " if (PC.exec_count <= tid) {\n" + " return;\n" + " }\n" + "\n" + " float value = vkdispatch_binding_1_ptr[tid];\n" + ) + else: + header = """ #version 450 #extension GL_ARB_separate_shader_objects : enable //#extension GL_EXT_debug_printf : enable @@ -164,20 +194,26 @@ def make_source(commands: List[ProgramCommand]): elif command.command_type == CommandType.COS_VALUE: body += f" value = cos(value);\n" - ending = """ + if is_cuda_python: + ending = """ + vkdispatch_binding_0_ptr[tid] = value; +} +""" + else: + ending = """ bufOut.data[tid] = value; } """ return header + body + ending -program_cache: Dict[int, vd.ComputePlan] = {} +program_cache: Dict[int, ComputePlan] = {} -def get_program(index: int, config: RunConfig) -> vd.ComputePlan: +def get_program(index: int, config: RunConfig) -> ComputePlan: global program_cache if index not in program_cache: - program_cache[index] = vd.ComputePlan( + program_cache[index] = ComputePlan( shader_source=make_source(config.program_commands[index]), binding_type_list=[1, 1], pc_size=4, @@ -186,9 +222,9 @@ def get_program(index: int, config: RunConfig) -> vd.ComputePlan: return program_cache[index] -descriptor_set_cache: Dict[Tuple[int, int, int], vd.DescriptorSet] = {} +descriptor_set_cache: Dict[Tuple[int, int, int], DescriptorSet] = {} -def get_descriptor_set(out_buffer: int, in_buffer: int, program: vd.ComputePlan, config: RunConfig) -> vd.DescriptorSet: +def get_descriptor_set(out_buffer: int, in_buffer: int, program: ComputePlan, config: RunConfig) -> DescriptorSet: global descriptor_set_cache dict_key = (out_buffer, in_buffer, program._handle) @@ -197,7 +233,7 @@ def get_descriptor_set(out_buffer: int, in_buffer: int, program: vd.ComputePlan, output_buffer = get_buffer(out_buffer, config) input_buffer = get_buffer(in_buffer, config) - descriptor_set = vd.DescriptorSet(program) + descriptor_set = DescriptorSet(program) descriptor_set.bind_buffer(output_buffer, 0) descriptor_set.bind_buffer(input_buffer, 1) @@ -216,7 +252,7 @@ def clear_caches(): program_cache.clear() descriptor_set_cache.clear() -def do_vkdispatch_command(cmd_list: vd.CommandList, out_buffer: int, in_buffer: int, program: int, config: RunConfig): +def do_vkdispatch_command(cmd_list: CommandList, out_buffer: int, in_buffer: int, program: int, config: RunConfig): compute_plan = get_program(program, config) descriptor_set = get_descriptor_set(out_buffer, in_buffer, compute_plan, config) @@ -266,12 +302,15 @@ def do_numpy_command(out_buffer: int, in_buffer: int, program: int, config: RunC output_array[:total_exec_size] = temp_array def test_async_commands(): + if not vd.is_vulkan(): + return + for _ in range(50): clear_caches() config = make_random_config() - cmd_list = vd.CommandList() + cmd_list = CommandList() exec_count = np.random.randint(1, 250) @@ -291,4 +330,4 @@ def test_async_commands(): assert np.allclose(vkbuffer, numpy_buffer, atol=1e-3) - clear_caches() \ No newline at end of file + clear_caches() diff --git a/vkdispatch/tests/test_buffer.py b/tests/test_buffer.py similarity index 100% rename from vkdispatch/tests/test_buffer.py rename to tests/test_buffer.py diff --git a/vkdispatch/tests/test_codegen.py b/tests/test_codegen.py similarity index 91% rename from vkdispatch/tests/test_codegen.py rename to tests/test_codegen.py index 477b0c09..b95b4e83 100644 --- a/vkdispatch/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -25,10 +25,10 @@ def test_arithmetic(): def my_shader(a: Buff[f32], b: Buff[f32]): nonlocal signal, signal2 - tid = vc.global_invocation().x + tid = vc.global_invocation_id().x - out_val = a[tid].copy() - other_val = b[tid].copy() + out_val = a[tid].to_register() + other_val = b[tid].to_register() for _ in range(op_count): op_number = np.random.randint(0, 4) diff --git a/vkdispatch/tests/test_command_graph.py b/tests/test_command_graph.py similarity index 77% rename from vkdispatch/tests/test_command_graph.py rename to tests/test_command_graph.py index db0d62a4..e2dd15ee 100644 --- a/vkdispatch/tests/test_command_graph.py +++ b/tests/test_command_graph.py @@ -2,6 +2,9 @@ import vkdispatch.codegen as vc from vkdispatch.codegen.abreviations import * + +vd.initialize(debug_mode=True) + import numpy as np def test_basic(): @@ -9,7 +12,7 @@ def test_basic(): @vd.shader(exec_size=lambda args: args.buff.size) def test_shader(buff: Buff[f32], A: Const[f32]): - tid = vc.global_invocation().x + tid = vc.global_invocation_id().x buff[tid] = buff[tid] + A @@ -19,7 +22,8 @@ def test_shader(buff: Buff[f32], A: Const[f32]): buff.write(signal) test_shader(buff, 1.0, graph=graph) - test_shader(buff, 2.0, graph=graph) + test_shader(buff, 1.0, graph=graph) + test_shader(buff, 1.0, graph=graph) graph.submit() diff --git a/tests/test_conv.py b/tests/test_conv.py new file mode 100644 index 00000000..d159f63f --- /dev/null +++ b/tests/test_conv.py @@ -0,0 +1,263 @@ +import vkdispatch as vd +import numpy as np +import random + +from typing import List + +TEST_COUNT = 4 + +def numpy_convolution_1d(signal: np.ndarray, kernel: np.ndarray) -> np.ndarray: + return np.fft.ifft( + np.fft.fft(signal).astype(np.complex64) + * + np.fft.fft(kernel).astype(np.complex64).conjugate() + ) + +def numpy_convolution(signal: np.ndarray, kernel: np.ndarray) -> np.ndarray: + return np.fft.ifft2( + np.fft.fft2(signal).astype(np.complex64) + * + np.fft.fft2(kernel).astype(np.complex64).conjugate() + ) + +def pick_radix_prime(): + return random.choice([2, 3, 5, 7, 11, 13]) + +def pick_dim_count(min_dim): + return random.choice(list(range(min_dim, 4))) + +def pick_dimention(dims: int): + if dims == 1: + return 0 + + return random.choice(list(range(dims))) + +def check_fft_dims(fft_dims: List[int], max_fft_size: int): + return all([dim <= max_fft_size for dim in fft_dims]) and np.prod(fft_dims) * vd.complex64.item_size < 2 ** 20 + +def test_convolution_1d(): + max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size + + max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) + + for _ in range(TEST_COUNT): + dims = pick_dim_count(2) + current_shape = [pick_radix_prime() for _ in range(dims)] + + while check_fft_dims(current_shape, max_fft_size): + data = np.random.rand(*current_shape).astype(np.complex64) + data2 = np.random.rand(*current_shape).astype(np.complex64) + + test_data = vd.asbuffer(data) + kernel_data = vd.asbuffer(data2) + + vd.fft.fft(kernel_data) + vd.fft.convolve(test_data, kernel_data) + + reference_data = numpy_convolution_1d(data, data2) + + assert np.allclose(reference_data, test_data.read(0), atol=1e-3) + + current_shape[pick_dimention(dims)] *= random.choice([2, 3, 5, 7, 11, 13]) + + vd.fft.cache_clear() + +def test_convolution_2d(): + max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size + + max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) + + for _ in range(TEST_COUNT): + dims = pick_dim_count(2) + current_shape = [pick_radix_prime() for _ in range(dims)] + + while check_fft_dims(current_shape, max_fft_size): + data = np.random.rand(*current_shape).astype(np.complex64) + data2 = np.random.rand(*current_shape).astype(np.complex64) + + test_data = vd.asbuffer(data) + kernel_data = vd.asbuffer(data2) + + vd.fft.fft2(kernel_data) + vd.fft.convolve2D(test_data, kernel_data) + + reference_data = numpy_convolution(data, data2) + + assert np.allclose(reference_data, test_data.read(0), atol=1e-3) + + current_shape[pick_dimention(dims)] *= random.choice([2, 3, 5, 7, 11, 13]) + + vd.fft.cache_clear() + +def test_convolution_2d_transpose(): + max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size + + max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) + + kernel_transposed_buffer = vd.Buffer((2048,), var_type=vd.complex64) + + for _ in range(TEST_COUNT): + dims = pick_dim_count(2) + current_shape = [pick_radix_prime() for _ in range(dims)] + + while check_fft_dims(current_shape, max_fft_size): + data = np.random.rand(*current_shape).astype(np.complex64) + data2 = np.random.rand(*current_shape).astype(np.complex64) + + test_data = vd.asbuffer(data) + kernel_data = vd.asbuffer(data2) + + transpose_size = vd.fft.get_transposed_size( + tuple(current_shape), + axis=len(kernel_data.shape)-2 + ) + + # Allocate new transposed buffer if needed + if transpose_size > kernel_transposed_buffer.size: + kernel_transposed_buffer.destroy() + kernel_transposed_buffer = vd.Buffer((transpose_size,), var_type=vd.complex64) + + vd.fft.fft2(kernel_data) + vd.fft.transpose(kernel_data, out_buffer=kernel_transposed_buffer, axis=len(kernel_data.shape)-2) + vd.fft.convolve2D(test_data, kernel_transposed_buffer, transposed_kernel=True) + + reference_data = numpy_convolution(data, data2) + + assert np.allclose(reference_data, test_data.read(0), atol=1e-3) + + current_shape[pick_dimention(dims)] *= random.choice([2, 3, 5, 7, 11, 13]) + + vd.fft.cache_clear() + +def test_convolution_2d_real(): + max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size + + max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) + + for _ in range(TEST_COUNT): + dims = pick_dim_count(2) + current_shape = [pick_radix_prime() for _ in range(dims)] + + while check_fft_dims(current_shape, max_fft_size): + data = np.random.rand(*current_shape).astype(np.float32) + data2 = np.random.rand(*current_shape).astype(np.float32) + + test_data = vd.asrfftbuffer(data) + kernel_data = vd.asrfftbuffer(data2) + + vd.fft.rfft2(kernel_data) + vd.fft.convolve2DR(test_data, kernel_data) + + reference_data = numpy_convolution(data, data2).real + + assert np.allclose(reference_data, test_data.read_real(0), atol=1e-3) + + current_shape[pick_dimention(dims)] *= random.choice([2, 3, 5, 7, 11, 13]) + + vd.fft.cache_clear() + +def test_convolution_2d_real_register_shuffle_edge_case(): + max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size + max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) + + # This shape triggers the register shuffle path where stage-local register usage + # is smaller than config.register_count (N=162 on convolution axis). + if max_fft_size < 162: + return + + shape = (162, 13) + data = np.random.rand(*shape).astype(np.float32) + data2 = np.random.rand(*shape).astype(np.float32) + + test_data = vd.asrfftbuffer(data) + kernel_data = vd.asrfftbuffer(data2) + + vd.fft.rfft2(kernel_data) + vd.fft.convolve2DR(test_data, kernel_data) + + reference_data = numpy_convolution(data, data2).real + assert np.allclose(reference_data, test_data.read_real(0), atol=1e-3) + + vd.fft.cache_clear() + +# def test_convolution_2d_inner(): +# max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size + +# max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) + +# for _ in range(TEST_COUNT): +# dims = 3 +# current_shape = [pick_radix_prime() for _ in range(dims)] + +# while check_fft_dims(current_shape, max_fft_size): +# data = np.random.rand(*current_shape).astype(np.complex64) +# data2 = np.random.rand(*current_shape[1:]).astype(np.complex64) + +# test_data = vd.asbuffer(data) +# kernel_data = vd.asbuffer(data2) + +# vd.fft.fft2(kernel_data) +# vd.fft.convolve2D( +# test_data, +# kernel_data, +# kernel_inner_only=True +# ) + +# reference_data = numpy_convolution(data, data2) + +# assert np.allclose(reference_data, test_data.read(0), atol=1e-3) + +# current_shape[pick_dimention(dims)] *= random.choice([2, 3, 5, 7, 11, 13]) + +# vd.fft.cache_clear() + +# def test_convolution_2d_transpose_inner(): +# max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size + +# max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) + +# kernel_transposed_buffer = vd.Buffer((2048,), var_type=vd.complex64) + +# for _ in range(TEST_COUNT): +# dims = 3 +# current_shape = [pick_radix_prime() for _ in range(dims)] + +# while check_fft_dims(current_shape, max_fft_size): +# data = np.random.rand(*current_shape).astype(np.complex64) +# data2 = np.random.rand(*current_shape[1:]).astype(np.complex64) + +# test_data = vd.asbuffer(data) +# kernel_data = vd.asbuffer(data2) + +# transpose_size = vd.fft.get_transposed_size( +# tuple(current_shape), +# axis=len(kernel_data.shape)-2 +# ) + +# # Allocate new transposed buffer if needed +# if transpose_size > kernel_transposed_buffer.size: +# kernel_transposed_buffer.destroy() +# kernel_transposed_buffer = vd.Buffer((transpose_size,), var_type=vd.complex64) + +# vd.fft.fft2(kernel_data) +# vd.fft.transpose( +# kernel_data, +# conv_shape=current_shape, +# out_buffer=kernel_transposed_buffer, +# axis=len(kernel_data.shape)-2, +# kernel_inner_only=True +# ) +# vd.fft.convolve2D( +# test_data, +# kernel_transposed_buffer, +# transposed_kernel=True, +# kernel_inner_only=True +# ) + +# reference_data = numpy_convolution(data, data2) + +# assert np.allclose(reference_data, test_data.read(0), atol=1e-3) + +# current_shape[pick_dimention(dims)] *= random.choice([2, 3, 5, 7, 11, 13]) + +# vd.fft.cache_clear() diff --git a/vkdispatch/tests/test_fft.py b/tests/test_fft.py similarity index 81% rename from vkdispatch/tests/test_fft.py rename to tests/test_fft.py index b50e0a3f..faff6f62 100644 --- a/vkdispatch/tests/test_fft.py +++ b/tests/test_fft.py @@ -4,6 +4,8 @@ from typing import List +TEST_COUNT = 4 + def numpy_convolution(signal: np.ndarray, kernel: np.ndarray) -> np.ndarray: return np.fft.ifft2( np.fft.fft2(signal).astype(np.complex64) @@ -31,7 +33,7 @@ def test_fft_1d(): max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - for _ in range(20): + for _ in range(TEST_COUNT): dims = pick_dim_count(1) current_shape = [pick_radix_prime() for _ in range(dims)] @@ -50,12 +52,14 @@ def test_fft_1d(): vd.fft.cache_clear() +test_fft_1d() + def test_fft_2d(): max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - for _ in range(20): + for _ in range(TEST_COUNT): dims = pick_dim_count(2) current_shape = [pick_radix_prime() for _ in range(dims)] @@ -78,7 +82,7 @@ def test_fft_3d(): max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - for _ in range(20): + for _ in range(TEST_COUNT): dims = 3 current_shape = [pick_radix_prime() for _ in range(dims)] @@ -101,7 +105,7 @@ def test_ifft_1d(): max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - for _ in range(20): + for _ in range(TEST_COUNT): dims = pick_dim_count(1) current_shape = [pick_radix_prime() for _ in range(dims)] @@ -125,7 +129,7 @@ def test_ifft_2d(): max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - for _ in range(20): + for _ in range(TEST_COUNT): dims = pick_dim_count(2) current_shape = [pick_radix_prime() for _ in range(dims)] @@ -148,7 +152,7 @@ def test_ifft_3d(): max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - for _ in range(20): + for _ in range(TEST_COUNT): dims = 3 current_shape = [pick_radix_prime() for _ in range(dims)] @@ -171,7 +175,7 @@ def test_rfft_1d(): max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - for _ in range(20): + for _ in range(TEST_COUNT): dims = pick_dim_count(1) current_shape = [pick_radix_prime() for _ in range(dims)] @@ -194,7 +198,7 @@ def test_rfft_2d(): max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - for _ in range(20): + for _ in range(TEST_COUNT): dims = pick_dim_count(2) current_shape = [pick_radix_prime() for _ in range(dims)] @@ -217,7 +221,7 @@ def test_rfft_3d(): max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - for _ in range(20): + for _ in range(TEST_COUNT): dims = 3 current_shape = [pick_radix_prime() for _ in range(dims)] @@ -240,7 +244,7 @@ def test_irfft_1d(): max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - for _ in range(20): + for _ in range(TEST_COUNT): dims = pick_dim_count(1) current_shape = [pick_radix_prime() for _ in range(dims)] @@ -263,7 +267,7 @@ def test_irfft_2d(): max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - for _ in range(20): + for _ in range(TEST_COUNT): dims = pick_dim_count(2) current_shape = [pick_radix_prime() for _ in range(dims)] @@ -286,7 +290,7 @@ def test_irfft_3d(): max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - for _ in range(20): + for _ in range(TEST_COUNT): dims = 3 current_shape = [pick_radix_prime() for _ in range(dims)] @@ -302,58 +306,4 @@ def test_irfft_3d(): current_shape[pick_dimention(dims)] *= random.choice([2, 3, 5, 7, 11, 13]) - vd.fft.cache_clear() - -def test_convolution_2d(): - max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size - - max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - - for _ in range(20): - dims = pick_dim_count(2) - current_shape = [pick_radix_prime() for _ in range(dims)] - - while check_fft_dims(current_shape, max_fft_size): - data = np.random.rand(*current_shape).astype(np.complex64) - data2 = np.random.rand(*current_shape).astype(np.complex64) - - test_data = vd.asbuffer(data) - kernel_data = vd.asbuffer(data2) - - vd.fft.fft2(kernel_data) - vd.fft.convolve2D(test_data, kernel_data) - - reference_data = numpy_convolution(data, data2) - - assert np.allclose(reference_data, test_data.read(0), atol=1e-3) - - current_shape[pick_dimention(dims)] *= random.choice([2, 3, 5, 7, 11, 13]) - - vd.fft.cache_clear() - -def test_convolution_2d_real(): - max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size - - max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) - - for _ in range(20): - dims = pick_dim_count(2) - current_shape = [pick_radix_prime() for _ in range(dims)] - - while check_fft_dims(current_shape, max_fft_size): - data = np.random.rand(*current_shape).astype(np.float32) - data2 = np.random.rand(*current_shape).astype(np.float32) - - test_data = vd.asrfftbuffer(data) - kernel_data = vd.asrfftbuffer(data2) - - vd.fft.rfft2(kernel_data) - vd.fft.convolve2DR(test_data, kernel_data) - - reference_data = numpy_convolution(data, data2).real - - assert np.allclose(reference_data, test_data.read_real(0), atol=1e-3) - - current_shape[pick_dimention(dims)] *= random.choice([2, 3, 5, 7, 11, 13]) - vd.fft.cache_clear() \ No newline at end of file diff --git a/tests/test_fft_mixed_precision.py b/tests/test_fft_mixed_precision.py new file mode 100644 index 00000000..62dd969f --- /dev/null +++ b/tests/test_fft_mixed_precision.py @@ -0,0 +1,320 @@ +import numpy as np +import pytest +from types import SimpleNamespace + +import vkdispatch as vd +import vkdispatch.codegen as vc +import vkdispatch.fft.functions as fft_functions + + +@pytest.fixture(autouse=True) +def _clear_fft_cache(): + yield + try: + vd.fft.cache_clear() + except Exception: + pass + + +def _require_runtime_context(): + try: + context = vd.get_context() + except Exception as exc: + pytest.skip(f"No runtime backend available for mixed-precision FFT tests: {exc}") + + is_dummy = getattr(vd, "is_dummy", None) + if callable(is_dummy) and is_dummy(): + pytest.skip("Dummy backend is codegen-only and cannot execute FFT kernels.") + + return context + + +def _supports_complex32(context) -> bool: + for device in context.device_infos: + if device.float_16_support != 1: + return False + if ( + device.storage_buffer_16_bit_access != 1 + and device.uniform_and_storage_buffer_16_bit_access != 1 + ): + return False + return True + + +def _supports_complex128(context) -> bool: + return all(device.float_64_support == 1 for device in context.device_infos) + + +def _require_complex32_support(context): + if not _supports_complex32(context): + pytest.skip("Active device set does not support complex32 (fp16) FFT buffers.") + + +def _require_complex128_support(context): + if not _supports_complex128(context): + pytest.skip("Active device set does not support complex128 (fp64) FFT buffers.") + + +def _quantize_to_complex32(values: np.ndarray) -> np.ndarray: + real = values.real.astype(np.float16).astype(np.float32) + imag = values.imag.astype(np.float16).astype(np.float32) + return (real + (1j * imag)).astype(np.complex64) + + +def _write_complex32(buffer: vd.Buffer, values: np.ndarray): + packed = np.empty(values.shape + (2,), dtype=np.float16) + packed[..., 0] = values.real.astype(np.float16) + packed[..., 1] = values.imag.astype(np.float16) + buffer.write(np.ascontiguousarray(packed)) + + +def test_fft_complex32_io_with_complex64_compute(): + context = _require_runtime_context() + _require_complex32_support(context) + + rng = np.random.default_rng(7) + data = ( + rng.standard_normal(64) + 1j * rng.standard_normal(64) + ).astype(np.complex64) + quantized = _quantize_to_complex32(data) + + test_buffer = vd.Buffer(data.shape, vd.complex32) + _write_complex32(test_buffer, data) + + vd.fft.fft(test_buffer, compute_type=vd.complex64) + + result = test_buffer.read(0).astype(np.complex64) + reference = np.fft.fft(quantized).astype(np.complex64) + + assert np.allclose(result, reference, atol=3e-1, rtol=2e-2) + + +def test_fft_map_complex32_input_to_complex128_output_auto_compute(): + context = _require_runtime_context() + _require_complex32_support(context) + _require_complex128_support(context) + + rng = np.random.default_rng(11) + data = ( + rng.standard_normal(32) + 1j * rng.standard_normal(32) + ).astype(np.complex64) + quantized = _quantize_to_complex32(data) + + input_buffer = vd.Buffer(data.shape, vd.complex32) + _write_complex32(input_buffer, data) + output_buffer = vd.Buffer(data.shape, vd.complex128) + + def input_map(buffer: vc.Buffer[vd.complex32]): + vd.fft.read_op().read_from_buffer(buffer) + + def output_map(buffer: vc.Buffer[vd.complex128]): + vd.fft.write_op().write_to_buffer(buffer) + + vd.fft.fft( + output_buffer, + input_buffer, + input_map=vd.map(input_map), + output_map=vd.map(output_map), + ) + + result = output_buffer.read(0) + reference = np.fft.fft(quantized).astype(np.complex128) + + assert np.allclose(result, reference, atol=3e-1, rtol=2e-2) + + +def test_fft_input_output_maps_allow_float32_buffers(): + _require_runtime_context() + + rng = np.random.default_rng(23) + data = rng.standard_normal(64).astype(np.float32) + + input_buffer = vd.asbuffer(data) + output_buffer = vd.Buffer(data.shape, vd.float32) + + def input_map(buffer: vc.Buffer[vd.float32]): + read_op = vd.fft.read_op() + value = vc.to_dtype(read_op.register.var_type.child_type, buffer[read_op.io_index]) + read_op.register.real = value + read_op.register.imag = vc.to_dtype(read_op.register.var_type.child_type, 0) + + def output_map(buffer: vc.Buffer[vd.float32]): + write_op = vd.fft.write_op() + buffer[write_op.io_index] = vc.to_dtype(buffer.var_type, write_op.register.real) + + vd.fft.fft( + output_buffer, + input_buffer, + input_map=vd.map(input_map), + output_map=vd.map(output_map), + ) + + result = output_buffer.read(0).astype(np.float32) + reference = np.fft.fft(data.astype(np.complex64)).real.astype(np.float32) + + assert np.allclose(result, reference, atol=2e-3, rtol=1e-3) + + +def test_convolve_kernel_map_allows_float32_buffer(): + _require_runtime_context() + + rng = np.random.default_rng(31) + data = ( + rng.standard_normal(64) + 1j * rng.standard_normal(64) + ).astype(np.complex64) + scale = np.float32(0.5) + + signal_buffer = vd.asbuffer(data.copy()) + scale_buffer = vd.asbuffer(np.full(data.shape, scale, dtype=np.float32)) + + def kernel_map(scale_values: vc.Buffer[vd.float32]): + read_op = vd.fft.read_op() + scale_value = vc.to_dtype( + read_op.register.var_type, + vc.to_complex(scale_values[read_op.io_index]), + ) + read_op.register[:] = vc.mult_complex(read_op.register, scale_value) + + vd.fft.convolve( + signal_buffer, + scale_buffer, + kernel_map=vd.map(kernel_map), + ) + + result = signal_buffer.read(0).astype(np.complex64) + reference = (data * scale).astype(np.complex64) + + assert np.allclose(result, reference, atol=2e-3, rtol=1e-3) + + +def test_fft_output_map_without_input_map_uses_explicit_input_buffer(): + _require_runtime_context() + + rng = np.random.default_rng(37) + data = ( + rng.standard_normal(64) + 1j * rng.standard_normal(64) + ).astype(np.complex64) + + input_buffer = vd.asbuffer(data.copy()) + output_buffer = vd.Buffer(data.shape, vd.complex64) + + @vd.map + def output_map(buffer: vc.Buffer[vd.complex64]): + vd.fft.write_op().write_to_buffer(buffer) + + vd.fft.fft( + output_buffer, + input_buffer, + output_map=output_map, + ) + + result = output_buffer.read(0).astype(np.complex64) + reference = np.fft.fft(data).astype(np.complex64) + + assert np.allclose(result, reference, atol=2e-3, rtol=1e-3) + + +def test_convolve_output_map_without_input_map_uses_explicit_input_buffer(): + if True: + return + _require_runtime_context() + + rng = np.random.default_rng(41) + data = ( + rng.standard_normal(64) + 1j * rng.standard_normal(64) + ).astype(np.complex64) + + input_buffer = vd.asbuffer(data.copy()) + output_buffer = vd.Buffer(data.shape, vd.complex64) + + @vd.map + def kernel_map(): + # Identity map: keep spectrum unchanged. + return + + @vd.map + def output_map(buffer: vc.Buffer[vd.complex64]): + vd.fft.write_op().write_to_buffer(buffer) + + vd.fft.convolve( + output_buffer, + input_buffer, + kernel_map=kernel_map, + output_map=output_map, + ) + + result = output_buffer.read(0).astype(np.complex64) + reference = data.astype(np.complex64) + + assert np.allclose(result, reference, atol=2e-3, rtol=1e-3) + + +def test_fft_complex64_io_with_complex128_compute(): + context = _require_runtime_context() + _require_complex128_support(context) + + rng = np.random.default_rng(29) + data = ( + rng.standard_normal(64) + 1j * rng.standard_normal(64) + ).astype(np.complex64) + + test_buffer = vd.asbuffer(data) + vd.fft.fft(test_buffer, compute_type=vd.complex128) + + result = test_buffer.read(0).astype(np.complex64) + reference = np.fft.fft(data).astype(np.complex64) + + assert np.allclose(result, reference, atol=2e-3, rtol=1e-3) + + +def test_resolve_input_precision_output_map_infers_input_from_post_map_argument(monkeypatch): + monkeypatch.setattr( + fft_functions, + "ensure_supported_complex_precision", + lambda dtype, role: None, + ) + + class _FakeBuffer: + def __init__(self, var_type): + self.var_type = var_type + + output_map = SimpleNamespace( + buffer_types=[vc.Buffer[vd.complex64], vc.Buffer[vd.float32]], + ) + + resolved = fft_functions._resolve_input_precision( + ( + _FakeBuffer(vd.complex64), + _FakeBuffer(vd.float32), + _FakeBuffer(vd.complex128), + ), + input_map=None, + output_map=output_map, + input_type=None, + output_precision=None, + ) + + assert resolved is vd.complex128 + + +def test_resolve_input_precision_output_map_requires_input_buffer_after_map_args(monkeypatch): + monkeypatch.setattr( + fft_functions, + "ensure_supported_complex_precision", + lambda dtype, role: None, + ) + + class _FakeBuffer: + def __init__(self, var_type): + self.var_type = var_type + + output_map = SimpleNamespace(buffer_types=[vc.Buffer[vd.complex64]]) + + with pytest.raises(ValueError, match="input buffer argument must be provided"): + fft_functions._resolve_input_precision( + (_FakeBuffer(vd.complex64),), + input_map=None, + output_map=output_map, + input_type=None, + output_precision=None, + ) diff --git a/tests/test_fft_padded.py b/tests/test_fft_padded.py new file mode 100644 index 00000000..9eff033a --- /dev/null +++ b/tests/test_fft_padded.py @@ -0,0 +1,98 @@ +import vkdispatch as vd +import numpy as np +import random + +from typing import List + +TEST_COUNT = 4 + +def numpy_convolution(signal: np.ndarray, kernel: np.ndarray) -> np.ndarray: + return np.fft.ifft2( + np.fft.fft2(signal).astype(np.complex64) + * + np.fft.fft2(kernel).astype(np.complex64).conjugate() + ) + +def pick_radix_prime(): + return random.choice([2, 3, 5, 7, 11, 13]) + +def pick_dim_count(min_dim): + return random.choice(list(range(min_dim, 4))) + +def pick_dimention(dims: int): + if dims == 1: + return 0 + + return random.choice(list(range(dims))) + +def check_fft_dims(fft_dims: List[int], max_fft_size: int): + return all([dim <= max_fft_size for dim in fft_dims]) and np.prod(fft_dims) * vd.complex64.item_size < 2 ** 20 + +def apply_zeros_to_numpy(data: np.ndarray, axis: int, signal_start: int, signal_end: int) -> np.ndarray: + zeroed_data = data.copy() + zeroed_data_slices = [slice(None)] * data.ndim + zeroed_data_slices[axis] = slice(0, signal_start) + zeroed_data[tuple(zeroed_data_slices)] = 0 + zeroed_data_slices[axis] = slice(signal_end, data.shape[axis]) + zeroed_data[tuple(zeroed_data_slices)] = 0 + + return zeroed_data + +def test_fft_1d(): + max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size + + max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) + + for _ in range(TEST_COUNT): + dims = pick_dim_count(1) + current_shape = [pick_radix_prime() for _ in range(dims)] + + while check_fft_dims(current_shape, max_fft_size): + data = np.random.rand(*current_shape).astype(np.complex64) + test_data = vd.Buffer(data.shape, vd.complex64) + + for axis in range(dims): + test_data.write(data) + + signal_start = np.random.randint(0, data.shape[axis]-1) + signal_end = np.random.randint(signal_start + 1, data.shape[axis] + 1) + + vd.fft.fft(test_data, axis=axis, input_signal_range=(signal_start, signal_end)) + + zeroed_data = apply_zeros_to_numpy(data, axis, signal_start, signal_end) + + assert np.allclose(np.fft.fft(zeroed_data, axis=axis), test_data.read(0), atol=1e-3) + + current_shape[pick_dimention(dims)] *= random.choice([2, 3, 5, 7, 11, 13]) + + vd.fft.cache_clear() + + + +def test_rfft_1d(): + max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size + + max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) + + for _ in range(TEST_COUNT): + dims = pick_dim_count(1) + current_shape = [pick_radix_prime() for _ in range(dims)] + + while check_fft_dims(current_shape, max_fft_size): + data = np.random.rand(*current_shape).astype(np.float32) + test_data = vd.RFFTBuffer(data.shape) + + test_data.write_real(data) + + signal_start = np.random.randint(0, data.shape[-1]-1) + signal_end = np.random.randint(signal_start + 1, data.shape[-1] + 1) + + vd.fft.fft(test_data, buffer_shape=test_data.real_shape, r2c=True, input_signal_range=(signal_start, signal_end)) + + zeroed_data = apply_zeros_to_numpy(data, -1, signal_start, signal_end) + + assert np.allclose(np.fft.rfft(zeroed_data), test_data.read_fourier(0), atol=1e-3) + + current_shape[pick_dimention(dims)] *= random.choice([2, 3, 5, 7, 11, 13]) + + vd.fft.cache_clear() diff --git a/vkdispatch/tests/test_image.py b/tests/test_image.py similarity index 85% rename from vkdispatch/tests/test_image.py rename to tests/test_image.py index de120a96..2a03478c 100644 --- a/vkdispatch/tests/test_image.py +++ b/tests/test_image.py @@ -8,6 +8,9 @@ vd.initialize(log_level=vd.LogLevel.WARNING, debug_mode=True) def test_1d_image_creation(): + if not vd.is_vulkan(): + return + # Create a 1D image signal = np.sin(np.array([i/8 for i in range(0, 50, 1)])).astype(np.float32) @@ -17,6 +20,8 @@ def test_1d_image_creation(): assert np.allclose(test_line.read(0), signal) def test_2d_image_creation(): + if not vd.is_vulkan(): + return # Create a 2D image signal_2d = np.sin(np.array([[i/8 + j/17 for i in range(0, 50, 1)] for j in range(0, 50, 1)])).astype(np.float32) @@ -26,6 +31,8 @@ def test_2d_image_creation(): assert np.allclose(test_img.read(0), signal_2d) def test_3d_image_creation(): + if not vd.is_vulkan(): + return # Create a 3D image signal_3d = np.sin(np.array([[[i/8 + j/17 + k/23 for i in range(0, 50, 1)] for j in range(0, 50, 1)] for k in range(0, 50, 1)])).astype(np.float32) @@ -35,6 +42,8 @@ def test_3d_image_creation(): assert np.allclose(test_img.read(0), signal_3d) def test_1d_image_linear_sampling(): + if not vd.is_vulkan(): + return # Create a 1D image signal = np.sin(np.array([i/8 for i in range(0, 50, 1)])).astype(np.float32) @@ -47,8 +56,8 @@ def test_1d_image_linear_sampling(): @vd.shader("buff.size") def do_approx(buff: Buff[f32], line: Img1[f32]): - ind = vc.global_invocation().x.copy() - buff[ind] = line.sample((ind.cast_to(f32)) / sample_factor).x + ind = vc.global_invocation_id().x.to_register() + buff[ind] = line.sample((ind.to_dtype(f32)) / sample_factor).x do_approx(result_arr, test_line.sample()) @@ -57,6 +66,8 @@ def do_approx(buff: Buff[f32], line: Img1[f32]): assert np.allclose(result_arr.read()[0], signal_full, atol=0.002) def test_2d_image_linear_sampling(): + if not vd.is_vulkan(): + return # Create a 2D image signal_2d = np.sin(np.array([[i/8 + j/17 for i in range(0, 50, 1)] for j in range(0, 50, 1)])).astype(np.float32) sample_factor = 10 @@ -68,9 +79,10 @@ def test_2d_image_linear_sampling(): @vd.shader("buff.size") def do_approx(buff: Buff[f32], img: Img2[f32]): - ind = vc.global_invocation().x.copy() - ind_2d = vc.unravel_index(ind, buff.shape) - buff[ind] = img.sample((ind_2d.cast_to(v2)) / sample_factor).x + ind = vc.global_invocation_id().x.to_register() + ind_2d = vc.ravel_index(ind, buff.shape).to_register() + ind_2d_transposed = vc.new_vec2_register(ind_2d.y, ind_2d.x) + buff[ind] = img.sample(ind_2d_transposed / sample_factor).x do_approx(result_arr, test_img.sample()) @@ -78,6 +90,7 @@ def do_approx(buff: Buff[f32], img: Img2[f32]): assert np.allclose(result_arr.read()[0], signal_full, atol=0.0025) + # def test_3d_image_linear_sampling(): # # Create a 3D image # signal_3d = np.sin(np.array([[[i/8 + j/17 + k/23 for i in range(0, 5, 1)] for j in range(0, 5, 1)] for k in range(0, 5, 1)]).astype(np.float32)) diff --git a/tests/test_ravel.py b/tests/test_ravel.py new file mode 100644 index 00000000..b186bf5c --- /dev/null +++ b/tests/test_ravel.py @@ -0,0 +1,92 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc + +from vkdispatch.base.dtype import to_vector + +import numpy as np + +from typing import Tuple + + +def run_index_ravel(shape: Tuple[int, ...], index: Tuple[int, ...], shape_static: bool): + var_type = to_vector(vd.uint32, len(shape)) + + buffer = vd.Buffer(shape, var_type=var_type) + + @vd.shader("buff.size") + def test_shader(buff: vc.Buff[var_type]): # pyright: ignore[reportInvalidTypeForm] + ind = vc.global_invocation_id().x + buff[ind] = vc.ravel_index( + ind, + shape if shape_static else buff.shape + ).swizzle("xyz"[:len(shape)]) + + test_shader(buffer) + + result_value = buffer.read(0) + + assert tuple(result_value[index]) == tuple(index), f"Expected index {index}, got {tuple(result_value[index])}" + + buffer.destroy() + +def test_index_ravel(): + for _ in range(100): + shape_len = np.random.choice([2, 3]) + shape = tuple(np.random.randint(1, 100) for _ in range(shape_len)) + index = tuple(np.random.randint(0, shape[i]) for i in range(shape_len)) + + run_index_ravel(shape, index, False) + run_index_ravel(shape, index, True) + +def run_index_unravel(shape: Tuple[int, ...], index: Tuple[int, ...], input_static: bool, shape_static: bool): + data = np.random.rand(*shape).astype(np.float32) + buffer = vd.asbuffer(data) + + result_buffer = vd.Buffer((1,), var_type=vd.float32) + + index_type = vd.int32 + + if len(index) == 2: + index_type = vd.ivec2 + elif len(index) == 3: + index_type = vd.ivec3 + + if input_static and shape_static: + @vd.shader(1) + def test_shader(buff: vc.Buff[vc.f32], buff_in: vc.Buff[vc.f32]): + buff[0] = buff_in[vc.unravel_index(index, shape)] + elif input_static and not shape_static: + @vd.shader(1) + def test_shader(buff: vc.Buff[vc.f32], buff_in: vc.Buff[vc.f32]): + buff[0] = buff_in[vc.unravel_index(index, buff_in.shape)] + elif not input_static and shape_static: + @vd.shader(1) + def test_shader(buff: vc.Buff[vc.f32], buff_in: vc.Buff[vc.f32]): + index_vec = vc.new_register(index_type, *index) + buff[0] = buff_in[vc.unravel_index(index_vec, shape)] + elif not input_static and not shape_static: + @vd.shader(1) + def test_shader(buff: vc.Buff[vc.f32], buff_in: vc.Buff[vc.f32]): + index_vec = vc.new_register(index_type, *index) + buff[0] = buff_in[vc.unravel_index(index_vec, buff_in.shape)] + + test_shader(result_buffer, buffer) + + result_value = result_buffer.read(0)[0] + reference_value = data[index] + + assert np.isclose(result_value, reference_value, atol=1e-5), f"Expected {reference_value}, got {result_value}" + + buffer.destroy() + result_buffer.destroy() + +def test_index_unravel(): + for _ in range(100): + shape_len = np.random.choice([1, 2, 3]) + shape = tuple(np.random.randint(1, 100) for _ in range(shape_len)) + index = tuple(np.random.randint(0, shape[i]) for i in range(shape_len)) + + run_index_unravel(shape, index, False, False) + run_index_unravel(shape, index, False, True) + run_index_unravel(shape, index, True, False) + run_index_unravel(shape, index, True, True) \ No newline at end of file diff --git a/vkdispatch/tests/test_reductions.py b/tests/test_reductions.py similarity index 60% rename from vkdispatch/tests/test_reductions.py rename to tests/test_reductions.py index 6abf895b..3bed232d 100644 --- a/vkdispatch/tests/test_reductions.py +++ b/tests/test_reductions.py @@ -18,9 +18,9 @@ def test_reductions_sum(): # Write the data to the buffer buf.write(data) - @vd.map_reduce(vd.SubgroupAdd) + @vd.reduce.map_reduce(vd.reduce.SubgroupAdd) def sum_map(buffer: Buff[f32]) -> f32: - return buffer[vc.mapping_index()] + return buffer[vd.reduce.mapped_io_index()] res_buf = sum_map(buf) @@ -40,9 +40,9 @@ def test_mapped_reductions(): # Write the data to the buffer buf.write(data) - @vd.map_reduce(vd.SubgroupAdd) + @vd.reduce.map_reduce(vd.reduce.SubgroupAdd) def sum_map(buffer: Buff[f32]) -> f32: - return vc.sin(buffer[vc.mapping_index()]) + return vc.sin(buffer[vd.reduce.mapped_io_index()]) res_buf = sum_map(buf) @@ -65,9 +65,9 @@ def test_listed_reductions(): buf.write(data) buf2.write(data2) - @vd.map_reduce(vd.SubgroupAdd) + @vd.reduce.map_reduce(vd.reduce.SubgroupAdd) def sum_map(buffer: Buff[v2], buffer2: Buff[v2]) -> v2: - ind = vc.mapping_index() + ind = vd.reduce.mapped_io_index() return vc.sin(buffer[ind] + buffer2[ind]) graph = vd.CommandGraph() @@ -78,6 +78,8 @@ def sum_map(buffer: Buff[v2], buffer2: Buff[v2]) -> v2: graph.submit() + vd.queue_wait_idle() + # Read the data from the buffer read_data = res_buf.read(0) @@ -95,10 +97,9 @@ def test_pure_reductions(): # Write the data to the buffer buf = vd.asbuffer(data) - @vd.reduce(0) + @vd.reduce.reduce(0) def sum_reduce(a: f32, b: f32) -> f32: - result = (a + b).copy() - return result + return a + b res_buf = sum_reduce(buf) @@ -123,12 +124,11 @@ def test_pure_reductions_with_mapping_function(): @vd.map def reduction_map(input: Buff[f32]) -> f32: - return vc.sin(input[vc.mapping_index()]) + return vc.sin(input[vd.reduce.mapped_io_index()]) - @vd.reduce(0, mapping_function=reduction_map) + @vd.reduce.reduce(0, mapping_function=reduction_map) def sum_reduce(a: f32, b: f32) -> f32: - result = (a + b).copy() - return result + return a + b res_buf = sum_reduce(buf) @@ -150,9 +150,9 @@ def test_batched_mapped_reductions(): # Write the data to the buffer buf = vd.asbuffer(data) - @vd.map_reduce(vd.SubgroupAdd, axes=[1]) + @vd.reduce.map_reduce(vd.reduce.SubgroupAdd, axes=[1]) def sum_map(buffer: Buff[f32]) -> f32: - return vc.sin(buffer[vc.mapping_index()]) + return vc.sin(buffer[vd.reduce.mapped_io_index()]) res_buf = sum_map(buf) @@ -160,4 +160,63 @@ def sum_map(buffer: Buff[f32]) -> f32: read_data = res_buf.read(0)[0] # Check that the data is the same - assert np.allclose([np.sin(data).sum(axis=1)], [read_data]) \ No newline at end of file + assert np.allclose([np.sin(data).sum(axis=1)], [read_data]) + +def test_mapped_reductions_min(): + # Create a buffer + buf = vd.Buffer((1024,), vd.float32) + + # Create a numpy array + data = np.random.randn(1024).astype(np.float32) + + # Write the data to the buffer + buf.write(data) + + @vd.reduce.map_reduce(vd.reduce.SubgroupMin) + def min_map(buffer: Buff[f32]) -> f32: + return buffer[vd.reduce.mapped_io_index()] + + res_buf = min_map(buf) + + # Read the data from the buffer + read_data = res_buf.read(0) + + # Check that the data is the same + assert np.allclose([data.min()], [read_data[0]]) + +def test_mapped_reductions_max(): + # Create a buffer + buf = vd.Buffer((1024,), vd.float32) + + # Create a numpy array + data = np.random.randn(1024).astype(np.float32) + + # Write the data to the buffer + buf.write(data) + + @vd.reduce.map_reduce(vd.reduce.SubgroupMax) + def max_map(buffer: Buff[f32]) -> f32: + return buffer[vd.reduce.mapped_io_index()] + + res_buf = max_map(buf) + + # Read the data from the buffer + read_data = res_buf.read(0) + + # Check that the data is the same + assert np.allclose([data.max()], [read_data[0]]) + +def test_min_max_codegen_stage_creation(): + @vd.reduce.map_reduce(vd.reduce.SubgroupMin) + def min_map(buffer: Buff[f32]) -> f32: + return buffer[vd.reduce.mapped_io_index()] + + @vd.reduce.map_reduce(vd.reduce.SubgroupMax) + def max_map(buffer: Buff[f32]) -> f32: + return buffer[vd.reduce.mapped_io_index()] + + min_src_stage1, min_src_stage2 = min_map.get_src() + max_src_stage1, max_src_stage2 = max_map.get_src() + + assert min_src_stage1 and min_src_stage2 + assert max_src_stage1 and max_src_stage2 diff --git a/tests/test_threading.py b/tests/test_threading.py new file mode 100644 index 00000000..62d6c7f6 --- /dev/null +++ b/tests/test_threading.py @@ -0,0 +1,68 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc +import numpy as np +import threading +import time + +def test_concurrent_shader_generation_robust(): + barrier_enter = threading.Barrier(2) + barrier_exit = threading.Barrier(2) + + thread_data = {} + thread_errors = [] + + def thread_task(thread_id): + try: + unique_name = f"var_thread_{thread_id}" + + @vd.shader(exec_size=(1,)) + def concurrent_shader(buf: vc.Buff[vc.f32]): + barrier_enter.wait() + + active_builder = vc.get_builder() + thread_data[f"builder_{thread_id}"] = active_builder + + reg = vc.new_float_register(1.0, var_name=unique_name) + buf[0] = reg + + barrier_exit.wait() + + concurrent_shader.build() + + thread_data[f"source_{thread_id}"] = concurrent_shader.source + + except Exception as e: + thread_errors.append(e) + + t1 = threading.Thread(target=thread_task, args=(1,)) + t2 = threading.Thread(target=thread_task, args=(2,)) + + t1.start() + t2.start() + + t1.join() + t2.join() + + if thread_errors: + raise RuntimeError(f"Thread failed: {thread_errors[0]}") + + b1 = thread_data["builder_1"] + b2 = thread_data["builder_2"] + + assert b1 is not b2, ( + f"THREAD SAFETY FAILURE: Both threads retrieved the exact same " + f"ShaderBuilder instance ({id(b1)}). This means `GlobalBuilder` is shared." + ) + + src_1 = thread_data["source_1"] + src_2 = thread_data["source_2"] + + assert "var_thread_1" in src_1, "Thread 1 failed to generate its own variable." + assert "var_thread_2" not in src_1, ( + "LEAK DETECTED: Thread 2's variable 'var_thread_2' appeared in Thread 1's source code." + ) + + assert "var_thread_2" in src_2, "Thread 2 failed to generate its own variable." + assert "var_thread_1" not in src_2, ( + "LEAK DETECTED: Thread 1's variable 'var_thread_1' appeared in Thread 2's source code." + ) \ No newline at end of file diff --git a/vkdispatch/tests/test_vkfft.py b/tests/test_vkfft.py similarity index 95% rename from vkdispatch/tests/test_vkfft.py rename to tests/test_vkfft.py index 49b2bf70..caf8a480 100644 --- a/vkdispatch/tests/test_vkfft.py +++ b/tests/test_vkfft.py @@ -20,6 +20,8 @@ def check_fft_dims(fft_dims: List[int], max_fft_size: int): return all([dim <= max_fft_size for dim in fft_dims]) and np.prod(fft_dims) * vd.complex64.item_size < 2 ** 20 def test_fft_1d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -44,6 +46,8 @@ def test_fft_1d(): vd.vkfft.clear_plan_cache() def test_fft_2d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -67,6 +71,8 @@ def test_fft_2d(): vd.vkfft.clear_plan_cache() def test_fft_3d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -90,6 +96,8 @@ def test_fft_3d(): vd.vkfft.clear_plan_cache() def test_ifft_1d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -114,6 +122,8 @@ def test_ifft_1d(): vd.vkfft.clear_plan_cache() def test_ifft_2d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -137,6 +147,8 @@ def test_ifft_2d(): vd.vkfft.clear_plan_cache() def test_ifft_3d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -160,6 +172,8 @@ def test_ifft_3d(): vd.vkfft.clear_plan_cache() def test_rfft_1d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -183,6 +197,8 @@ def test_rfft_1d(): vd.vkfft.clear_plan_cache() def test_rfft_2d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -206,6 +222,8 @@ def test_rfft_2d(): vd.vkfft.clear_plan_cache() def test_rfft_3d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -229,6 +247,8 @@ def test_rfft_3d(): vd.vkfft.clear_plan_cache() def test_irfft_1d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -252,6 +272,8 @@ def test_irfft_1d(): vd.vkfft.clear_plan_cache() def test_irfft_2d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) @@ -275,6 +297,8 @@ def test_irfft_2d(): vd.vkfft.clear_plan_cache() def test_irfft_3d(): + if not vd.is_vulkan(): + return max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size max_fft_size = min(max_fft_size, vd.get_context().max_workgroup_size[0]) diff --git a/tests/test_vkfft_conv.py b/tests/test_vkfft_conv.py new file mode 100644 index 00000000..a4404c80 --- /dev/null +++ b/tests/test_vkfft_conv.py @@ -0,0 +1,74 @@ +import vkdispatch as vd +import random + +from typing import List +import numpy as np + +#vd.initialize(log_level=vd.LogLevel.INFO, debug_mode=True) +vd.initialize() + +def numpy_convolution(signal: np.ndarray, kernel: np.ndarray) -> np.ndarray: + return np.fft.ifft2( + np.fft.fft2(signal).astype(np.complex64) + * + np.fft.fft2(kernel).astype(np.complex64) + ) + +def pick_radix_prime(): + return random.choice([2, 3, 5, 7, 11, 13]) + +def pick_dim_count(min_dim): + return random.choice(list(range(min_dim, 4))) + +def pick_dimention(dims: int): + if dims == 1: + return 0 + + return random.choice(list(range(dims))) + +def check_fft_dims(fft_dims: List[int], max_fft_size: int): + return all([dim <= max_fft_size for dim in fft_dims]) and np.prod(fft_dims) * vd.complex64.item_size < 2 ** 29 + +def test_convolution_2d_powers_of_2(): + if not vd.is_vulkan(): + return + max_fft_size = vd.get_context().max_shared_memory // vd.complex64.item_size + + buffer_cache = {} + kernel_cache = {} + + for i in range(3): + current_shape = [512, 16, 16] + + while current_shape[1] <= 4096: + data = np.random.rand(*current_shape).astype(np.complex64) + data2 = np.random.rand(*current_shape).astype(np.complex64) + + shape_key = tuple(current_shape) + if shape_key in buffer_cache: + test_data = buffer_cache[shape_key] + test_data.write(data) + else: + test_data = vd.asbuffer(data) + buffer_cache[shape_key] = test_data + + if shape_key in kernel_cache: + kernel_data = kernel_cache[shape_key] + kernel_data.write(data2) + else: + kernel_data = vd.asbuffer(data2) + kernel_cache[shape_key] = kernel_data + + vd.vkfft.transpose_kernel2D(kernel_data) + vd.vkfft.convolve2D(test_data, kernel_data, normalize=True) + + reference_data = numpy_convolution(data, data2) + + assert np.allclose(reference_data, test_data.read(0), atol=1e-3) + + current_shape[0] //= 2 + current_shape[1] *= 2 + current_shape[2] *= 2 + + vd.fft.cache_clear() + \ No newline at end of file diff --git a/vkdispatch/__init__.py b/vkdispatch/__init__.py index 91ea0327..27e99e2a 100644 --- a/vkdispatch/__init__.py +++ b/vkdispatch/__init__.py @@ -1,30 +1,42 @@ -from .base.errors import check_for_errors -from .base.errors import check_for_compute_stage_errors - from .base.init import DeviceInfo from .base.init import LogLevel from .base.init import get_devices +from .base.init import get_backend, is_vulkan, is_cuda, is_opencl, is_dummy from .base.init import initialize from .base.init import is_initialized from .base.init import log, log_error, log_warning, log_info, log_verbose, set_log_level from .base.dtype import dtype -from .base.dtype import float32, int32, uint32, complex64 -from .base.dtype import vec2, vec3, vec4, ivec2, ivec3, ivec4, uvec2, uvec3, uvec4 -from .base.dtype import mat2, mat4 -from .base.dtype import is_scalar, is_complex, is_vector, is_matrix, is_dtype -from .base.dtype import to_numpy_dtype, from_numpy_dtype, to_vector - -from .base.context import get_context, queue_wait_idle +from .base.dtype import float16, float32, float64, int16, uint16, int32, uint32, int64, uint64 +from .base.dtype import complex32, complex64, complex128 +from .base.dtype import hvec2, hvec3, hvec4 +from .base.dtype import vec2, vec3, vec4 +from .base.dtype import dvec2, dvec3, dvec4 +from .base.dtype import ihvec2, ihvec3, ihvec4 +from .base.dtype import ivec2, ivec3, ivec4 +from .base.dtype import uhvec2, uhvec3, uhvec4 +from .base.dtype import uvec2, uvec3, uvec4 +from .base.dtype import mat2, mat3, mat4 + +from .base.context import get_context, queue_wait_idle, Signal from .base.context import get_context_handle -from .base.context import make_context, select_queue_families +from .base.context import make_context, select_queue_families, set_dummy_context_params from .base.context import is_context_initialized from .base.buffer import asbuffer +from .base.buffer import from_cuda_array from .base.buffer import Buffer from .base.buffer import asrfftbuffer from .base.buffer import RFFTBuffer +from .base.buffer_allocators import buffer_u32, buffer_uv2, buffer_uv3, buffer_uv4 +from .base.buffer_allocators import buffer_i32, buffer_iv2, buffer_iv3, buffer_iv4 +from .base.buffer_allocators import buffer_f32, buffer_v2, buffer_v3, buffer_v4, buffer_c64 +from .base.buffer_allocators import buffer_u16, buffer_uhv2, buffer_uhv3, buffer_uhv4 +from .base.buffer_allocators import buffer_i16, buffer_ihv2, buffer_ihv3, buffer_ihv4 +from .base.buffer_allocators import buffer_f16, buffer_hv2, buffer_hv3, buffer_hv4 +from .base.buffer_allocators import buffer_f64, buffer_dv2, buffer_dv3, buffer_dv4 + from .base.image import image_format from .base.image import image_type from .base.image import image_view_type @@ -38,39 +50,17 @@ from .base.image import AddressMode from .base.image import BorderColor -from .base.compute_plan import ComputePlan - -from .base.descriptor_set import DescriptorSet - -from .base.command_list import CommandList - -from .execution_pipeline.buffer_builder import BufferUsage, BufferedStructEntry, BufferBuilder - from .execution_pipeline.command_graph import CommandGraph, BufferBindInfo, ImageBindInfo from .execution_pipeline.command_graph import global_graph, set_global_graph, default_graph +from .execution_pipeline.cuda_graph_capture import cuda_graph_capture, get_cuda_capture, CUDAGraphCapture -from .shader_generation.signature import ShaderArgumentType -from .shader_generation.signature import ShaderArgument -from .shader_generation.signature import ShaderSignature - -from .shader_generation.shader_object import ShaderObject -from .shader_generation.shader_object import ExectionBounds -from .shader_generation.shader_object import LaunchParametersHolder - -from .shader_generation.mapping_shader import map, map_registers, MappingFunction - -from .shader_generation.reduction_operations import ReductionOperation, SubgroupAdd, SubgroupMul, SubgroupMin -from .shader_generation.reduction_operations import SubgroupMax, SubgroupAnd, SubgroupOr, SubgroupXor - -from .shader_generation.reduction_stage import make_reduction_stage, ReductionParams - -from .shader_generation.reduction_object import ReductionObject - -from .shader_generation.decorators import shader, reduce, map_reduce +from .shader.shader_function import ShaderFunction, ShaderSource +from .shader.context import ShaderContext, shader_context +from .shader.map import map, MappingFunction +from .shader.decorator import shader import vkdispatch.vkfft as vkfft import vkdispatch.fft as fft +import vkdispatch.reduce as reduce -import vkdispatch.fft as fft - -__version__ = "0.0.30" +__version__ = "0.0.34" diff --git a/vkdispatch/backends/__init__.py b/vkdispatch/backends/__init__.py new file mode 100644 index 00000000..a9a2c5b3 --- /dev/null +++ b/vkdispatch/backends/__init__.py @@ -0,0 +1 @@ +__all__ = [] diff --git a/vkdispatch/backends/backend_selection.py b/vkdispatch/backends/backend_selection.py new file mode 100644 index 00000000..6a3836b9 --- /dev/null +++ b/vkdispatch/backends/backend_selection.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import importlib +from types import ModuleType +from typing import Dict, Optional + +import os + +BACKEND_VULKAN = "vulkan" +BACKEND_CUDA = "cuda" +BACKEND_OPENCL = "opencl" +BACKEND_DUMMY = "dummy" + +_VALID_BACKENDS = {BACKEND_VULKAN, BACKEND_CUDA, BACKEND_OPENCL, BACKEND_DUMMY} +_active_backend_name: Optional[str] = None +_backend_modules: Dict[str, ModuleType] = {} + + +class BackendUnavailableError(ImportError): + def __init__(self, backend_name: str, message: str): + super().__init__(message) + self.backend_name = backend_name + + +def normalize_backend_name(backend: Optional[str]) -> str: + if backend is None: + return BACKEND_VULKAN + + backend_name = backend.strip().lower() + if backend_name not in _VALID_BACKENDS: + valid = ", ".join(sorted(_VALID_BACKENDS)) + raise ValueError(f"Unknown backend '{backend}'. Expected one of: {valid}") + + return backend_name + + +def set_active_backend(backend: str) -> str: + global _active_backend_name + + backend_name = normalize_backend_name(backend) + + if _active_backend_name is not None and _active_backend_name != backend_name: + raise RuntimeError( + f"Backend is already set to '{_active_backend_name}' and cannot be changed to '{backend_name}' in this process." + ) + + _active_backend_name = backend_name + return _active_backend_name + + +def clear_active_backend() -> None: + global _active_backend_name + _active_backend_name = None + +def get_environment_backend() -> Optional[str]: + env_backend = os.environ.get("VKDISPATCH_BACKEND") + if env_backend is not None: + return normalize_backend_name(env_backend) + return None + +def get_active_backend_name(default: Optional[str] = None) -> str: + if _active_backend_name is not None: + return _active_backend_name + + if default is not None: + return normalize_backend_name(default) + + env_backend = get_environment_backend() + + if env_backend is not None: + return env_backend + + return BACKEND_VULKAN + + +def _load_backend_module(backend_name: str) -> ModuleType: + if backend_name in _backend_modules: + return _backend_modules[backend_name] + + try: + if backend_name == BACKEND_VULKAN: + module = importlib.import_module("vkdispatch_vulkan_native") + elif backend_name == BACKEND_CUDA: + module = importlib.import_module("vkdispatch.backends.cuda_backend") + elif backend_name == BACKEND_OPENCL: + module = importlib.import_module("vkdispatch.backends.opencl_backend") + elif backend_name == BACKEND_DUMMY: + module = importlib.import_module("vkdispatch.backends.dummy_backend") + else: + # Defensive guard for future refactors. + raise ValueError(f"Unsupported backend '{backend_name}'") + except ImportError as exc: + if backend_name == BACKEND_VULKAN: + raise BackendUnavailableError( + backend_name, + "Vulkan backend is unavailable because the 'vkdispatch_native' package " + f"could not be imported ({exc}).", + ) from exc + if backend_name == BACKEND_CUDA: + raise BackendUnavailableError( + backend_name, + "CUDA Python backend is unavailable because the " + "'vkdispatch.backends.cuda_backend' module could not be imported " + f"({exc}).", + ) from exc + if backend_name == BACKEND_OPENCL: + raise BackendUnavailableError( + backend_name, + "OpenCL backend is unavailable because the " + "'vkdispatch.backends.opencl_backend' module could not be imported " + f"({exc}).", + ) from exc + raise + + _backend_modules[backend_name] = module + return module + + +def get_backend_module(backend: Optional[str] = None) -> ModuleType: + backend_name = normalize_backend_name(backend) if backend is not None else get_active_backend_name() + return _load_backend_module(backend_name) + + +class _BackendProxy: + def __getattr__(self, name: str): + return getattr(get_backend_module(), name) + + +native = _BackendProxy() diff --git a/vkdispatch/backends/cuda_backend/__init__.py b/vkdispatch/backends/cuda_backend/__init__.py new file mode 100644 index 00000000..a4bf6927 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/__init__.py @@ -0,0 +1,111 @@ +"""cuda-python-backed runtime shim mirroring the vkdispatch_native API surface. + +This module intentionally matches the function names exposed by the Cython +extension so existing Python runtime objects can call into either backend. +""" + +from __future__ import annotations + +from .api_buffer import ( + buffer_create, + buffer_create_external, + buffer_destroy, + buffer_get_queue_signal, + buffer_read, + buffer_read_staging, + buffer_wait_staging_idle, + buffer_write, + buffer_write_staging, +) +from .api_command_list import ( + command_list_create, + command_list_destroy, + command_list_get_instance_size, + command_list_reset, + command_list_submit, + stage_compute_record +) +from .api_compute import ( + stage_compute_plan_create, + stage_compute_plan_destroy, +) +from .api_context import ( + context_create, + context_destroy, + context_stop_threads, + cuda_stream_override_begin, + cuda_stream_override_end, + get_devices, + get_error_string, + init, + log, + set_log_level, +) +from .descriptor_sets import ( + descriptor_set_create, + descriptor_set_destroy, + descriptor_set_write_buffer, + descriptor_set_write_image, + descriptor_set_write_inline_uniform, +) +from .image_fft_stubs import ( + image_create, + image_create_sampler, + image_destroy, + image_destroy_sampler, + image_format_block_size, + image_read, + image_write, + stage_fft_plan_create, + stage_fft_plan_destroy, + stage_fft_record, +) +from .signal import signal_destroy, signal_insert, signal_wait + +__all__ = [ + "init", + "log", + "set_log_level", + "get_devices", + "context_create", + "signal_wait", + "signal_insert", + "signal_destroy", + "context_destroy", + "get_error_string", + "context_stop_threads", + "cuda_stream_override_begin", + "cuda_stream_override_end", + "buffer_create", + "buffer_create_external", + "buffer_destroy", + "buffer_get_queue_signal", + "buffer_wait_staging_idle", + "buffer_write_staging", + "buffer_read_staging", + "buffer_write", + "buffer_read", + "command_list_create", + "command_list_destroy", + "command_list_get_instance_size", + "command_list_reset", + "command_list_submit", + "descriptor_set_create", + "descriptor_set_destroy", + "descriptor_set_write_buffer", + "descriptor_set_write_image", + "descriptor_set_write_inline_uniform", + "image_create", + "image_destroy", + "image_create_sampler", + "image_destroy_sampler", + "image_write", + "image_format_block_size", + "image_read", + "stage_compute_plan_create", + "stage_compute_plan_destroy", + "stage_compute_record", + "stage_fft_plan_create", + "stage_fft_plan_destroy", + "stage_fft_record", +] diff --git a/vkdispatch/backends/cuda_backend/api_buffer.py b/vkdispatch/backends/cuda_backend/api_buffer.py new file mode 100644 index 00000000..3502fe96 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/api_buffer.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +from . import state as state +from .cuda_primitives import cuda +from .helpers import ( + activate_context, + allocate_staging_storage, + buffer_device_ptr, + context_from_handle, + new_handle, + queue_indices, + set_error, + stream_for_queue, + to_bytes, +) +from .state import CUDABuffer + +from .signal import CUDASignal, signal_destroy + +def buffer_create(context, size, per_device): + _ = per_device + + ctx = context_from_handle(int(context)) + if ctx is None: + return 0 + + size = int(size) + if size <= 0: + set_error("Buffer size must be greater than zero") + return 0 + + try: + with activate_context(ctx): + allocation = cuda.mem_alloc(size) + + signal_handles = [ + CUDASignal(context_handle=int(context), queue_index=i, done=True).handle + for i in range(ctx.queue_count) + ] + + obj = CUDABuffer( + context_handle=int(context), + size=size, + device_ptr=int(allocation), + device_allocation=allocation, + owns_allocation=True, + staging_data=[allocate_staging_storage(size) for _ in range(ctx.queue_count)], + signal_handles=signal_handles, + ) + return new_handle(state.buffers, obj) + except Exception as exc: + set_error(f"Failed to create CUDA buffer: {exc}") + return 0 + + +def buffer_create_external(context, size, device_ptr): + ctx = context_from_handle(int(context)) + if ctx is None: + return 0 + + size = int(size) + device_ptr = int(device_ptr) + + if size <= 0: + set_error("External buffer size must be greater than zero") + return 0 + + if device_ptr == 0: + set_error("External buffer device pointer must be non-zero") + return 0 + + try: + signal_handles = [ + CUDASignal(context_handle=int(context), queue_index=i, done=True).handle + for i in range(ctx.queue_count) + ] + + obj = CUDABuffer( + context_handle=int(context), + size=size, + device_ptr=device_ptr, + device_allocation=None, + owns_allocation=False, + staging_data=[allocate_staging_storage(size) for _ in range(ctx.queue_count)], + signal_handles=signal_handles, + ) + return new_handle(state.buffers, obj) + except Exception as exc: + set_error(f"Failed to create external CUDA buffer alias: {exc}") + return 0 + + +def buffer_destroy(buffer): + obj = state.buffers.pop(int(buffer), None) + if obj is None: + return + + for signal_handle in obj.signal_handles: + signal_destroy(signal_handle) + + ctx = state.contexts.get(obj.context_handle) + if ctx is None or not obj.owns_allocation or obj.device_allocation is None: + return + + try: + with activate_context(ctx): + obj.device_allocation.free() + except Exception: + pass + + +def buffer_get_queue_signal(buffer, queue_index): + obj = state.buffers.get(int(buffer)) + if obj is None: + return CUDASignal(context_handle=0, queue_index=0, done=True).handle + + queue_index = int(queue_index) + if queue_index < 0 or queue_index >= len(obj.signal_handles): + queue_index = 0 + + return obj.signal_handles[queue_index] + + +def buffer_wait_staging_idle(buffer, queue_index): + signal_handle = buffer_get_queue_signal(buffer, queue_index) + signal_obj = CUDASignal.from_handle(signal_handle) + if signal_obj is None: + return True + return signal_obj.query() + + +def buffer_write_staging(buffer, queue_index, data, size): + obj = state.buffers.get(int(buffer)) + if obj is None: + return + + queue_index = int(queue_index) + if queue_index < 0 or queue_index >= len(obj.staging_data): + return + + payload = to_bytes(data) + size = min(int(size), len(payload), obj.size) + if size <= 0: + return + + payload_view = memoryview(payload)[:size] + staging_view = memoryview(obj.staging_data[queue_index]) + staging_view[:size] = payload_view + + +def buffer_read_staging(buffer, queue_index, size): + obj = state.buffers.get(int(buffer)) + if obj is None: + return bytes(int(size)) + + queue_index = int(queue_index) + if queue_index < 0 or queue_index >= len(obj.staging_data): + return bytes(int(size)) + + size = max(0, int(size)) + staging = obj.staging_data[queue_index] + + if size <= len(staging): + return bytes(staging[:size]) + + return bytes(staging) + bytes(size - len(staging)) + + +def buffer_write(buffer, offset, size, index): + obj = state.buffers.get(int(buffer)) + if obj is None: + return + + ctx = state.contexts.get(obj.context_handle) + if ctx is None: + set_error(f"Missing context for buffer handle {buffer}") + return + + offset = int(offset) + size = int(size) + if size <= 0 or offset < 0: + return + + try: + with activate_context(ctx): + for queue_index in queue_indices(ctx, int(index), all_on_negative=True): + stream = stream_for_queue(ctx, queue_index) + end = min(offset + size, obj.size) + copy_size = end - offset + if copy_size <= 0: + continue + + src_view = memoryview(obj.staging_data[queue_index])[:copy_size] + cuda.memcpy_htod_async(buffer_device_ptr(obj) + offset, src_view, stream) + + signal = CUDASignal.from_handle(obj.signal_handles[queue_index]) + if signal is not None: + signal.record(stream) + except Exception as exc: + set_error(f"Failed to write CUDA buffer: {exc}") + + +def buffer_read(buffer, offset, size, index): + obj = state.buffers.get(int(buffer)) + if obj is None: + return + + ctx = state.contexts.get(obj.context_handle) + if ctx is None: + set_error(f"Missing context for buffer handle {buffer}") + return + + queue_index = int(index) + if queue_index < 0 or queue_index >= ctx.queue_count: + set_error(f"Invalid queue index {queue_index} for buffer read") + return + + offset = int(offset) + size = int(size) + if size <= 0 or offset < 0: + return + + try: + with activate_context(ctx): + stream = stream_for_queue(ctx, queue_index) + end = min(offset + size, obj.size) + copy_size = end - offset + if copy_size <= 0: + return + + dst_view = memoryview(obj.staging_data[queue_index])[:copy_size] + cuda.memcpy_dtoh_async(dst_view, buffer_device_ptr(obj) + offset, stream) + + signal = CUDASignal.from_handle(obj.signal_handles[queue_index]) + if signal is not None: + signal.record(stream) + except Exception as exc: + set_error(f"Failed to read CUDA buffer: {exc}") diff --git a/vkdispatch/backends/cuda_backend/api_command_list.py b/vkdispatch/backends/cuda_backend/api_command_list.py new file mode 100644 index 00000000..8c80c102 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/api_command_list.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +from typing import List, Optional, Tuple + +from . import state as state +from .helpers import ( + activate_context, + build_kernel_args_template, + estimate_kernel_param_size_bytes, + new_handle, + queue_indices, + set_error, + stream_for_queue, + to_bytes, +) +from .state import CUDACommandList, CUDAComputePlan, CUDACommandRecord + +from .descriptor_sets import CUDADescriptorSet + +import dataclasses + +@dataclasses.dataclass +class CUDAResolvedLaunch: + plan: CUDAComputePlan + blocks: Tuple[int, int, int] + descriptor_set: Optional[CUDADescriptorSet] + pc_size: int + pc_offset: int + static_args: Optional[Tuple[object, ...]] = None + +def command_list_create(context): + if int(context) not in state.contexts: + set_error("Invalid context handle for command_list_create") + return 0 + + return new_handle(state.command_lists, CUDACommandList(context_handle=int(context))) + + +def command_list_destroy(command_list): + obj = state.command_lists.pop(int(command_list), None) + if obj is None: + return + + ctx = state.contexts.get(obj.context_handle) + if ctx is None: + return + + +def command_list_get_instance_size(command_list): + obj = state.command_lists.get(int(command_list)) + if obj is None: + return 0 + + return int(sum(int(command.pc_size) for command in obj.commands)) + + +def command_list_reset(command_list): + obj = state.command_lists.get(int(command_list)) + if obj is None: + return + + obj.commands = [] + + +def command_list_submit(command_list, data, instance_count, index): + obj = state.command_lists.get(int(command_list)) + if obj is None: + return True + + ctx = state.contexts.get(obj.context_handle) + if ctx is None: + set_error(f"Missing context for command list {command_list}") + return True + + instance_count = int(instance_count) + if instance_count <= 0: + return True + + instance_size = command_list_get_instance_size(command_list) + payload = to_bytes(data) + expected_payload_size = int(instance_size) * int(instance_count) + + if expected_payload_size == 0: + if len(payload) != 0: + set_error( + f"Unexpected push-constant data for command list with instance_size=0 " + f"(got {len(payload)} bytes)." + ) + return True + elif len(payload) != expected_payload_size: + set_error( + f"Push-constant data size mismatch. Expected {expected_payload_size} bytes " + f"(instance_size={instance_size}, instance_count={instance_count}) but got {len(payload)} bytes." + ) + return True + + queue_targets = queue_indices(ctx, int(index), all_on_negative=True) + if len(queue_targets) == 0: + queue_targets = [0] + + try: + with activate_context(ctx): + for queue_index in queue_targets: + stream = stream_for_queue(ctx, queue_index) + resolved_launches: List[CUDAResolvedLaunch] = [] + per_instance_offset = 0 + + for command in obj.commands: + plan = state.compute_plans.get(command.plan_handle) + if plan is None: + raise RuntimeError(f"Invalid compute plan handle {command.plan_handle}") + + descriptor_set = None + if command.descriptor_set_handle != 0: + descriptor_set = CUDADescriptorSet.from_handle(command.descriptor_set_handle) + if descriptor_set is None: + raise RuntimeError( + f"Invalid descriptor set handle {command.descriptor_set_handle}" + ) + + command_pc_size = int(command.pc_size) + first_instance_payload = b"" + if command_pc_size > 0 and len(payload) > 0: + first_instance_payload = payload[per_instance_offset: per_instance_offset + command_pc_size] + + static_args = None + if command_pc_size == 0: + static_args = build_kernel_args_template(plan, descriptor_set, b"") + size_check_args = static_args + else: + size_check_args = build_kernel_args_template( + plan, + descriptor_set, + first_instance_payload, + ) + + estimated_param_size = estimate_kernel_param_size_bytes(size_check_args) + if estimated_param_size > int(ctx.max_kernel_param_size): + shader_name = plan.shader_name.decode("utf-8", errors="replace") + raise RuntimeError( + f"Kernel '{shader_name}' launch parameters require " + f"{estimated_param_size} bytes, exceeding device limit " + f"{ctx.max_kernel_param_size} bytes. " + "Reduce by-value uniform/push-constant payload size or switch large " + "uniform data to buffer-backed arguments." + ) + resolved_launches.append( + CUDAResolvedLaunch( + plan=plan, + blocks=command.blocks, + descriptor_set=descriptor_set, + pc_size=command_pc_size, + pc_offset=per_instance_offset, + static_args=static_args, + ) + ) + per_instance_offset += command_pc_size + + if per_instance_offset != instance_size: + raise RuntimeError( + f"Internal command list size mismatch: computed {per_instance_offset} bytes, " + f"expected {instance_size} bytes." + ) + + for instance_index in range(instance_count): + instance_base_offset = instance_index * instance_size + for launch in resolved_launches: + if launch.static_args is not None: + args = launch.static_args + else: + pc_start = instance_base_offset + launch.pc_offset + pc_end = pc_start + launch.pc_size + pc_payload = payload[pc_start:pc_end] + args = build_kernel_args_template( + launch.plan, + launch.descriptor_set, + pc_payload, + ) + + launch.plan.function( + *args, + block=launch.plan.local_size, + grid=launch.blocks, + stream=stream, + ) + except Exception as exc: + set_error(f"Failed to submit CUDA command list: {exc}") + + return True + +def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z): + cl = state.command_lists.get(int(command_list)) + cp = state.compute_plans.get(int(plan)) + if cl is None or cp is None: + set_error("Invalid command list or compute plan handle for stage_compute_record") + return + + cl.commands.append( + CUDACommandRecord( + plan_handle=int(plan), + descriptor_set_handle=int(descriptor_set), + blocks=(int(blocks_x), int(blocks_y), int(blocks_z)), + pc_size=int(cp.pc_size), + ) + ) diff --git a/vkdispatch/backends/cuda_backend/api_compute.py b/vkdispatch/backends/cuda_backend/api_compute.py new file mode 100644 index 00000000..8db48b43 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/api_compute.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from . import state as state +from .cuda_primitives import SourceModule, cuda +from .helpers import ( + activate_context, + context_from_handle, + new_handle, + parse_kernel_params, + parse_local_size, + set_error, + to_bytes, +) +from .state import CUDAComputePlan + + +def _nvrtc_compile_options(ctx): + options = ["-w"] + + try: + dev = cuda.Device(ctx.device_index) + cc_major, cc_minor = dev.compute_capability() + options.append(f"--gpu-architecture=sm_{int(cc_major)}{int(cc_minor)}") + except Exception: + pass + + return options + + +def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name): + ctx = context_from_handle(int(context)) + if ctx is None: + return 0 + + source_bytes = to_bytes(shader_source) + shader_name_bytes = to_bytes(shader_name) + source_text = source_bytes.decode("utf-8", errors="replace") + + try: + with activate_context(ctx): + module = SourceModule( + source_text, + no_extern_c=True, + options=_nvrtc_compile_options(ctx), + ) + function = module.get_function("vkdispatch_main") + except Exception as exc: + set_error(f"Failed to compile CUDA kernel '{shader_name_bytes.decode(errors='ignore')}': {exc}") + return 0 + + try: + params = parse_kernel_params(source_text) + local_size = parse_local_size(source_text) + except Exception as exc: + set_error(f"Failed to parse CUDA kernel metadata: {exc}") + return 0 + + plan = CUDAComputePlan( + context_handle=int(context), + shader_source=source_bytes, + bindings=[int(x) for x in bindings], + shader_name=shader_name_bytes, + module=module, + function=function, + local_size=local_size, + params=params, + pc_size=int(pc_size), + ) + + return new_handle(state.compute_plans, plan) + + +def stage_compute_plan_destroy(plan): + if plan is None: + return + state.compute_plans.pop(int(plan), None) diff --git a/vkdispatch/backends/cuda_backend/api_context.py b/vkdispatch/backends/cuda_backend/api_context.py new file mode 100644 index 00000000..7232b2c5 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/api_context.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +import hashlib + +from . import state as state +from .cuda_primitives import cuda +from .helpers import ( + activate_context, + clear_error, + coerce_stream_handle, + new_handle, + query_max_kernel_param_size, + set_error, + stream_override_stack, +) +from .state import CUDAContext + + +def init(debug, log_level): + state.debug_mode = bool(debug) + state.log_level = int(log_level) + clear_error() + + if state.initialized: + return + + cuda.init() + state.initialized = True + + +def log(log_level, text, file_str, line_str): + _ = log_level + _ = text + _ = file_str + _ = line_str + + +def set_log_level(log_level): + state.log_level = int(log_level) + + +def get_devices(): + if not state.initialized: + init(False, state.log_level) + + try: + device_count = cuda.Device.count() + except Exception as exc: + set_error(f"Failed to enumerate CUDA devices: {exc}") + return [] + + driver_version = 0 + try: + driver_version = int(cuda.get_driver_version()) + except Exception: + driver_version = 0 + + devices = [] + + for index in range(device_count): + dev = cuda.Device(index) + attrs = dev.get_attributes() + cc_major, cc_minor = dev.compute_capability() + total_memory = int(dev.total_memory()) + + max_workgroup_size = ( + int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_X, 0)), + int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_Y, 0)), + int(attrs.get(cuda.device_attribute.MAX_BLOCK_DIM_Z, 0)), + ) + + max_workgroup_count = ( + int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_X, 0)), + int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_Y, 0)), + int(attrs.get(cuda.device_attribute.MAX_GRID_DIM_Z, 0)), + ) + + subgroup_size = int(attrs.get(cuda.device_attribute.WARP_SIZE, 0)) + max_shared_memory = int( + attrs.get(cuda.device_attribute.MAX_SHARED_MEMORY_PER_BLOCK, 0) + ) + + try: + bus_id = str(dev.pci_bus_id()) + except Exception: + bus_id = f"cuda-device-{index}" + + uuid_bytes = hashlib.md5(bus_id.encode("utf-8")).digest() + + devices.append( + ( + 0, # Vulkan variant + int(cc_major), # major + int(cc_minor), # minor + 0, # patch + driver_version, + 0, # vendor id unknown in this API layer + index, # device id + 2, # discrete gpu + str(dev.name()), + 1, # shader_buffer_float32_atomics + 1, # shader_buffer_float32_atomic_add + 1, # float64 support + 1 if (cc_major > 5 or (cc_major == 5 and cc_minor >= 3)) else 0, # float16 support + 1, # int64 + 1, # int16 + 1, # storage_buffer_16_bit_access + 1, # uniform_and_storage_buffer_16_bit_access + 1, # storage_push_constant_16 + 1, # storage_input_output_16 + max_workgroup_size, + int(attrs.get(cuda.device_attribute.MAX_THREADS_PER_BLOCK, 0)), + max_workgroup_count, + 8, # max descriptor sets (virtualized for parity) + 4096, # max push constant size + min(total_memory, (1 << 31) - 1), + 65536, + 16, + subgroup_size, + 0x7FFFFFFF, # supported stages (virtualized for parity) + 0x7FFFFFFF, # supported operations (virtualized for parity) + 1, + max_shared_memory, + [(1, 0x002)], # compute queue + 1, # scalar block layout + 1, # timeline semaphores equivalent + uuid_bytes, + ) + ) + + return devices + + +def context_create(device_indicies, queue_families): + if not state.initialized: + init(False, state.log_level) + + try: + device_ids = [int(x) for x in device_indicies] + except Exception: + set_error("context_create expected a list of integer device indices") + return 0 + + if len(device_ids) != 1: + set_error("CUDA Python backend currently supports exactly one device") + return 0 + + if len(queue_families) != 1 or len(queue_families[0]) != 1: + set_error("CUDA Python backend currently supports exactly one queue") + return 0 + + device_index = device_ids[0] + + cuda_context = None + context_pushed = False + + try: + if device_index < 0 or device_index >= cuda.Device.count(): + set_error(f"Invalid CUDA device index {device_index}") + return 0 + + dev = cuda.Device(device_index) + cc_major, _cc_minor = dev.compute_capability() + max_kernel_param_size = query_max_kernel_param_size(dev.device_raw, cc_major) + uses_primary_context = False + + if hasattr(dev, "retain_primary_context"): + cuda_context = dev.retain_primary_context() + uses_primary_context = True + cuda_context.push() + else: # pragma: no cover - fallback for older CUDA Python + cuda_context = dev.make_context() + context_pushed = True + stream = cuda.Stream() + + ctx = CUDAContext( + device_index=device_index, + cuda_context=cuda_context, + streams=[stream], + queue_count=1, + queue_to_device=[0], + max_kernel_param_size=int(max_kernel_param_size), + uses_primary_context=uses_primary_context, + stopped=False, + ) + handle = new_handle(state.contexts, ctx) + + # Leave no context current after creation. + cuda.Context.pop() + context_pushed = False + return handle + except Exception as exc: + if context_pushed: + try: + cuda.Context.pop() + except Exception: + pass + + if cuda_context is not None: + try: + cuda_context.detach() + except Exception: + pass + + set_error(f"Failed to create CUDA Python context: {exc}") + return 0 + + +def context_destroy(context): + ctx = state.contexts.pop(int(context), None) + if ctx is None: + return + + try: + with activate_context(ctx): + for stream in ctx.streams: + stream.synchronize() + except Exception: + pass + + try: + ctx.cuda_context.detach() + except Exception: + pass + + +def context_stop_threads(context): + ctx = state.contexts.get(int(context)) + if ctx is not None: + ctx.stopped = True + + +def get_error_string(): + if state.error_string is None: + return 0 + return state.error_string + + +def cuda_stream_override_begin(stream_obj): + try: + stack = stream_override_stack() + stack.append(coerce_stream_handle(stream_obj)) + except Exception as exc: + set_error(f"Failed to activate external CUDA stream override: {exc}") + + +def cuda_stream_override_end(): + stack = stream_override_stack() + if len(stack) > 0: + stack.pop() diff --git a/vkdispatch/backends/cuda_backend/bindings.py b/vkdispatch/backends/cuda_backend/bindings.py new file mode 100644 index 00000000..be7d82ee --- /dev/null +++ b/vkdispatch/backends/cuda_backend/bindings.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +import ctypes +import importlib.util +import os +from pathlib import Path +import shutil +import sys +from typing import List, Optional + +try: + import numpy as np +except Exception as exc: # pragma: no cover - import failure path + raise ImportError( + "The CUDA Python backend requires both 'cuda-python' and 'numpy' to be installed." + ) from exc + +try: + from cuda.bindings import driver, nvrtc +except Exception: + try: + from cuda import cuda as driver # type: ignore + from cuda import nvrtc # type: ignore + except Exception as exc: # pragma: no cover - import failure path + raise ImportError( + "The CUDA Python backend requires the NVIDIA cuda-python package " + "(`pip install cuda-python`)." + ) from exc + + +def to_int(value) -> int: + if isinstance(value, int): + return int(value) + + if hasattr(value, "value"): + try: + return int(value.value) + except Exception: + pass + + return int(value) + + +def drv_call(names, *args): + if isinstance(names, str): + names = [names] + + last_error = None + for name in names: + fn = getattr(driver, name, None) + if fn is not None: + try: + return fn(*args) + except TypeError as exc: + last_error = exc + continue + + if last_error is not None: + raise RuntimeError(f"CUDA Driver call failed for {names}: {last_error}") from last_error + raise RuntimeError(f"CUDA Driver symbol not found: {names}") + + +def nvrtc_call(names, *args): + if isinstance(names, str): + names = [names] + + last_error = None + for name in names: + fn = getattr(nvrtc, name, None) + if fn is not None: + try: + return fn(*args) + except TypeError as exc: + last_error = exc + continue + + if last_error is not None: + raise RuntimeError(f"NVRTC call failed for {names}: {last_error}") from last_error + raise RuntimeError(f"NVRTC symbol not found: {names}") + + +def status_success(status) -> bool: + try: + return to_int(status) == 0 + except Exception: + return str(status).endswith("CUDA_SUCCESS") or str(status).endswith("NVRTC_SUCCESS") + + +def drv_error_string(status) -> str: + try: + name_res = drv_call("cuGetErrorName", status) + string_res = drv_call("cuGetErrorString", status) + _name_status = name_res[0] if isinstance(name_res, tuple) else 1 + _string_status = string_res[0] if isinstance(string_res, tuple) else 1 + if status_success(_name_status) and status_success(_string_status): + name = name_res[1] if isinstance(name_res, tuple) and len(name_res) > 1 else name_res + text = string_res[1] if isinstance(string_res, tuple) and len(string_res) > 1 else string_res + if isinstance(name, (bytes, bytearray)): + name = name.decode("utf-8", errors="replace") + if isinstance(text, (bytes, bytearray)): + text = text.decode("utf-8", errors="replace") + return f"{name}: {text}" + except Exception: + pass + + return str(status) + + +def drv_check(result, op_name: str): + if isinstance(result, tuple): + status = result[0] + payload = result[1:] + else: + status = result + payload = () + + if not status_success(status): + raise RuntimeError(f"{op_name} failed ({drv_error_string(status)})") + + if len(payload) == 0: + return None + + if len(payload) == 1: + return payload[0] + + return payload + + +def nvrtc_check(result, op_name: str): + if isinstance(result, tuple): + status = result[0] + payload = result[1:] + else: + status = result + payload = () + + if not status_success(status): + raise RuntimeError(f"{op_name} failed ({status})") + + if len(payload) == 0: + return None + + if len(payload) == 1: + return payload[0] + + return payload + + +def nvrtc_read_bytes(program, size_api: str, read_api: str) -> bytes: + raw_size = nvrtc_check(nvrtc_call(size_api, program), size_api) + size = int(to_int(raw_size)) + if size <= 0: + return b"" + + def _normalize_output(data) -> Optional[bytes]: + if data is None: + return None + + if isinstance(data, memoryview): + data = data.tobytes() + elif isinstance(data, str): + data = data.encode("utf-8", errors="replace") + + if isinstance(data, (bytes, bytearray)): + raw = bytes(data) + if len(raw) >= size: + return raw[:size] + return raw + (b"\x00" * (size - len(raw))) + + if isinstance(data, (tuple, list)): + for item in data: + normalized = _normalize_output(item) + if normalized is not None: + return normalized + + return None + + try: + direct_data = nvrtc_check(nvrtc_call(read_api, program), read_api) + normalized = _normalize_output(direct_data) + if normalized is not None: + return normalized + except Exception: + pass + + out_c = ctypes.create_string_buffer(size) + out_bytearray = bytearray(size) + out_bytes = bytes(size) + + for out_candidate in (out_bytes, out_bytearray, out_c): + try: + call_result = nvrtc_check(nvrtc_call(read_api, program, out_candidate), read_api) + normalized_result = _normalize_output(call_result) + if normalized_result is not None: + return normalized_result + + if isinstance(out_candidate, bytearray): + return bytes(out_candidate) + + if out_candidate is out_c: + return bytes(out_c.raw) + except Exception: + continue + + return bytes(out_c.raw) + + +def discover_cuda_include_dirs() -> List[str]: + include_dirs: List[str] = [] + seen = set() + + def add_dir(path_like) -> None: + if path_like is None: + return + try: + resolved = str(Path(path_like).resolve()) + except Exception: + resolved = str(path_like) + if resolved in seen: + return + header_path = Path(resolved) / "cuda_runtime.h" + if header_path.exists(): + seen.add(resolved) + include_dirs.append(resolved) + + # Standard CUDA environment variables. + for env_name in ( + "CUDA_HOME", + "CUDA_PATH", + "CUDA_ROOT", + "CUDA_TOOLKIT_ROOT_DIR", + "CUDAToolkit_ROOT", + ): + root = os.environ.get(env_name) + if root: + add_dir(Path(root) / "include") + + # CUDA toolkit from nvcc location. + nvcc_path = shutil.which("nvcc") + if nvcc_path: + try: + nvcc_root = Path(nvcc_path).resolve().parent.parent + add_dir(nvcc_root / "include") + except Exception: + pass + + # Common Unix install locations. + add_dir("/usr/local/cuda/include") + add_dir("/opt/cuda/include") + add_dir("/usr/include") + + # Conda cudatoolkit layouts. + conda_prefix = os.environ.get("CONDA_PREFIX") + if conda_prefix: + add_dir(Path(conda_prefix) / "include") + add_dir(Path(conda_prefix) / "targets" / "x86_64-linux" / "include") + add_dir(Path(conda_prefix) / "Library" / "include") + + # NVIDIA pip wheel layout. + for base in sys.path: + add_dir(Path(base) / "nvidia" / "cuda_runtime" / "include") + + # Some environments expose this namespace package. + try: + spec = importlib.util.find_spec("nvidia.cuda_runtime") + if spec is not None and spec.submodule_search_locations: + for entry in spec.submodule_search_locations: + add_dir(Path(entry) / "include") + except Exception: + pass + + return include_dirs + + +def prepare_nvrtc_options(options: List[bytes]) -> List[bytes]: + normalized: List[bytes] = [] + has_include_path = False + + for opt in options: + as_str = opt.decode("utf-8", errors="replace") + if as_str.startswith("-I") or as_str.startswith("--include-path"): + has_include_path = True + normalized.append(opt) + + if not has_include_path: + for include_dir in discover_cuda_include_dirs(): + normalized.append(f"--include-path={include_dir}".encode("utf-8")) + + return normalized + + +def as_driver_handle(type_name: str, value): + handle_type = getattr(driver, type_name, None) + if handle_type is None: + return value + + try: + if isinstance(value, handle_type): + return value + except Exception: + pass + + try: + return handle_type(to_int(value)) + except Exception: + return value + + +def writable_host_ptr(view: memoryview): + byte_view = view.cast("B") + try: + c_buffer = (ctypes.c_ubyte * len(byte_view)).from_buffer(byte_view) + return ctypes.addressof(c_buffer), c_buffer + except Exception: + copied = ctypes.create_string_buffer(byte_view.tobytes()) + return ctypes.addressof(copied), copied + + +def readonly_host_ptr(view: memoryview): + byte_view = view.cast("B") + try: + c_buffer = (ctypes.c_ubyte * len(byte_view)).from_buffer(byte_view) + return ctypes.addressof(c_buffer), c_buffer + except Exception: + copied = ctypes.create_string_buffer(byte_view.tobytes()) + return ctypes.addressof(copied), copied diff --git a/vkdispatch/backends/cuda_backend/constants.py b/vkdispatch/backends/cuda_backend/constants.py new file mode 100644 index 00000000..246346be --- /dev/null +++ b/vkdispatch/backends/cuda_backend/constants.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import re + +# Log level constants mirrored from native bindings. +LOG_LEVEL_VERBOSE = 0 +LOG_LEVEL_INFO = 1 +LOG_LEVEL_WARNING = 2 +LOG_LEVEL_ERROR = 3 + +# Descriptor type enum values mirrored from vkdispatch_native/stages_extern.pxd. +DESCRIPTOR_TYPE_STORAGE_BUFFER = 1 +DESCRIPTOR_TYPE_STORAGE_IMAGE = 2 +DESCRIPTOR_TYPE_UNIFORM_BUFFER = 3 +DESCRIPTOR_TYPE_UNIFORM_IMAGE = 4 +DESCRIPTOR_TYPE_SAMPLER = 5 + +LOCAL_X_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_X\s+(\d+)") +LOCAL_Y_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Y\s+(\d+)") +LOCAL_Z_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Z\s+(\d+)") +KERNEL_SIGNATURE_RE = re.compile(r"vkdispatch_main\s*\(([^)]*)\)", re.S) +BINDING_PARAM_RE = re.compile(r"vkdispatch_binding_(\d+)_ptr$") +SAMPLER_PARAM_RE = re.compile(r"vkdispatch_sampler_(\d+)$") diff --git a/vkdispatch/backends/cuda_backend/cuda_primitives.py b/vkdispatch/backends/cuda_backend/cuda_primitives.py new file mode 100644 index 00000000..8a3af54a --- /dev/null +++ b/vkdispatch/backends/cuda_backend/cuda_primitives.py @@ -0,0 +1,571 @@ +from __future__ import annotations + +import ctypes +from dataclasses import dataclass +from typing import List, Optional + +from .bindings import ( + np, + driver, + as_driver_handle, + discover_cuda_include_dirs, + drv_call, + drv_check, + nvrtc_call, + nvrtc_check, + nvrtc_read_bytes, + prepare_nvrtc_options, + readonly_host_ptr, + status_success, + to_int, + writable_host_ptr, +) + + +@dataclass +class _ByValueKernelArg: + payload: bytes + raw_name: str + + +class _DeviceAllocation: + def __init__(self, ptr: int): + self.ptr = int(ptr) + self.freed = False + + def __int__(self): + return int(self.ptr) + + def free(self): + if self.freed: + return + + drv_check( + drv_call( + ["cuMemFree", "cuMemFree_v2"], + as_driver_handle("CUdeviceptr", self.ptr), + ), + "cuMemFree", + ) + self.freed = True + + +class _ContextHandle: + def __init__(self, context_raw, device_index: int, uses_primary_context: bool): + self.context_raw = context_raw + self.device_index = int(device_index) + self.uses_primary_context = bool(uses_primary_context) + self._detached = False + + def push(self): + drv_check( + drv_call( + "cuCtxPushCurrent", + as_driver_handle("CUcontext", self.context_raw), + ), + "cuCtxPushCurrent", + ) + + def detach(self): + if self._detached: + return + + if self.uses_primary_context: + dev = drv_check(drv_call("cuDeviceGet", int(self.device_index)), "cuDeviceGet") + drv_check(drv_call("cuDevicePrimaryCtxRelease", dev), "cuDevicePrimaryCtxRelease") + else: + drv_check( + drv_call( + ["cuCtxDestroy", "cuCtxDestroy_v2"], + as_driver_handle("CUcontext", self.context_raw), + ), + "cuCtxDestroy", + ) + self._detached = True + + +class _StreamHandle: + def __init__(self, handle: Optional[int] = None, ptr: Optional[int] = None, *args, **kwargs): + _ = kwargs + if handle is None and ptr is None and len(args) == 1: + handle = int(args[0]) + if handle is None and ptr is not None: + handle = int(ptr) + + if handle is None: + stream_raw = drv_check(drv_call("cuStreamCreate", 0), "cuStreamCreate") + self.handle = int(to_int(stream_raw)) + self.owned = True + else: + self.handle = int(handle) + self.owned = False + + def synchronize(self): + drv_check( + drv_call( + "cuStreamSynchronize", + as_driver_handle("CUstream", self.handle), + ), + "cuStreamSynchronize", + ) + + def __int__(self): + return int(self.handle) + + @property + def ptr(self): + return int(self.handle) + + @property + def cuda_stream(self): + return int(self.handle) + + +class _EventHandle: + def __init__(self): + self.event_raw = drv_check(drv_call("cuEventCreate", 0), "cuEventCreate") + + def record(self, stream_obj: Optional["_StreamHandle"]): + stream_handle = 0 if stream_obj is None else int(stream_obj) + drv_check( + drv_call( + "cuEventRecord", + self.event_raw, + as_driver_handle("CUstream", stream_handle), + ), + "cuEventRecord", + ) + + def query(self) -> bool: + res = drv_call("cuEventQuery", self.event_raw) + status = res[0] if isinstance(res, tuple) else res + + if status_success(status): + return True + + status_text = str(status) + if "NOT_READY" in status_text: + return False + + if to_int(status) != 0: + return False + + return True + + def synchronize(self): + drv_check(drv_call("cuEventSynchronize", self.event_raw), "cuEventSynchronize") + + +class _KernelFunction: + def __init__(self, function_raw): + self.function_raw = function_raw + + def __call__(self, *args, block, grid, stream=None): + arg_values = [] + + def _dedupe(values): + out = [] + seen = set() + for value in values: + key = f"{type(value).__name__}:{repr(value)}" + if key in seen: + continue + seen.add(key) + out.append(value) + return out + + arg_ptr_values = [] + for arg in args: + if isinstance(arg, _ByValueKernelArg): + payload = arg.payload + if len(payload) == 0: + payload = b"\x00" + + payload_storage = (ctypes.c_ubyte * len(payload)).from_buffer_copy(payload) + arg_values.append(payload_storage) + arg_ptr_values.append(ctypes.addressof(payload_storage)) + continue + + scalar_storage = ctypes.c_uint64(int(arg)) + arg_values.append(scalar_storage) + arg_ptr_values.append(ctypes.addressof(scalar_storage)) + + arg_ptr_array = None + if len(arg_ptr_values) > 0: + arg_ptr_array = (ctypes.c_void_p * len(arg_ptr_values))( + *[ctypes.c_void_p(ptr) for ptr in arg_ptr_values] + ) + + kernel_param_variants = [None, 0, ctypes.c_void_p(0)] + if arg_ptr_array is not None: + array_ptr = ctypes.cast(arg_ptr_array, ctypes.POINTER(ctypes.c_void_p)) + kernel_param_variants = _dedupe( + [ + arg_ptr_array, + array_ptr, + ctypes.cast(array_ptr, ctypes.c_void_p), + ctypes.cast(array_ptr, ctypes.c_void_p).value, + tuple(arg_ptr_values), + list(arg_ptr_values), + ] + ) + + stream_handle = 0 if stream is None else int(stream) + stream_variants = _dedupe( + [ + stream_handle, + as_driver_handle("CUstream", stream_handle), + ] + ) + + function_candidates = [ + self.function_raw, + as_driver_handle("CUfunction", self.function_raw), + ] + try: + function_candidates.append(to_int(self.function_raw)) + except Exception: + pass + function_variants = _dedupe(function_candidates) + + extra_variants = [None, 0, ctypes.c_void_p(0)] + last_error = None + + for function_handle in function_variants: + for stream_value in stream_variants: + for kernel_params in kernel_param_variants: + for extra in extra_variants: + try: + drv_check( + drv_call( + "cuLaunchKernel", + function_handle, + int(grid[0]), + int(grid[1]), + int(grid[2]), + int(block[0]), + int(block[1]), + int(block[2]), + 0, + stream_value, + kernel_params, + extra, + ), + "cuLaunchKernel", + ) + return + except Exception as exc: + last_error = exc + + try: + drv_check( + drv_call( + "cuLaunchKernel", + function_handle, + int(grid[0]), + int(grid[1]), + int(grid[2]), + int(block[0]), + int(block[1]), + int(block[2]), + 0, + stream_value, + kernel_params, + ), + "cuLaunchKernel", + ) + return + except Exception as exc: + last_error = exc + continue + + if last_error is None: + raise RuntimeError("cuLaunchKernel failed with no diagnostic.") + raise RuntimeError(f"cuLaunchKernel failed: {last_error}") from last_error + + +class SourceModule: + def __init__(self, source: str, no_extern_c: bool = True, options: Optional[List[str]] = None): + _ = no_extern_c + if options is None: + options = [] + + program_name = b"vkdispatch.cu" + source_bytes = source.encode("utf-8") + program = nvrtc_check( + nvrtc_call( + "nvrtcCreateProgram", + source_bytes, + program_name, + 0, + [], + [], + ), + "nvrtcCreateProgram", + ) + + cubin = b"" + ptx = b"" + build_log = b"" + + try: + encoded_options = [opt.encode("utf-8") if isinstance(opt, str) else bytes(opt) for opt in options] + encoded_options = prepare_nvrtc_options(encoded_options) + compile_result = nvrtc_call("nvrtcCompileProgram", program, len(encoded_options), encoded_options) + compile_status = compile_result[0] if isinstance(compile_result, tuple) else compile_result + + build_log = nvrtc_read_bytes(program, "nvrtcGetProgramLogSize", "nvrtcGetProgramLog") + if not status_success(compile_status): + clean_build_log = build_log.rstrip(b"\x00").decode("utf-8", errors="replace") + if 'could not open source file "cuda_runtime.h"' in clean_build_log: + discovered = discover_cuda_include_dirs() + hint = ( + " NVRTC could not find CUDA headers. " + f"Discovered include dirs: {discovered if len(discovered) > 0 else 'none'}. " + "Set CUDA_HOME/CUDA_PATH to your toolkit root or ensure nvcc is on PATH." + ) + else: + hint = "" + raise RuntimeError( + f"NVRTC compilation failed: {clean_build_log}{hint}" + ) + + try: + cubin = nvrtc_read_bytes(program, "nvrtcGetCUBINSize", "nvrtcGetCUBIN") + except Exception: + cubin = b"" + + if len(cubin) == 0: + try: + ptx = nvrtc_read_bytes(program, "nvrtcGetPTXSize", "nvrtcGetPTX") + except Exception: + ptx = b"" + finally: + try: + nvrtc_check(nvrtc_call("nvrtcDestroyProgram", program), "nvrtcDestroyProgram") + except Exception: + pass + + image_data = cubin + if len(image_data) == 0: + image_data = ptx + + if len(image_data) == 0: + raise RuntimeError("NVRTC compilation succeeded but produced neither a CUBIN nor a PTX payload.") + + if len(cubin) == 0 and not image_data.endswith(b"\x00"): + image_data += b"\x00" + + self.module_raw = drv_check( + drv_call(["cuModuleLoadDataEx", "cuModuleLoadData"], image_data), + "cuModuleLoadData", + ) + + def get_function(self, name: str): + func_raw = drv_check( + drv_call("cuModuleGetFunction", self.module_raw, name.encode("utf-8")), + "cuModuleGetFunction", + ) + return _KernelFunction(func_raw) + + +class _CudaDevice: + class device_attribute: + MAX_BLOCK_DIM_X = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X", + 0, + ) + MAX_BLOCK_DIM_Y = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y", + 0, + ) + MAX_BLOCK_DIM_Z = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z", + 0, + ) + MAX_THREADS_PER_BLOCK = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK", + 0, + ) + MAX_GRID_DIM_X = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X", + 0, + ) + MAX_GRID_DIM_Y = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y", + 0, + ) + MAX_GRID_DIM_Z = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z", + 0, + ) + WARP_SIZE = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_WARP_SIZE", + 0, + ) + MAX_SHARED_MEMORY_PER_BLOCK = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK", + 0, + ) + + class Device: + def __init__(self, index: int): + self.index = int(index) + self.device_raw = drv_check(drv_call("cuDeviceGet", self.index), "cuDeviceGet") + + @staticmethod + def count(): + return int(drv_check(drv_call("cuDeviceGetCount"), "cuDeviceGetCount")) + + def get_attributes(self): + attrs = {} + for attr_name in ( + "MAX_BLOCK_DIM_X", + "MAX_BLOCK_DIM_Y", + "MAX_BLOCK_DIM_Z", + "MAX_THREADS_PER_BLOCK", + "MAX_GRID_DIM_X", + "MAX_GRID_DIM_Y", + "MAX_GRID_DIM_Z", + "WARP_SIZE", + "MAX_SHARED_MEMORY_PER_BLOCK", + ): + attr_enum = getattr(_CudaDevice.device_attribute, attr_name) + try: + val = drv_check( + drv_call("cuDeviceGetAttribute", attr_enum, self.device_raw), + "cuDeviceGetAttribute", + ) + attrs[attr_enum] = int(val) + except Exception: + attrs[attr_enum] = 0 + return attrs + + def compute_capability(self): + major_enum = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR", + 0, + ) + minor_enum = getattr( + getattr(driver, "CUdevice_attribute", object()), + "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR", + 0, + ) + major = drv_check(drv_call("cuDeviceGetAttribute", major_enum, self.device_raw), "cuDeviceGetAttribute") + minor = drv_check(drv_call("cuDeviceGetAttribute", minor_enum, self.device_raw), "cuDeviceGetAttribute") + return int(major), int(minor) + + def total_memory(self): + return int(drv_check(drv_call(["cuDeviceTotalMem", "cuDeviceTotalMem_v2"], self.device_raw), "cuDeviceTotalMem")) + + def pci_bus_id(self): + try: + bus_id = drv_check(drv_call("cuDeviceGetPCIBusId", 64, self.device_raw), "cuDeviceGetPCIBusId") + if isinstance(bus_id, (bytes, bytearray)): + return bus_id.decode("utf-8", errors="replace").rstrip("\x00") + return str(bus_id) + except Exception: + return f"cuda-device-{self.index}" + + def name(self): + try: + name = drv_check(drv_call("cuDeviceGetName", 128, self.device_raw), "cuDeviceGetName") + if isinstance(name, (bytes, bytearray)): + return name.decode("utf-8", errors="replace").rstrip("\x00") + return str(name) + except Exception: + return f"CUDA Device {self.index}" + + def retain_primary_context(self): + ctx_raw = drv_check(drv_call("cuDevicePrimaryCtxRetain", self.device_raw), "cuDevicePrimaryCtxRetain") + return _ContextHandle(ctx_raw, self.index, True) + + def make_context(self): + ctx_raw = drv_check( + drv_call(["cuCtxCreate", "cuCtxCreate_v2"], 0, self.device_raw), + "cuCtxCreate", + ) + return _ContextHandle(ctx_raw, self.index, False) + + class Context: + @staticmethod + def pop(): + try: + drv_check(drv_call("cuCtxPopCurrent"), "cuCtxPopCurrent") + return + except Exception: + pass + + popped = ctypes.c_void_p() + drv_check(drv_call("cuCtxPopCurrent", popped), "cuCtxPopCurrent") + + Stream = _StreamHandle + ExternalStream = _StreamHandle + Event = _EventHandle + DeviceAllocation = _DeviceAllocation + device_attribute = device_attribute + + @staticmethod + def init(): + drv_check(drv_call("cuInit", 0), "cuInit") + + @staticmethod + def get_driver_version(): + return int(drv_check(drv_call("cuDriverGetVersion"), "cuDriverGetVersion")) + + @staticmethod + def mem_alloc(size: int): + ptr = drv_check( + drv_call(["cuMemAlloc", "cuMemAlloc_v2"], int(size)), + "cuMemAlloc", + ) + return _DeviceAllocation(int(to_int(ptr))) + + @staticmethod + def memcpy_htod_async(dst_ptr, src_obj, stream_obj): + src_view = memoryview(src_obj).cast("B") + host_ptr, _keepalive = readonly_host_ptr(src_view) + stream_handle = 0 if stream_obj is None else int(stream_obj) + drv_check( + drv_call( + ["cuMemcpyHtoDAsync", "cuMemcpyHtoDAsync_v2"], + as_driver_handle("CUdeviceptr", int(dst_ptr)), + host_ptr, + len(src_view), + as_driver_handle("CUstream", stream_handle), + ), + "cuMemcpyHtoDAsync", + ) + + @staticmethod + def memcpy_dtoh_async(dst_obj, src_ptr, stream_obj): + dst_view = memoryview(dst_obj).cast("B") + host_ptr, _keepalive = writable_host_ptr(dst_view) + stream_handle = 0 if stream_obj is None else int(stream_obj) + drv_check( + drv_call( + ["cuMemcpyDtoHAsync", "cuMemcpyDtoHAsync_v2"], + host_ptr, + as_driver_handle("CUdeviceptr", int(src_ptr)), + len(dst_view), + as_driver_handle("CUstream", stream_handle), + ), + "cuMemcpyDtoHAsync", + ) + + @staticmethod + def pagelocked_empty(size: int, dtype): + return np.empty(int(size), dtype=dtype) + + +cuda = _CudaDevice diff --git a/vkdispatch/backends/cuda_backend/descriptor_sets.py b/vkdispatch/backends/cuda_backend/descriptor_sets.py new file mode 100644 index 00000000..10670708 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/descriptor_sets.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from . import state as state +from .helpers import set_error, to_bytes, buffer_device_ptr + +from .handle import CUDAHandle, HandleRegistry +from typing import Dict, Tuple, Optional + +_descriptor_sets: HandleRegistry = HandleRegistry() + +class CUDADescriptorSet(CUDAHandle): + plan_handle: int + buffer_bindings: Dict[int, Tuple[int, int, int, int, int, int]] + image_bindings: Dict[int, Tuple[int, int, int, int]] + inline_uniform_payload: bytes + + def __init__(self, plan_handle: int): + super().__init__(_descriptor_sets) + + self.plan_handle = plan_handle + self.buffer_bindings = {} + self.image_bindings = {} + self.inline_uniform_payload = b"" + + @staticmethod + def from_handle(handle: int) -> Optional["CUDADescriptorSet"]: + return _descriptor_sets.get(int(handle)) + + def resolve_buffer_pointer(self, binding: int) -> int: + binding_info = self.buffer_bindings.get(binding) + if binding_info is None: + raise RuntimeError(f"Missing descriptor buffer binding {binding}") + + buffer_handle, offset, _, _, _, _ = binding_info + + buffer_obj = state.buffers.get(int(buffer_handle)) + if buffer_obj is None: + raise RuntimeError(f"Invalid buffer handle {buffer_handle} for binding {binding}") + + return buffer_device_ptr(buffer_obj) + int(offset) + +def descriptor_set_create(plan): + if int(plan) not in state.compute_plans: + set_error("Invalid compute plan handle for descriptor_set_create") + return 0 + + return CUDADescriptorSet(plan_handle=int(plan)).handle + + +def descriptor_set_destroy(descriptor_set): + _descriptor_sets.pop(descriptor_set) + + +def descriptor_set_write_buffer( + descriptor_set, + binding, + object, + offset, + range, + uniform, + read_access, + write_access, +): + ds = CUDADescriptorSet.from_handle(descriptor_set) + if ds is None: + set_error("Invalid descriptor set handle for descriptor_set_write_buffer") + return + + ds.buffer_bindings[int(binding)] = ( + int(object), + int(offset), + int(range), + int(uniform), + int(read_access), + int(write_access), + ) + + +def descriptor_set_write_image( + descriptor_set, + binding, + object, + sampler_obj, + read_access, + write_access, +): + _ = descriptor_set + _ = binding + _ = object + _ = sampler_obj + _ = read_access + _ = write_access + set_error("CUDA Python backend does not support image objects yet") + + +def descriptor_set_write_inline_uniform(descriptor_set, payload): + ds = CUDADescriptorSet.from_handle(descriptor_set) + if ds is None: + set_error("Invalid descriptor set handle for descriptor_set_write_inline_uniform") + return + + try: + ds.inline_uniform_payload = to_bytes(payload) + except Exception as exc: + set_error(f"Failed to store inline uniform payload: {exc}") diff --git a/vkdispatch/backends/cuda_backend/handle.py b/vkdispatch/backends/cuda_backend/handle.py new file mode 100644 index 00000000..5f5e5082 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/handle.py @@ -0,0 +1,26 @@ +from typing import Dict, Optional + +from . import state as state + +class HandleRegistry: + def __init__(self): + self.registry: Dict[int, object] = {} + + def new_handle(self, obj: object) -> int: + handle = state.next_handle + self.registry[handle] = obj + state.next_handle += 1 + return handle + + def get(self, handle: int) -> Optional[object]: + return self.registry.get(int(handle)) + + def pop(self, handle: int) -> Optional[object]: + return self.registry.pop(int(handle), None) + + +class CUDAHandle: + handle: int + + def __init__(self, registry: HandleRegistry): + self.handle = registry.new_handle(self) \ No newline at end of file diff --git a/vkdispatch/backends/cuda_backend/helpers.py b/vkdispatch/backends/cuda_backend/helpers.py new file mode 100644 index 00000000..5dad2743 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/helpers.py @@ -0,0 +1,380 @@ +from __future__ import annotations + +from contextlib import contextmanager +import re +import sys +from typing import Dict, List, Optional, Tuple, Any + +from . import state as state +from .bindings import driver, np, drv_call, drv_check, to_int +from .constants import ( + BINDING_PARAM_RE, + KERNEL_SIGNATURE_RE, + LOCAL_X_RE, + LOCAL_Y_RE, + LOCAL_Z_RE, + SAMPLER_PARAM_RE, +) +from .cuda_primitives import _ByValueKernelArg, cuda +from .state import CUDABuffer, CUDAComputePlan, CUDAContext, CUDAKernelParam + +#from .api_descriptor import CUDADescriptorSet + +def new_handle(registry: Dict[int, object], obj: object) -> int: + handle = state.next_handle + state.next_handle += 1 + registry[handle] = obj + return handle + + +def to_bytes(value) -> bytes: + if value is None: + return b"" + if isinstance(value, bytes): + return value + if isinstance(value, bytearray): + return bytes(value) + if isinstance(value, memoryview): + return value.tobytes() + return bytes(value) + + +def set_error(message: str) -> None: + state.error_string = str(message) + + +def clear_error() -> None: + state.error_string = None + + +def coerce_stream_handle(stream_obj) -> Optional[int]: + if stream_obj is None: + return None + + if isinstance(stream_obj, int): + return int(stream_obj) + + cuda_stream_protocol = getattr(stream_obj, "__cuda_stream__", None) + if cuda_stream_protocol is not None: + try: + proto_value = cuda_stream_protocol() if callable(cuda_stream_protocol) else cuda_stream_protocol + if isinstance(proto_value, tuple) and len(proto_value) > 0: + proto_value = proto_value[0] + return int(proto_value) + except Exception: + pass + + for attr_name in ("cuda_stream", "ptr", "handle"): + if hasattr(stream_obj, attr_name): + try: + return int(getattr(stream_obj, attr_name)) + except Exception: + pass + + nested = getattr(stream_obj, "stream", None) + if nested is not None and nested is not stream_obj: + try: + return coerce_stream_handle(nested) + except Exception: + pass + + try: + return int(stream_obj) + except Exception as exc: + raise TypeError( + "Unable to extract a CUDA stream handle from the provided object. " + "Pass an int handle or an object with __cuda_stream__/.cuda_stream/.ptr/.handle." + ) from exc + + +def stream_override_stack() -> List[Optional[int]]: + stack = getattr(state.stream_override, "stack", None) + if stack is None: + stack = [] + state.stream_override.stack = stack + return stack + + +def get_stream_override_handle() -> Optional[int]: + stack = getattr(state.stream_override, "stack", None) + if not stack: + return None + return stack[-1] + + +def wrap_external_stream(handle: int): + handle = int(handle) + + if handle in state.external_stream_cache: + return state.external_stream_cache[handle] + + if handle == 0: + return None + + ctor_attempts = [ + lambda: cuda.Stream(handle=handle), + lambda: cuda.Stream(ptr=handle), + lambda: cuda.Stream(int(handle)), + ] + + external_cls = getattr(cuda, "ExternalStream", None) + if external_cls is not None: + ctor_attempts.insert(0, lambda: external_cls(handle)) + + last_error = None + for ctor in ctor_attempts: + try: + stream_obj = ctor() + state.external_stream_cache[handle] = stream_obj + return stream_obj + except Exception as exc: # pragma: no cover - depends on cuda-python version + last_error = exc + + raise RuntimeError( + f"Failed to wrap external CUDA stream handle {handle} with CUDA Python. " + "This CUDA Python version may not support external stream wrappers." + ) from last_error + + +def stream_for_queue(ctx: CUDAContext, queue_index: int): + override_handle = get_stream_override_handle() + if override_handle is None: + return ctx.streams[queue_index] + return wrap_external_stream(int(override_handle)) + + +def buffer_device_ptr(buffer_obj: CUDABuffer) -> int: + return int(buffer_obj.device_ptr) + + +def queue_indices(ctx: CUDAContext, queue_index: int, *, all_on_negative: bool = False) -> List[int]: + if ctx.queue_count <= 0: + return [] + + if queue_index is None: + return [0] + + queue_index = int(queue_index) + + if all_on_negative and queue_index < 0: + return list(range(ctx.queue_count)) + + if queue_index == -1: + return [0] + + if 0 <= queue_index < ctx.queue_count: + return [queue_index] + + return [] + + +def context_from_handle(context_handle: int) -> Optional[CUDAContext]: + ctx = state.contexts.get(int(context_handle)) + if ctx is None: + set_error(f"Invalid context handle {context_handle}") + return ctx + + +@contextmanager +def activate_context(ctx: CUDAContext): + ctx.cuda_context.push() + try: + yield + finally: + cuda.Context.pop() + +def allocate_staging_storage(size: int): + try: + # Pagelocked host memory improves async HtoD/DtoH throughput and overlap. + return cuda.pagelocked_empty(int(size), np.uint8) + except Exception: + return bytearray(int(size)) + + +def fallback_max_kernel_param_size(compute_capability_major: int) -> int: + # CUDA kernels support at least 4 KiB of launch parameters on legacy devices. + # Volta+ devices commonly expose a larger 32 KiB-ish argument space. + return 32764 if int(compute_capability_major) >= 7 else 4096 + + +def query_max_kernel_param_size(device_raw, compute_capability_major: int) -> int: + attr_names = ( + "CU_DEVICE_ATTRIBUTE_MAX_PARAMETER_SIZE", + "CU_DEVICE_ATTRIBUTE_MAX_PARAMETER_SIZE_SUPPORTED", + "CU_DEVICE_ATTRIBUTE_MAX_KERNEL_PARAMETER_SIZE", + ) + + attr_enum_container = getattr(driver, "CUdevice_attribute", None) + if attr_enum_container is not None: + for attr_name in attr_names: + attr_enum = getattr(attr_enum_container, attr_name, None) + if attr_enum is None: + continue + + try: + queried_value = drv_check( + drv_call("cuDeviceGetAttribute", attr_enum, device_raw), + "cuDeviceGetAttribute", + ) + queried_size = int(to_int(queried_value)) + if queried_size > 0: + return queried_size + except Exception: + continue + + print( + "Warning: Unable to query max kernel parameter size from CUDA driver. Falling back to a conservative default.", + file=sys.stderr, + ) + + return fallback_max_kernel_param_size(compute_capability_major) + + +def parse_local_size(source: str) -> Tuple[int, int, int]: + x_match = LOCAL_X_RE.search(source) + y_match = LOCAL_Y_RE.search(source) + z_match = LOCAL_Z_RE.search(source) + + x = int(x_match.group(1)) if x_match else 1 + y = int(y_match.group(1)) if y_match else 1 + z = int(z_match.group(1)) if z_match else 1 + + return (x, y, z) + + +def parse_kernel_params(source: str) -> List[CUDAKernelParam]: + signature_match = KERNEL_SIGNATURE_RE.search(source) + if signature_match is None: + raise RuntimeError("Could not find vkdispatch_main kernel signature in CUDA source") + + signature_blob = signature_match.group(1).strip() + if len(signature_blob) == 0: + return [] + + params: List[CUDAKernelParam] = [] + + for raw_decl in [part.strip() for part in signature_blob.split(",") if len(part.strip()) > 0]: + name_match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)\s*$", raw_decl) + if name_match is None: + raise RuntimeError(f"Unable to parse kernel parameter declaration '{raw_decl}'") + + param_name = name_match.group(1) + + if param_name == "vkdispatch_uniform_ptr": + params.append(CUDAKernelParam("uniform", 0, param_name)) + continue + + if param_name == "vkdispatch_uniform_value": + params.append(CUDAKernelParam("uniform_value", None, param_name)) + continue + + if param_name == "vkdispatch_pc_value": + params.append(CUDAKernelParam("push_constant_value", None, param_name)) + continue + + binding_match = BINDING_PARAM_RE.match(param_name) + if binding_match is not None: + params.append(CUDAKernelParam("storage", int(binding_match.group(1)), param_name)) + continue + + sampler_match = SAMPLER_PARAM_RE.match(param_name) + if sampler_match is not None: + params.append(CUDAKernelParam("sampler", int(sampler_match.group(1)), param_name)) + continue + + params.append(CUDAKernelParam("unknown", None, param_name)) + + return params + +def build_kernel_args_template( + plan: CUDAComputePlan, + descriptor_set: Optional[Any], # CUDADescriptorSet + push_constant_payload: bytes = b"", +) -> Tuple[object, ...]: + args: List[object] = [] + + for param in plan.params: + if param.kind == "uniform": + if descriptor_set is None: + raise RuntimeError("Kernel requires a descriptor set but none was provided") + + args.append(np.uintp(descriptor_set.resolve_buffer_pointer(0))) + continue + + if param.kind == "uniform_value": + if descriptor_set is None: + raise RuntimeError("Kernel requires a descriptor set but none was provided") + + if len(descriptor_set.inline_uniform_payload) == 0: + raise RuntimeError( + "Missing inline uniform payload for CUDA by-value uniform parameter " + f"'{param.raw_name}'." + ) + + args.append(_ByValueKernelArg(descriptor_set.inline_uniform_payload, param.raw_name)) + continue + + if param.kind == "push_constant_value": + if plan.pc_size <= 0: + raise RuntimeError( + f"Kernel parameter '{param.raw_name}' expects push-constant data, but this compute plan has pc_size={plan.pc_size}." + ) + + if len(push_constant_payload) == 0: + raise RuntimeError( + "Missing push-constant payload for CUDA by-value push-constant parameter " + f"'{param.raw_name}'." + ) + + if len(push_constant_payload) != int(plan.pc_size): + raise RuntimeError( + f"Push-constant payload size mismatch for parameter '{param.raw_name}'. " + f"Expected {plan.pc_size} bytes but got {len(push_constant_payload)} bytes." + ) + + args.append(_ByValueKernelArg(push_constant_payload, param.raw_name)) + continue + + if param.kind == "storage": + if descriptor_set is None: + raise RuntimeError("Kernel requires a descriptor set but none was provided") + + if param.binding is None: + raise RuntimeError("Storage parameter has no binding index") + + args.append(np.uintp(descriptor_set.resolve_buffer_pointer(param.binding))) + continue + + if param.kind == "sampler": + raise RuntimeError("CUDA Python backend does not support sampled image bindings yet") + + raise RuntimeError( + f"Unsupported kernel parameter '{param.raw_name}'. " + "Expected vkdispatch_uniform_ptr / vkdispatch_uniform_value / vkdispatch_pc_value / vkdispatch_binding__ptr." + ) + + return tuple(args) + + +def align_up(value: int, alignment: int) -> int: + if alignment <= 1: + return value + return ((value + alignment - 1) // alignment) * alignment + + +def estimate_kernel_param_size_bytes(args: Tuple[object, ...]) -> int: + total_bytes = 0 + + for arg in args: + if isinstance(arg, _ByValueKernelArg): + payload_size = len(arg.payload) + # Kernel params are aligned by argument type. Use a conservative + # 16-byte alignment for by-value structs. + total_bytes = align_up(total_bytes, 16) + total_bytes += payload_size + continue + + total_bytes = align_up(total_bytes, 8) + total_bytes += 8 + + return total_bytes diff --git a/vkdispatch/backends/cuda_backend/image_fft_stubs.py b/vkdispatch/backends/cuda_backend/image_fft_stubs.py new file mode 100644 index 00000000..7b21e627 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/image_fft_stubs.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from . import state as state +from .helpers import set_error + + +def image_create(context, extent, layers, format, type, view_type, generate_mips): + _ = context + _ = extent + _ = layers + _ = format + _ = type + _ = view_type + _ = generate_mips + set_error("CUDA Python backend does not support image objects yet") + return 0 + + +def image_destroy(image): + _ = image + set_error("CUDA Python backend does not support image objects yet") + + +def image_create_sampler( + context, + mag_filter, + min_filter, + mip_mode, + address_mode, + mip_lod_bias, + min_lod, + max_lod, + border_color, +): + _ = context + _ = mag_filter + _ = min_filter + _ = mip_mode + _ = address_mode + _ = mip_lod_bias + _ = min_lod + _ = max_lod + _ = border_color + set_error("CUDA Python backend does not support image samplers yet") + return 0 + + +def image_destroy_sampler(sampler): + _ = sampler + set_error("CUDA Python backend does not support image samplers yet") + + +def image_write(image, data, offset, extent, baseLayer, layerCount, device_index): + _ = image + _ = data + _ = offset + _ = extent + _ = baseLayer + _ = layerCount + _ = device_index + set_error("CUDA Python backend does not support image writes yet") + + +def image_format_block_size(format): + _ = format + set_error("CUDA Python backend does not support image format block size queries yet") + + +def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index): + _ = image + _ = offset + _ = extent + _ = baseLayer + _ = layerCount + _ = device_index + set_error("CUDA Python backend does not support image reads yet") + return bytes(max(0, int(out_size))) + + +def stage_fft_plan_create( + context, + dims, + axes, + buffer_size, + do_r2c, + normalize, + pad_left, + pad_right, + frequency_zeropadding, + kernel_num, + kernel_convolution, + conjugate_convolution, + convolution_features, + input_buffer_size, + num_batches, + single_kernel_multiple_batches, + keep_shader_code, +): + _ = context + _ = dims + _ = axes + _ = buffer_size + _ = do_r2c + _ = normalize + _ = pad_left + _ = pad_right + _ = frequency_zeropadding + _ = kernel_num + _ = kernel_convolution + _ = conjugate_convolution + _ = convolution_features + _ = input_buffer_size + _ = num_batches + _ = single_kernel_multiple_batches + _ = keep_shader_code + set_error("CUDA Python backend does not support FFT plans yet") + return 0 + + +def stage_fft_plan_destroy(plan): + _ = plan + set_error("CUDA Python backend does not support FFT plans yet") + + +def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer): + _ = command_list + _ = plan + _ = buffer + _ = inverse + _ = kernel + _ = input_buffer + set_error("CUDA Python backend does not support FFT stages yet") diff --git a/vkdispatch/backends/cuda_backend/signal.py b/vkdispatch/backends/cuda_backend/signal.py new file mode 100644 index 00000000..6dfbca35 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/signal.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from . import state as state +from .helpers import ( + activate_context, + context_from_handle, + queue_indices, + set_error, + stream_for_queue, +) + +from typing import Optional + +from .cuda_primitives import cuda +from .handle import CUDAHandle, HandleRegistry + +_signals: HandleRegistry = HandleRegistry() + +class CUDASignal(CUDAHandle): + context_handle: int + queue_index: int + event: Optional["cuda.Event"] = None + submitted: bool = True + done: bool = True + + def __init__(self, + context_handle: int, + queue_index: int, + event: Optional["cuda.Event"] = None, + submitted: bool = True, + done: bool = True): + super().__init__(_signals) + + self.context_handle = context_handle + self.queue_index = queue_index + self.event = event + self.submitted = submitted + self.done = done + + @staticmethod + def from_handle(handle: int) -> Optional["CUDASignal"]: + return _signals.get(handle) + + def record(self, stream: "cuda.Stream"): + self.submitted = True + self.done = False + if self.event is None: + self.event = cuda.Event() + self.event.record(stream) + + def query(self) -> bool: + if self.event is None: + return bool(self.done) + + try: + done = self.event.query() + except Exception: + return False + + self.done = bool(done) + return self.done + +def signal_wait(signal_ptr, wait_for_timestamp, queue_index): + signal_obj = CUDASignal.from_handle(signal_ptr) + if signal_obj is None: + return True + + if not bool(wait_for_timestamp): + # CUDA Python records signals synchronously on submission; host-side "recorded" waits + # should therefore complete immediately once an event exists. + if signal_obj.event is None: + return bool(signal_obj.done) + return bool(signal_obj.submitted) + + if signal_obj.done: + return True + + if signal_obj.event is None: + return bool(signal_obj.done) + + ctx = state.contexts.get(signal_obj.context_handle) + if ctx is None: + return signal_obj.query() + + try: + with activate_context(ctx): + signal_obj.event.synchronize() + signal_obj.done = True + return True + except Exception: + return signal_obj.query() + + +def signal_insert(context, queue_index): + ctx = context_from_handle(int(context)) + if ctx is None: + return 0 + + selected = queue_indices(ctx, int(queue_index)) + if len(selected) == 0: + selected = [0] + + signal = CUDASignal(context_handle=int(context), queue_index=selected[0], submitted=False, done=False) + + try: + with activate_context(ctx): + signal.record(stream_for_queue(ctx, selected[0])) + except Exception as exc: + set_error(f"Failed to insert signal: {exc}") + return 0 + + return signal.handle + + +def signal_destroy(signal_ptr): + _signals.pop(signal_ptr) diff --git a/vkdispatch/backends/cuda_backend/state.py b/vkdispatch/backends/cuda_backend/state.py new file mode 100644 index 00000000..21e7af25 --- /dev/null +++ b/vkdispatch/backends/cuda_backend/state.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +import threading +from typing import Dict, List, Optional, Tuple + +from .constants import LOG_LEVEL_WARNING +from .cuda_primitives import SourceModule, cuda + +#from .api_descriptor import CUDADescriptorSet + +# --- Runtime state --- + +initialized = False +debug_mode = False +log_level = LOG_LEVEL_WARNING +error_string: Optional[str] = None +next_handle = 1 + +contexts: Dict[int, "CUDAContext"] = {} +buffers: Dict[int, "CUDABuffer"] = {} +command_lists: Dict[int, "CUDACommandList"] = {} +compute_plans: Dict[int, "CUDAComputePlan"] = {} +external_stream_cache: Dict[int, object] = {} +stream_override = threading.local() + + +# --- Internal objects --- + +@dataclass +class CUDAContext: + device_index: int + cuda_context: "cuda.Context" + streams: List["cuda.Stream"] + queue_count: int + queue_to_device: List[int] + max_kernel_param_size: int + uses_primary_context: bool = False + stopped: bool = False + + +@dataclass +class CUDABuffer: + context_handle: int + size: int + device_ptr: int + device_allocation: Optional["cuda.DeviceAllocation"] + owns_allocation: bool + staging_data: List[object] + signal_handles: List[int] + + +@dataclass +class CUDACommandRecord: + plan_handle: int + descriptor_set_handle: int + blocks: Tuple[int, int, int] + pc_size: int + + +@dataclass +class CUDACommandList: + context_handle: int + commands: List[CUDACommandRecord] = field(default_factory=list) + + +@dataclass +class CUDAKernelParam: + kind: str + binding: Optional[int] + raw_name: str + + +@dataclass +class CUDAComputePlan: + context_handle: int + shader_source: bytes + bindings: List[int] + shader_name: bytes + module: SourceModule + function: object + local_size: Tuple[int, int, int] + params: List[CUDAKernelParam] + pc_size: int + + + diff --git a/vkdispatch/backends/dummy_backend.py b/vkdispatch/backends/dummy_backend.py new file mode 100644 index 00000000..47319abd --- /dev/null +++ b/vkdispatch/backends/dummy_backend.py @@ -0,0 +1,535 @@ +"""Brython-friendly pure-Python shim for ``vkdispatch_native``. + +This module mirrors the Cython-exposed API used by ``vkdispatch`` and provides +dummy metadata helpers for docs/codegen flows. + +Runtime GPU operations are intentionally denied so the dummy backend fails fast +when used outside codegen-only scripts. +""" + +# --- Runtime state --- + +_initialized = False +_debug_mode = False +_log_level = 2 +_error_string = None +_next_handle = 1 + +_contexts = {} +_signals = {} + +# Device limits exposed through get_devices(); mutable so docs UI can tune them. +_DEFAULT_SUBGROUP_SIZE = 32 +_DEFAULT_MAX_WORKGROUP_SIZE = (1024, 1024, 64) +_DEFAULT_MAX_WORKGROUP_INVOCATIONS = 1024 +_DEFAULT_MAX_WORKGROUP_COUNT = (65535, 65535, 65535) +_DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE = 64 * 1024 + +_device_subgroup_size = _DEFAULT_SUBGROUP_SIZE +_device_max_workgroup_size = _DEFAULT_MAX_WORKGROUP_SIZE +_device_max_workgroup_invocations = _DEFAULT_MAX_WORKGROUP_INVOCATIONS +_device_max_workgroup_count = _DEFAULT_MAX_WORKGROUP_COUNT +_device_max_compute_shared_memory_size = _DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE + + +# --- Internal objects --- + +class _Signal: + __slots__ = ("done",) + + def __init__(self, done=True): + self.done = bool(done) + + +class _Context: + __slots__ = ( + "device_indices", + "queue_families", + "queue_count", + "queue_to_device", + "stopped", + ) + + def __init__(self, device_indices, queue_families): + self.device_indices = list(device_indices) + self.queue_families = [list(fam) for fam in queue_families] + + normalized = [] + for fam in self.queue_families: + normalized.append(fam if len(fam) > 0 else [0]) + self.queue_families = normalized + + self.queue_count = sum(len(fam) for fam in self.queue_families) + if self.queue_count <= 0: + self.queue_families = [[0]] + self.queue_count = 1 + + queue_to_device = [] + for dev_idx, fam in enumerate(self.queue_families): + for _ in fam: + queue_to_device.append(dev_idx) + + if len(queue_to_device) == 0: + queue_to_device = [0] + + self.queue_to_device = queue_to_device + self.stopped = False + +# --- Internal helpers --- + +def _new_handle(registry, obj): + global _next_handle + handle = _next_handle + _next_handle += 1 + registry[handle] = obj + return handle + +def _set_error(message): + global _error_string + _error_string = str(message) + + +def _clear_error(): + global _error_string + _error_string = None + + +_DUMMY_CODEGEN_ONLY_ERROR = ( + "The 'dummy' backend is codegen-only and does not support runtime GPU " + "operations. Use backend='vulkan', backend='pycuda', or backend='cuda-python' for execution." +) + + +def _deny_runtime_native_call(function_name): + raise RuntimeError(f"{_DUMMY_CODEGEN_ONLY_ERROR} (native call: {function_name})") + + +def _as_positive_int(name, value): + try: + parsed = int(value) + except Exception as exc: + raise ValueError("%s must be an integer" % name) from exc + + if parsed <= 0: + raise ValueError("%s must be greater than zero" % name) + + return parsed + + +def _as_positive_triplet(name, value): + try: + parts = list(value) + except Exception as exc: + raise ValueError("%s must contain exactly 3 integers" % name) from exc + + if len(parts) != 3: + raise ValueError("%s must contain exactly 3 integers" % name) + + return ( + _as_positive_int("%s[0]" % name, parts[0]), + _as_positive_int("%s[1]" % name, parts[1]), + _as_positive_int("%s[2]" % name, parts[2]), + ) + + +# --- API: context/init/errors/logging --- + + +def reset_device_options(): + global _device_subgroup_size + global _device_max_workgroup_size + global _device_max_workgroup_invocations + global _device_max_workgroup_count + global _device_max_compute_shared_memory_size + + _device_subgroup_size = _DEFAULT_SUBGROUP_SIZE + _device_max_workgroup_size = _DEFAULT_MAX_WORKGROUP_SIZE + _device_max_workgroup_invocations = _DEFAULT_MAX_WORKGROUP_INVOCATIONS + _device_max_workgroup_count = _DEFAULT_MAX_WORKGROUP_COUNT + _device_max_compute_shared_memory_size = _DEFAULT_MAX_COMPUTE_SHARED_MEMORY_SIZE + + +def set_device_options( + subgroup_size=None, + max_workgroup_size=None, + max_workgroup_invocations=None, + max_workgroup_count=None, + max_compute_shared_memory_size=None, +): + global _device_subgroup_size + global _device_max_workgroup_size + global _device_max_workgroup_invocations + global _device_max_workgroup_count + global _device_max_compute_shared_memory_size + + if subgroup_size is not None: + _device_subgroup_size = _as_positive_int("subgroup_size", subgroup_size) + + if max_workgroup_size is not None: + _device_max_workgroup_size = _as_positive_triplet( + "max_workgroup_size", + max_workgroup_size, + ) + + if max_workgroup_invocations is not None: + _device_max_workgroup_invocations = _as_positive_int( + "max_workgroup_invocations", + max_workgroup_invocations, + ) + + if max_workgroup_count is not None: + _device_max_workgroup_count = _as_positive_triplet( + "max_workgroup_count", + max_workgroup_count, + ) + + if max_compute_shared_memory_size is not None: + _device_max_compute_shared_memory_size = _as_positive_int( + "max_compute_shared_memory_size", + max_compute_shared_memory_size, + ) + + +def init(debug, log_level): + global _initialized, _debug_mode, _log_level + _initialized = True + _debug_mode = bool(debug) + _log_level = int(log_level) + _clear_error() + + +def log(log_level, text, file_str, line_str): + # Keep logging quiet in docs/brython by default. + # Function kept for API compatibility. + _ = log_level + _ = text + _ = file_str + _ = line_str + + +def set_log_level(log_level): + global _log_level + _log_level = int(log_level) + + +def get_devices(): + if not _initialized: + init(False, _log_level) + + # One plausible fake discrete GPU with compute+graphics queue families. + device_tuple = ( + 0, # version_variant + 1, # version_major + 3, # version_minor + 0, # version_patch + 1001000, # driver_version + 0x1BAD, # vendor_id + 0x0001, # device_id + 2, # device_type (Discrete GPU) + "VKDispatch Web Dummy GPU", + 1, # shader_buffer_float32_atomics + 1, # shader_buffer_float32_atomic_add + 1, # float_64_support + 1, # float_16_support + 1, # int_64_support + 1, # int_16_support + 1, # storage_buffer_16_bit_access + 1, # uniform_and_storage_buffer_16_bit_access + 1, # storage_push_constant_16 + 1, # storage_input_output_16 + _device_max_workgroup_size, # max_workgroup_size + _device_max_workgroup_invocations, # max_workgroup_invocations + _device_max_workgroup_count, # max_workgroup_count + 8, # max_descriptor_set_count + 256, # max_push_constant_size + 1 << 30, # max_storage_buffer_range + 65536, # max_uniform_buffer_range + 0, # uniform_buffer_alignment + _device_subgroup_size, # subgroup_size + 0x7FFFFFFF, # supported_stages + 0x7FFFFFFF, # supported_operations + 1, # quad_operations_in_all_stages + _device_max_compute_shared_memory_size, # max_compute_shared_memory_size + [ + (8, 0x006), # compute + transfer + (4, 0x007), # graphics + compute + transfer + ], + 1, # scalar_block_layout + 1, # timeline_semaphores + bytes((0x56, 0x4B, 0x44, 0x30, 0x57, 0x45, 0x42, 0x31, 0x44, 0x55, 0x4D, 0x4D, 0x59, 0x00, 0x00, 0x01)), + ) + + return [device_tuple] + + +def context_create(device_indicies, queue_families): + try: + ctx = _Context(device_indicies, queue_families) + return _new_handle(_contexts, ctx) + except Exception as exc: + _set_error("Failed to create context: %s" % exc) + return 0 + + +def signal_wait(signal_ptr, wait_for_timestamp, queue_index): + _ = wait_for_timestamp + _ = queue_index + signal_obj = _signals.get(int(signal_ptr)) + if signal_obj is None: + return True + return bool(signal_obj.done) + + +def signal_insert(context, queue_index): + _ = context + _ = queue_index + return _new_handle(_signals, _Signal(done=True)) + + +def signal_destroy(signal_ptr): + _signals.pop(int(signal_ptr), None) + + +def context_destroy(context): + _contexts.pop(int(context), None) + + +def get_error_string(): + if _error_string is None: + return 0 + return _error_string + + +def context_stop_threads(context): + ctx = _contexts.get(int(context)) + if ctx is not None: + ctx.stopped = True + + +# --- API: buffers --- + + +def buffer_create(context, size, per_device): + _deny_runtime_native_call("buffer_create") + + +def buffer_destroy(buffer): + _deny_runtime_native_call("buffer_destroy") + + +def buffer_get_queue_signal(buffer, queue_index): + _deny_runtime_native_call("buffer_get_queue_signal") + + +def buffer_wait_staging_idle(buffer, queue_index): + _deny_runtime_native_call("buffer_wait_staging_idle") + + +def buffer_write_staging(buffer, queue_index, data, size): + _deny_runtime_native_call("buffer_write_staging") + + +def buffer_read_staging(buffer, queue_index, size): + _deny_runtime_native_call("buffer_read_staging") + + +def buffer_write(buffer, offset, size, index): + _deny_runtime_native_call("buffer_write") + + +def buffer_read(buffer, offset, size, index): + _deny_runtime_native_call("buffer_read") + + +# --- API: command lists --- + + +def command_list_create(context): + _deny_runtime_native_call("command_list_create") + + +def command_list_destroy(command_list): + _deny_runtime_native_call("command_list_destroy") + + +def command_list_get_instance_size(command_list): + _deny_runtime_native_call("command_list_get_instance_size") + + +def command_list_reset(command_list): + _deny_runtime_native_call("command_list_reset") + + +def command_list_submit(command_list, data, instance_count, index): + _deny_runtime_native_call("command_list_submit") + + +# --- API: descriptor sets --- + + +def descriptor_set_create(plan): + _deny_runtime_native_call("descriptor_set_create") + + +def descriptor_set_destroy(descriptor_set): + _deny_runtime_native_call("descriptor_set_destroy") + + +def descriptor_set_write_buffer( + descriptor_set, + binding, + object, + offset, + range, + uniform, + read_access, + write_access, +): + _deny_runtime_native_call("descriptor_set_write_buffer") + + +def descriptor_set_write_image( + descriptor_set, + binding, + object, + sampler_obj, + read_access, + write_access, +): + _deny_runtime_native_call("descriptor_set_write_image") + + +# --- API: images/samplers --- + + +def image_create(context, extent, layers, format, type, view_type, generate_mips): + _deny_runtime_native_call("image_create") + + +def image_destroy(image): + _deny_runtime_native_call("image_destroy") + + +def image_create_sampler( + context, + mag_filter, + min_filter, + mip_mode, + address_mode, + mip_lod_bias, + min_lod, + max_lod, + border_color, +): + _deny_runtime_native_call("image_create_sampler") + + +def image_destroy_sampler(sampler): + _deny_runtime_native_call("image_destroy_sampler") + + +def image_write(image, data, offset, extent, baseLayer, layerCount, device_index): + _deny_runtime_native_call("image_write") + + +def image_format_block_size(format): + _deny_runtime_native_call("image_format_block_size") + + +def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index): + _deny_runtime_native_call("image_read") + + +# --- API: compute stage --- + + +def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name): + _deny_runtime_native_call("stage_compute_plan_create") + + +def stage_compute_plan_destroy(plan): + _deny_runtime_native_call("stage_compute_plan_destroy") + + +def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z): + _deny_runtime_native_call("stage_compute_record") + + +# --- API: FFT stage --- + + +def stage_fft_plan_create( + context, + dims, + axes, + buffer_size, + do_r2c, + normalize, + pad_left, + pad_right, + frequency_zeropadding, + kernel_num, + kernel_convolution, + conjugate_convolution, + convolution_features, + input_buffer_size, + num_batches, + single_kernel_multiple_batches, + keep_shader_code, +): + _deny_runtime_native_call("stage_fft_plan_create") + + +def stage_fft_plan_destroy(plan): + _deny_runtime_native_call("stage_fft_plan_destroy") + + +def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer): + _deny_runtime_native_call("stage_fft_record") + + +__all__ = [ + "reset_device_options", + "set_device_options", + "init", + "log", + "set_log_level", + "get_devices", + "context_create", + "signal_wait", + "signal_insert", + "signal_destroy", + "context_destroy", + "get_error_string", + "context_stop_threads", + "buffer_create", + "buffer_destroy", + "buffer_get_queue_signal", + "buffer_wait_staging_idle", + "buffer_write_staging", + "buffer_read_staging", + "buffer_write", + "buffer_read", + "command_list_create", + "command_list_destroy", + "command_list_get_instance_size", + "command_list_reset", + "command_list_submit", + "descriptor_set_create", + "descriptor_set_destroy", + "descriptor_set_write_buffer", + "descriptor_set_write_image", + "image_create", + "image_destroy", + "image_create_sampler", + "image_destroy_sampler", + "image_write", + "image_format_block_size", + "image_read", + "stage_compute_plan_create", + "stage_compute_plan_destroy", + "stage_compute_record", + "stage_fft_plan_create", + "stage_fft_plan_destroy", + "stage_fft_record" +] diff --git a/vkdispatch/backends/opencl_backend.py b/vkdispatch/backends/opencl_backend.py new file mode 100644 index 00000000..eed638a3 --- /dev/null +++ b/vkdispatch/backends/opencl_backend.py @@ -0,0 +1,2080 @@ +"""pyopencl-backed runtime shim mirroring the vkdispatch_native API surface. + +This module intentionally matches the function names exposed by the Cython +extension so existing Python runtime objects can call into either backend. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import hashlib +import re +import threading +from typing import Dict, List, Optional, Tuple + +import os +import sys + +try: + import numpy as np +except Exception as exc: # pragma: no cover - import failure path + raise ImportError( + "The OpenCL Python backend requires both 'pyopencl' and 'numpy' to be installed." + ) from exc + +try: + import pyopencl as cl +except Exception as exc: # pragma: no cover - import failure path + raise ImportError( + "The OpenCL runtime backend requires the 'pyopencl' package " + "(`pip install pyopencl`)." + ) from exc + + +# Log level constants mirrored from native bindings. +LOG_LEVEL_VERBOSE = 0 +LOG_LEVEL_INFO = 1 +LOG_LEVEL_WARNING = 2 +LOG_LEVEL_ERROR = 3 + +# Descriptor type enum values mirrored from vkdispatch_native/stages_extern.pxd. +DESCRIPTOR_TYPE_STORAGE_BUFFER = 1 +DESCRIPTOR_TYPE_STORAGE_IMAGE = 2 +DESCRIPTOR_TYPE_UNIFORM_BUFFER = 3 +DESCRIPTOR_TYPE_UNIFORM_IMAGE = 4 +DESCRIPTOR_TYPE_SAMPLER = 5 + +# Image format block sizes for formats exposed in vkdispatch.base.image.image_format. +_IMAGE_BLOCK_SIZES = { + 13: 1, + 14: 1, + 20: 2, + 21: 2, + 27: 3, + 28: 3, + 41: 4, + 42: 4, + 74: 2, + 75: 2, + 76: 2, + 81: 4, + 82: 4, + 83: 4, + 88: 6, + 89: 6, + 90: 6, + 95: 8, + 96: 8, + 97: 8, + 98: 4, + 99: 4, + 100: 4, + 101: 8, + 102: 8, + 103: 8, + 104: 12, + 105: 12, + 106: 12, + 107: 16, + 108: 16, + 109: 16, + 110: 8, + 111: 8, + 112: 8, + 113: 16, + 114: 16, + 115: 16, + 116: 24, + 117: 24, + 118: 24, + 119: 32, + 120: 32, + 121: 32, +} + +_LOCAL_X_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_X\s+(\d+)") +_LOCAL_Y_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Y\s+(\d+)") +_LOCAL_Z_RE = re.compile(r"#define\s+VKDISPATCH_EXPECTED_LOCAL_SIZE_Z\s+(\d+)") +_REQD_LOCAL_RE = re.compile(r"reqd_work_group_size\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)") +_KERNEL_SIGNATURE_RE = re.compile(r"vkdispatch_main\s*\(([^)]*)\)", re.S) +_BINDING_PARAM_RE = re.compile(r"vkdispatch_binding_(\d+)_ptr$") +_SAMPLER_PARAM_RE = re.compile(r"vkdispatch_sampler_(\d+)$") +_PUSH_CONSTANT_STRUCT_RE = re.compile( + r"typedef\s+struct\s+PushConstant\s*\{(?P.*?)\}\s*PushConstant\s*;", + re.S, +) +_PUSH_CONSTANT_FIELD_RE = re.compile( + r"(?P[A-Za-z_][A-Za-z0-9_]*)\s+" + r"(?P[A-Za-z_][A-Za-z0-9_]*)" + r"(?:\s*\[\s*(?P\d+)\s*\])?$" +) +_VECTOR_TYPE_RE = re.compile(r"([A-Za-z_][A-Za-z0-9_]*?)([2-4])$") +_OPENCL_VERSION_RE = re.compile(r"OpenCL\s+(\d+)\.(\d+)") +_DIGIT_RE = re.compile(r"(\d+)") +_OPENCL_MAX_INFLIGHT_SUBMISSIONS = 4 +_OPENCL_SUBGROUP_PROBE_SOURCE = """ +__kernel void vkdispatch_subgroup_probe(__global uint *out) { + size_t gid = get_global_id(0); + if (gid == 0) { + out[0] = 0u; + } +} +""" + + +# --- Runtime state --- + +_initialized = False +_debug_mode = False +_log_level = LOG_LEVEL_WARNING +_error_string: Optional[str] = None +_next_handle = 1 + +_contexts: Dict[int, "_Context"] = {} +_signals: Dict[int, "_Signal"] = {} +_buffers: Dict[int, "_Buffer"] = {} +_command_lists: Dict[int, "_CommandList"] = {} +_compute_plans: Dict[int, "_ComputePlan"] = {} +_descriptor_sets: Dict[int, "_DescriptorSet"] = {} +_images: Dict[int, object] = {} +_samplers: Dict[int, object] = {} +_fft_plans: Dict[int, object] = {} +_subgroup_size_cache: Dict[Tuple[int, int, str, str], int] = {} + +_marker_helpers = threading.local() + + +# --- Internal objects --- + + +@dataclass(frozen=True) +class _DeviceEntry: + logical_index: int + platform_index: int + device_index: int + platform: object + device: object + + +@dataclass +class _Signal: + context_handle: int + queue_index: int + event: Optional[object] = None + submitted: bool = True + done: bool = True + + +@dataclass +class _Context: + device_index: int + cl_context: object + queues: List[object] + queue_count: int + queue_to_device: List[int] + sub_buffer_alignment: int + submission_events: List[List[object]] = field(default_factory=list) + stopped: bool = False + + +@dataclass +class _Buffer: + context_handle: int + size: int + cl_buffer: object + staging_data: List[bytearray] + signal_handles: List[int] + + +@dataclass +class _CommandRecord: + plan_handle: int + descriptor_set_handle: int + blocks: Tuple[int, int, int] + pc_size: int + + +@dataclass +class _CommandList: + context_handle: int + commands: List[_CommandRecord] = field(default_factory=list) + + +@dataclass +class _KernelParam: + kind: str + binding: Optional[int] + raw_name: str + + +@dataclass(frozen=True) +class _PushConstantTypeLayout: + host_elem_size: int + opencl_elem_size: int + opencl_align: int + + +@dataclass(frozen=True) +class _PushConstantFieldDecl: + type_name: str + field_name: str + count: int + + +@dataclass(frozen=True) +class _PushConstantFieldLayout: + type_name: str + field_name: str + count: int + host_offset: int + opencl_offset: int + host_elem_size: int + opencl_elem_size: int + + +@dataclass(frozen=True) +class _PushConstantLayout: + fields: Tuple[_PushConstantFieldLayout, ...] + host_size: int + opencl_size: int + opencl_alignment: int + needs_repack: bool + + +@dataclass +class _ComputePlan: + context_handle: int + shader_source: bytes + bindings: List[int] + shader_name: bytes + program: object + kernel: object + local_size: Tuple[int, int, int] + params: List[_KernelParam] + pc_size: int + pc_layout: Optional[_PushConstantLayout] = None + + +@dataclass +class _DescriptorSet: + plan_handle: int + buffer_bindings: Dict[int, Tuple[int, int, int, int, int, int]] = field(default_factory=dict) + image_bindings: Dict[int, Tuple[int, int, int, int]] = field(default_factory=dict) + + +# --- Helper utilities --- + + +def _new_handle(registry: Dict[int, object], obj: object) -> int: + global _next_handle + handle = _next_handle + _next_handle += 1 + registry[handle] = obj + return handle + + +def _to_bytes(value) -> bytes: + if value is None: + return b"" + if isinstance(value, bytes): + return value + if isinstance(value, bytearray): + return bytes(value) + if isinstance(value, memoryview): + return value.tobytes() + return bytes(value) + + +def _set_error(message: str) -> None: + global _error_string + _error_string = str(message) + + +def _clear_error() -> None: + global _error_string + _error_string = None + +def _enumerate_opencl_devices() -> List[_DeviceEntry]: + entries: List[_DeviceEntry] = [] + + if ( + sys.platform.startswith("linux") + and "OCL_ICD_VENDORS" not in os.environ + and "OPENCL_VENDOR_PATH" not in os.environ + and os.path.isdir("/etc/OpenCL/vendors") + ): + os.environ["OCL_ICD_VENDORS"] = "/etc/OpenCL/vendors" + + try: + platforms = cl.get_platforms() + except Exception as exc: + raise RuntimeError( + f"Failed to get OpenCL Platform: {exc}" + ) from exc + + logical_index = 0 + for platform_index, platform in enumerate(platforms): + try: + devices = platform.get_devices() + except Exception: + continue + + for device_index, device in enumerate(devices): + entries.append( + _DeviceEntry( + logical_index=logical_index, + platform_index=platform_index, + device_index=device_index, + platform=platform, + device=device, + ) + ) + logical_index += 1 + + return entries + + +def _coerce_int(value, fallback: int = 0) -> int: + try: + return int(value) + except Exception: + return int(fallback) + + +def _align_up(value: int, alignment: int) -> int: + if alignment <= 1: + return int(value) + return ((int(value) + alignment - 1) // alignment) * alignment + + +def _opencl_version_components(version_text: str) -> Tuple[int, int]: + if not isinstance(version_text, str): + return (0, 0) + + match = _OPENCL_VERSION_RE.search(version_text) + if match is None: + return (0, 0) + + return (_coerce_int(match.group(1), 0), _coerce_int(match.group(2), 0)) + + +def _driver_version_number(driver_text: str) -> int: + if not isinstance(driver_text, str): + return 0 + + pieces = _DIGIT_RE.findall(driver_text) + if len(pieces) == 0: + return 0 + + folded = 0 + weight = 1_000_000 + for token in pieces[:3]: + folded += _coerce_int(token, 0) * weight + weight = max(1, weight // 1000) + return folded + + +def _device_type_to_vkdispatch(device_type: int) -> int: + if device_type & getattr(cl.device_type, "GPU", 0): + return 2 + if device_type & getattr(cl.device_type, "ACCELERATOR", 0): + return 3 + if device_type & getattr(cl.device_type, "CPU", 0): + return 4 + return 0 + + +def _device_uuid(entry: _DeviceEntry, device_name: str, driver_version: str) -> bytes: + platform_vendor = "" + platform_name = "" + try: + platform_vendor = str(entry.platform.vendor) + except Exception: + platform_vendor = "" + try: + platform_name = str(entry.platform.name) + except Exception: + platform_name = "" + + seed = ( + f"opencl:{entry.platform_index}:{entry.device_index}:" + f"{platform_vendor}:" + f"{platform_name}:" + f"{device_name}:{driver_version}" + ) + return hashlib.md5(seed.encode("utf-8")).digest() + + +def _device_attr(device, attr_name: str, default): + try: + return getattr(device, attr_name) + except Exception: + return default + + +def _release_opencl_object(obj: object) -> None: + release = getattr(obj, "release", None) + if callable(release): + try: + release() + except Exception: + pass + + +def _device_identity_key(entry: _DeviceEntry, device_name: str, driver_version: str) -> Tuple[int, int, str, str]: + return (int(entry.platform_index), int(entry.device_index), str(device_name), str(driver_version)) + + +def _kernel_preferred_workgroup_multiple(device) -> Optional[int]: + ctx = None + program = None + kernel = None + + try: + ctx = cl.Context(devices=[device]) + program = cl.Program(ctx, _OPENCL_SUBGROUP_PROBE_SOURCE).build() + kernel = cl.Kernel(program, "vkdispatch_subgroup_probe") + multiple = kernel.get_work_group_info( + cl.kernel_work_group_info.PREFERRED_WORK_GROUP_SIZE_MULTIPLE, + device, + ) + multiple_int = _coerce_int(multiple, 0) + if multiple_int > 0: + return multiple_int + except Exception: + return None + finally: + _release_opencl_object(kernel) + _release_opencl_object(program) + _release_opencl_object(ctx) + + return None + + +def _round_down_power_of_two(value: int) -> int: + value = int(value) + if value <= 1: + return 1 + return 1 << (value.bit_length() - 1) + + +def _vendor_subgroup_fallback( + *, + device_type: int, + vendor_text: str, + platform_name: str, + device_name: str, + max_workgroup_invocations: int, +) -> int: + if device_type == 4: + return 1 + + combined = " ".join( + token.lower() + for token in (vendor_text, platform_name, device_name) + if isinstance(token, str) and len(token) > 0 + ) + + if "nvidia" in combined: + return 32 + + if "advanced micro devices" in combined or " amd" in f" {combined}" or "radeon" in combined: + return 64 + + if "apple" in combined or "m1" in combined or "m2" in combined or "m3" in combined or "m4" in combined: + return 32 + + if "intel" in combined: + return 16 if device_type == 2 else 1 + + if device_type == 2: + bounded = min(max(1, int(max_workgroup_invocations)), 64) + if bounded >= 32: + return 32 + return _round_down_power_of_two(bounded) + + return 1 + + +def _estimate_subgroup_size( + entry: _DeviceEntry, + device, + *, + device_name: str, + driver_version: str, + device_type: int, + max_workgroup_invocations: int, +) -> int: + cache_key = _device_identity_key(entry, device_name, driver_version) + cached = _subgroup_size_cache.get(cache_key) + if cached is not None: + return cached + + platform_name = str(_device_attr(entry.platform, "name", "")) + vendor_text = str(_device_attr(device, "vendor", _device_attr(entry.platform, "vendor", ""))) + + subgroup_size = _kernel_preferred_workgroup_multiple(device) + if subgroup_size is None: + subgroup_size = _vendor_subgroup_fallback( + device_type=device_type, + vendor_text=vendor_text, + platform_name=platform_name, + device_name=device_name, + max_workgroup_invocations=max_workgroup_invocations, + ) + + subgroup_size = max(1, int(subgroup_size)) + _subgroup_size_cache[cache_key] = subgroup_size + return subgroup_size + + +def _context_from_handle(context_handle: int) -> Optional[_Context]: + ctx = _contexts.get(int(context_handle)) + if ctx is None: + _set_error(f"Invalid context handle {context_handle}") + return ctx + + +def _queue_indices(ctx: _Context, queue_index: int, *, all_on_negative: bool = False) -> List[int]: + if ctx.queue_count <= 0: + return [] + + if queue_index is None: + return [0] + + queue_index = int(queue_index) + + if all_on_negative and queue_index < 0: + return list(range(ctx.queue_count)) + + if queue_index == -1: + return [0] + + if 0 <= queue_index < ctx.queue_count: + return [queue_index] + + return [] + + +def _record_signal(signal: _Signal, event_obj: Optional[object]) -> None: + if signal.event is not None and signal.event is not event_obj: + try: + signal.event.release() + except Exception: + pass + signal.submitted = True + signal.done = event_obj is None + signal.event = event_obj + + +def _query_event_done(event_obj: Optional[object]) -> bool: + if event_obj is None: + return True + + try: + complete = int(getattr(getattr(cl, "command_execution_status", object()), "COMPLETE", 0)) + status = _coerce_int(event_obj.command_execution_status, 0) + return status == complete + except Exception: + return False + + +def _query_signal(signal: _Signal) -> bool: + signal.done = _query_event_done(signal.event) if signal.event is not None else bool(signal.done) + return signal.done + + +def _wait_signal(signal: _Signal) -> bool: + if signal.event is None: + return bool(signal.done) + + try: + signal.event.wait() + signal.done = True + return True + except Exception: + return _query_signal(signal) + + +def _parse_local_size(source: str) -> Tuple[int, int, int]: + x_match = _LOCAL_X_RE.search(source) + y_match = _LOCAL_Y_RE.search(source) + z_match = _LOCAL_Z_RE.search(source) + + if x_match is not None and y_match is not None and z_match is not None: + return ( + _coerce_int(x_match.group(1), 1), + _coerce_int(y_match.group(1), 1), + _coerce_int(z_match.group(1), 1), + ) + + reqd_match = _REQD_LOCAL_RE.search(source) + if reqd_match is not None: + return ( + _coerce_int(reqd_match.group(1), 1), + _coerce_int(reqd_match.group(2), 1), + _coerce_int(reqd_match.group(3), 1), + ) + + return (1, 1, 1) + + +def _opencl_device_launch_limits(logical_device_index: int) -> Tuple[Tuple[int, int, int], int]: + entries = _enumerate_opencl_devices() + if logical_device_index < 0 or logical_device_index >= len(entries): + raise RuntimeError( + f"OpenCL device index {logical_device_index} is out of range for launch validation" + ) + + device = entries[logical_device_index].device + max_work_item_sizes = tuple( + _coerce_int(x, 1) + for x in _device_attr(device, "max_work_item_sizes", (1, 1, 1)) + ) + + if len(max_work_item_sizes) < 3: + max_work_item_sizes = (max_work_item_sizes + (1, 1, 1))[:3] + else: + max_work_item_sizes = max_work_item_sizes[:3] + + max_workgroup_size = ( + max(1, int(max_work_item_sizes[0])), + max(1, int(max_work_item_sizes[1])), + max(1, int(max_work_item_sizes[2])), + ) + max_workgroup_invocations = max( + 1, + _coerce_int(_device_attr(device, "max_work_group_size", 1), 1), + ) + + return max_workgroup_size, max_workgroup_invocations + + +def _validate_local_size_for_enqueue(ctx: _Context, local_size: Tuple[int, int, int]) -> None: + max_workgroup_size, max_workgroup_invocations = _opencl_device_launch_limits(ctx.device_index) + local_x, local_y, local_z = (max(1, int(dim)) for dim in local_size) + local_invocations = local_x * local_y * local_z + + violations = [] + if local_x > max_workgroup_size[0]: + violations.append(f"x={local_x} exceeds {max_workgroup_size[0]}") + if local_y > max_workgroup_size[1]: + violations.append(f"y={local_y} exceeds {max_workgroup_size[1]}") + if local_z > max_workgroup_size[2]: + violations.append(f"z={local_z} exceeds {max_workgroup_size[2]}") + if local_invocations > max_workgroup_invocations: + violations.append( + f"total invocations={local_invocations} exceeds {max_workgroup_invocations}" + ) + + if violations: + raise RuntimeError( + "OpenCL local size is invalid for the active device: " + f"requested ({local_x}, {local_y}, {local_z}), " + f"device limits {max_workgroup_size} with max_work_group_size=" + f"{max_workgroup_invocations} ({'; '.join(violations)})" + ) + + +_PUSH_CONSTANT_SCALAR_LAYOUTS: Dict[str, Tuple[int, int]] = { + "char": (1, 1), + "uchar": (1, 1), + "short": (2, 2), + "ushort": (2, 2), + "int": (4, 4), + "uint": (4, 4), + "long": (8, 8), + "ulong": (8, 8), + "half": (2, 2), + "float": (4, 4), + "double": (8, 8), +} + +_PUSH_CONSTANT_MATRIX_LAYOUTS: Dict[str, _PushConstantTypeLayout] = { + "vkdispatch_mat2": _PushConstantTypeLayout(host_elem_size=16, opencl_elem_size=16, opencl_align=8), + "vkdispatch_mat3": _PushConstantTypeLayout(host_elem_size=36, opencl_elem_size=36, opencl_align=1), + "vkdispatch_mat4": _PushConstantTypeLayout(host_elem_size=64, opencl_elem_size=64, opencl_align=16), + "vkdispatch_packed_float3": _PushConstantTypeLayout(host_elem_size=12, opencl_elem_size=12, opencl_align=1), +} + + +def _extract_push_constant_struct_body(source: str) -> Optional[str]: + struct_match = _PUSH_CONSTANT_STRUCT_RE.search(source) + if struct_match is None: + return None + return struct_match.group("body") + + +def _parse_push_constant_struct_fields(body: str) -> List[_PushConstantFieldDecl]: + fields: List[_PushConstantFieldDecl] = [] + + for raw_decl in body.split(";"): + decl = " ".join(raw_decl.strip().split()) + if len(decl) == 0: + continue + + field_match = _PUSH_CONSTANT_FIELD_RE.fullmatch(decl) + if field_match is None: + raise RuntimeError(f"Unable to parse PushConstant field declaration '{decl}'") + + type_name = field_match.group("type") + field_name = field_match.group("name") + count_token = field_match.group("count") + count = 1 if count_token is None else _coerce_int(count_token, 0) + + if count <= 0: + raise RuntimeError(f"Invalid PushConstant array size for field '{field_name}'") + + fields.append(_PushConstantFieldDecl(type_name=type_name, field_name=field_name, count=count)) + + return fields + + +def _push_constant_type_layout(type_name: str) -> _PushConstantTypeLayout: + matrix_layout = _PUSH_CONSTANT_MATRIX_LAYOUTS.get(type_name) + if matrix_layout is not None: + return matrix_layout + + scalar_layout = _PUSH_CONSTANT_SCALAR_LAYOUTS.get(type_name) + if scalar_layout is not None: + size, align = scalar_layout + return _PushConstantTypeLayout(host_elem_size=size, opencl_elem_size=size, opencl_align=align) + + vector_match = _VECTOR_TYPE_RE.fullmatch(type_name) + if vector_match is not None: + scalar_name = vector_match.group(1) + lane_count = _coerce_int(vector_match.group(2), 0) + scalar_info = _PUSH_CONSTANT_SCALAR_LAYOUTS.get(scalar_name) + if scalar_info is None: + raise RuntimeError(f"Unsupported PushConstant vector scalar type '{scalar_name}'") + + scalar_size, _scalar_align = scalar_info + host_elem_size = scalar_size * lane_count + + if lane_count == 3: + opencl_elem_size = scalar_size * 4 + opencl_align = scalar_size * 4 + else: + opencl_elem_size = host_elem_size + opencl_align = opencl_elem_size + + return _PushConstantTypeLayout( + host_elem_size=host_elem_size, + opencl_elem_size=opencl_elem_size, + opencl_align=opencl_align, + ) + + raise RuntimeError(f"Unsupported PushConstant field type '{type_name}'") + + +def _compute_push_constant_layout(field_decls: List[_PushConstantFieldDecl]) -> _PushConstantLayout: + host_offset = 0 + opencl_offset = 0 + max_opencl_align = 1 + needs_repack = False + field_layouts: List[_PushConstantFieldLayout] = [] + + for field_decl in field_decls: + type_layout = _push_constant_type_layout(field_decl.type_name) + + opencl_offset = _align_up(opencl_offset, type_layout.opencl_align) + + if type_layout.opencl_align > max_opencl_align: + max_opencl_align = type_layout.opencl_align + + if host_offset != opencl_offset: + needs_repack = True + if type_layout.host_elem_size != type_layout.opencl_elem_size: + needs_repack = True + + field_layouts.append( + _PushConstantFieldLayout( + type_name=field_decl.type_name, + field_name=field_decl.field_name, + count=field_decl.count, + host_offset=host_offset, + opencl_offset=opencl_offset, + host_elem_size=type_layout.host_elem_size, + opencl_elem_size=type_layout.opencl_elem_size, + ) + ) + + host_offset += type_layout.host_elem_size * field_decl.count + opencl_offset += type_layout.opencl_elem_size * field_decl.count + + opencl_size = _align_up(opencl_offset, max_opencl_align) + if opencl_size != host_offset: + needs_repack = True + + return _PushConstantLayout( + fields=tuple(field_layouts), + host_size=host_offset, + opencl_size=opencl_size, + opencl_alignment=max_opencl_align, + needs_repack=needs_repack, + ) + + +def _build_push_constant_layout(source: str, expected_host_size: int) -> Optional[_PushConstantLayout]: + expected_host_size = int(expected_host_size) + if expected_host_size <= 0: + return None + + body = _extract_push_constant_struct_body(source) + if body is None: + raise RuntimeError("Could not find PushConstant struct declaration in OpenCL source") + + field_decls = _parse_push_constant_struct_fields(body) + if len(field_decls) == 0: + raise RuntimeError("PushConstant struct declaration is empty") + + layout = _compute_push_constant_layout(field_decls) + if layout.host_size != expected_host_size: + raise RuntimeError( + f"PushConstant host layout mismatch. Expected {expected_host_size} bytes " + f"but parsed {layout.host_size} bytes from OpenCL source." + ) + + return layout + + +def _repack_push_constant_payload( + push_constant_payload: bytes, + layout: Optional[_PushConstantLayout], +) -> bytes: + payload = _to_bytes(push_constant_payload) + + if layout is None or not layout.needs_repack: + return payload + + if len(payload) != int(layout.host_size): + raise RuntimeError( + f"PushConstant payload length mismatch for repack. " + f"Expected {layout.host_size} bytes but got {len(payload)} bytes." + ) + + out = bytearray(int(layout.opencl_size)) + + for field in layout.fields: + if field.host_elem_size > field.opencl_elem_size: + raise RuntimeError( + f"PushConstant field '{field.field_name}' host element size ({field.host_elem_size}) " + f"exceeds OpenCL ABI element size ({field.opencl_elem_size})." + ) + + for element_index in range(int(field.count)): + host_start = field.host_offset + (element_index * field.host_elem_size) + host_end = host_start + field.host_elem_size + opencl_start = field.opencl_offset + (element_index * field.opencl_elem_size) + opencl_end = opencl_start + field.host_elem_size + out[opencl_start:opencl_end] = payload[host_start:host_end] + + return bytes(out) + + +def _parse_kernel_params(source: str) -> List[_KernelParam]: + signature_match = _KERNEL_SIGNATURE_RE.search(source) + if signature_match is None: + raise RuntimeError("Could not find vkdispatch_main kernel signature in OpenCL source") + + signature_blob = signature_match.group(1).strip() + if len(signature_blob) == 0: + return [] + + params: List[_KernelParam] = [] + + for raw_decl in [part.strip() for part in signature_blob.split(",") if len(part.strip()) > 0]: + name_match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)\s*$", raw_decl) + if name_match is None: + raise RuntimeError(f"Unable to parse kernel parameter declaration '{raw_decl}'") + + param_name = name_match.group(1) + + if param_name == "vkdispatch_uniform_ptr": + params.append(_KernelParam("uniform", 0, param_name)) + continue + + if param_name == "vkdispatch_pc_value": + params.append(_KernelParam("push_constant_value", None, param_name)) + continue + + binding_match = _BINDING_PARAM_RE.match(param_name) + if binding_match is not None: + params.append(_KernelParam("storage", _coerce_int(binding_match.group(1), 0), param_name)) + continue + + sampler_match = _SAMPLER_PARAM_RE.match(param_name) + if sampler_match is not None: + params.append(_KernelParam("sampler", _coerce_int(sampler_match.group(1), 0), param_name)) + continue + + params.append(_KernelParam("unknown", None, param_name)) + + return params + + +def _buffer_access_flags(read_access: int, write_access: int) -> int: + read_enabled = int(read_access) != 0 + write_enabled = int(write_access) != 0 + + if read_enabled and not write_enabled: + return int(cl.mem_flags.READ_ONLY) + if write_enabled and not read_enabled: + return int(cl.mem_flags.WRITE_ONLY) + return int(cl.mem_flags.READ_WRITE) + + +def _resolve_descriptor_buffer( + descriptor_set: _DescriptorSet, + binding: int, + ctx: _Context, + keepalive: List[object], +): + binding_info = descriptor_set.buffer_bindings.get(int(binding)) + if binding_info is None: + raise RuntimeError(f"Missing descriptor buffer binding {binding}") + + buffer_handle, offset, requested_range, _uniform, read_access, write_access = binding_info + + buffer_obj = _buffers.get(int(buffer_handle)) + if buffer_obj is None: + raise RuntimeError(f"Invalid buffer handle {buffer_handle} for binding {binding}") + + offset = int(offset) + requested_range = int(requested_range) + + if offset < 0: + raise RuntimeError(f"Negative descriptor offset {offset} for binding {binding}") + + max_size = int(buffer_obj.size) + if offset > max_size: + raise RuntimeError(f"Descriptor offset {offset} exceeds buffer size {max_size} for binding {binding}") + + sub_size = max_size - offset if requested_range <= 0 else requested_range + if sub_size < 0: + raise RuntimeError(f"Invalid descriptor range {sub_size} for binding {binding}") + + if offset + sub_size > max_size: + raise RuntimeError( + f"Descriptor range (offset={offset}, size={sub_size}) exceeds buffer size {max_size} for binding {binding}" + ) + + if offset == 0 and sub_size == max_size: + return buffer_obj.cl_buffer + + if (offset % ctx.sub_buffer_alignment) != 0: + raise RuntimeError( + f"Descriptor offset {offset} for binding {binding} is not aligned to " + f"{ctx.sub_buffer_alignment} bytes required by this OpenCL device" + ) + + sub_buffer = buffer_obj.cl_buffer.get_sub_region( + int(offset), + int(sub_size), + _buffer_access_flags(read_access, write_access), + ) + keepalive.append(sub_buffer) + return sub_buffer + + +def _build_kernel_args( + plan: _ComputePlan, + descriptor_set: Optional[_DescriptorSet], + ctx: _Context, + push_constant_payload: bytes = b"", +) -> Tuple[List[object], List[object]]: + args: List[object] = [] + keepalive: List[object] = [] + + for param in plan.params: + if param.kind == "uniform": + if descriptor_set is None: + raise RuntimeError("Kernel requires a descriptor set but none was provided") + args.append(_resolve_descriptor_buffer(descriptor_set, 0, ctx, keepalive)) + continue + + if param.kind == "storage": + if descriptor_set is None: + raise RuntimeError("Kernel requires a descriptor set but none was provided") + if param.binding is None: + raise RuntimeError("Storage parameter has no binding index") + args.append(_resolve_descriptor_buffer(descriptor_set, int(param.binding), ctx, keepalive)) + continue + + if param.kind == "push_constant_value": + if int(plan.pc_size) <= 0: + raise RuntimeError( + f"Kernel parameter '{param.raw_name}' expects push-constant data, but this compute plan has pc_size={plan.pc_size}." + ) + + if len(push_constant_payload) == 0: + raise RuntimeError( + "Missing push-constant payload for OpenCL by-value push-constant parameter " + f"'{param.raw_name}'." + ) + + if len(push_constant_payload) != int(plan.pc_size): + raise RuntimeError( + f"Push-constant payload size mismatch for parameter '{param.raw_name}'. " + f"Expected {plan.pc_size} bytes but got {len(push_constant_payload)} bytes." + ) + + args.append(_repack_push_constant_payload(push_constant_payload, plan.pc_layout)) + continue + + if param.kind == "sampler": + raise RuntimeError("OpenCL backend does not support image/sampler bindings") + + raise RuntimeError( + f"Unsupported kernel parameter '{param.raw_name}'. " + "Expected vkdispatch_uniform_ptr / vkdispatch_pc_value / vkdispatch_binding__ptr." + ) + + return args, keepalive + + +def _marker_wait_functions() -> List[object]: + cached = getattr(_marker_helpers, "funcs", None) + if cached is not None: + return cached + + funcs: List[object] = [] + for fn_name in ( + "enqueue_marker", + "enqueue_marker_with_wait_list", + "enqueue_barrier_with_wait_list", + ): + fn = getattr(cl, fn_name, None) + if fn is not None: + funcs.append(fn) + + _marker_helpers.funcs = funcs + return funcs + + +def _insert_queue_marker_event(queue) -> Optional[object]: + for marker_fn in _marker_wait_functions(): + try: + event_obj = marker_fn(queue) + if event_obj is not None: + return event_obj + except TypeError: + try: + event_obj = marker_fn(queue, wait_for=[]) + if event_obj is not None: + return event_obj + except Exception: + continue + except Exception: + continue + + return None + + +def _release_event(event_obj: Optional[object]) -> None: + if event_obj is None: + return + + try: + event_obj.release() + except Exception: + pass + + +def _prune_submission_events(ctx: _Context, queue_index: int) -> int: + pending_events: List[object] = [] + + for event_obj in ctx.submission_events[queue_index]: + if _query_event_done(event_obj): + _release_event(event_obj) + continue + + pending_events.append(event_obj) + + ctx.submission_events[queue_index] = pending_events + return len(pending_events) + + +def _reserve_submission_slot(ctx: _Context, queue_index: int) -> bool: + return _prune_submission_events(ctx, queue_index) < _OPENCL_MAX_INFLIGHT_SUBMISSIONS + + +def _track_submission_completion(ctx: _Context, queue_index: int) -> None: + queue = ctx.queues[queue_index] + marker_event = _insert_queue_marker_event(queue) + + if marker_event is None: + queue.finish() + _prune_submission_events(ctx, queue_index) + return + + ctx.submission_events[queue_index].append(marker_event) + queue.flush() + + +# --- API: context/init/logging --- + + +def init(debug, log_level): + global _initialized, _debug_mode, _log_level + + _debug_mode = bool(debug) + _log_level = int(log_level) + _clear_error() + + if _initialized: + return + + _initialized = True + + +def log(log_level, text, file_str, line_str): + _ = log_level + _ = text + _ = file_str + _ = line_str + + +def set_log_level(log_level): + global _log_level + _log_level = int(log_level) + + +def get_devices(): + if not _initialized: + init(False, _log_level) + + entries = _enumerate_opencl_devices() + devices = [] + + for entry in entries: + device = entry.device + opencl_version = _device_attr(device, "version", "") + version_major, version_minor = _opencl_version_components(opencl_version) + version_patch = 0 + + driver_version = str(_device_attr(device, "driver_version", "")) + driver_version_num = _driver_version_number(driver_version) + + vendor_id = _coerce_int(_device_attr(device, "vendor_id", 0), 0) + device_id = int(entry.logical_index) + device_type = _device_type_to_vkdispatch(_coerce_int(_device_attr(device, "type", 0), 0)) + device_name = str(_device_attr(device, "name", f"OpenCL Device {entry.logical_index}")) + + extensions = str(_device_attr(device, "extensions", "")) + float32_atomic_support = ( + "cl_ext_float_atomics" in extensions + or "cl_khr_float_atomics" in extensions + ) + float64_support = "cl_khr_fp64" in extensions or _coerce_int(_device_attr(device, "double_fp_config", 0), 0) != 0 + float16_support = "cl_khr_fp16" in extensions or _coerce_int(_device_attr(device, "half_fp_config", 0), 0) != 0 + int64_support = _coerce_int(_device_attr(device, "address_bits", 0), 0) >= 64 + int16_support = _coerce_int(_device_attr(device, "preferred_vector_width_short", 0), 0) > 0 + + max_work_item_sizes = tuple( + _coerce_int(x, 1) + for x in _device_attr(device, "max_work_item_sizes", (1, 1, 1)) + ) + if len(max_work_item_sizes) < 3: + max_work_item_sizes = ( + max_work_item_sizes + (1, 1, 1) + )[:3] + else: + max_work_item_sizes = max_work_item_sizes[:3] + + max_workgroup_size = ( + max(1, int(max_work_item_sizes[0])), + max(1, int(max_work_item_sizes[1])), + max(1, int(max_work_item_sizes[2])), + ) + max_workgroup_invocations = max(1, _coerce_int(_device_attr(device, "max_work_group_size", 1), 1)) + + max_workgroup_count = (2 ** 31 - 1, 2 ** 31 - 1, 2 ** 31 - 1) + + max_storage_buffer_range = max( + 1, + min( + _coerce_int(_device_attr(device, "max_mem_alloc_size", 1), 1), + (1 << 31) - 1, + ), + ) + max_uniform_buffer_range = max(1, _coerce_int(_device_attr(device, "max_constant_buffer_size", 65536), 65536)) + uniform_alignment = max( + 1, + _coerce_int(_device_attr(device, "mem_base_addr_align", 8), 8) // 8, + ) + max_push_constant_size = max(0, _coerce_int(_device_attr(device, "max_parameter_size", 0), 0)) + + subgroup_size = _estimate_subgroup_size( + entry, + device, + device_name=device_name, + driver_version=driver_version, + device_type=device_type, + max_workgroup_invocations=max_workgroup_invocations, + ) + + max_compute_shared_memory_size = max( + 1, + _coerce_int(_device_attr(device, "local_mem_size", 1), 1), + ) + + uuid_bytes = _device_uuid(entry, device_name, driver_version) + + devices.append( + ( + 0, # Vulkan variant + int(version_major), + int(version_minor), + int(version_patch), + int(driver_version_num), + int(vendor_id), + int(device_id), + int(device_type), + str(device_name), + 1 if float32_atomic_support else 0, + 1 if float32_atomic_support else 0, + 1 if float64_support else 0, + 1 if float16_support else 0, + 1 if int64_support else 0, + 1 if int16_support else 0, + 1 if int16_support else 0, # storage_buffer_16_bit_access + 1 if int16_support else 0, # uniform_and_storage_buffer_16_bit_access + 0, # storage_push_constant_16 + 1 if int16_support else 0, # storage_input_output_16 + max_workgroup_size, + int(max_workgroup_invocations), + max_workgroup_count, + 8, # max descriptor sets (virtualized for parity) + int(max_push_constant_size), + int(max_storage_buffer_range), + int(max_uniform_buffer_range), + int(uniform_alignment), + subgroup_size, # subgroup size + 0, # subgroup stages + 0, # subgroup operations + 0, # quad operations in all stages + int(max_compute_shared_memory_size), + [(1, 0x006)], # compute + transfer queue + 1, # scalar block layout equivalent + 0, # timeline semaphores equivalent + uuid_bytes, + ) + ) + + return devices + + +def context_create(device_indicies, queue_families): + if not _initialized: + init(False, _log_level) + + try: + device_ids = [int(x) for x in device_indicies] + except Exception: + _set_error("context_create expected a list of integer device indices") + return 0 + + if len(device_ids) != 1: + _set_error("OpenCL backend currently supports exactly one device") + return 0 + + try: + normalized_families = [[int(x) for x in family] for family in queue_families] + except Exception: + _set_error("context_create expected queue_families to be a nested integer list") + return 0 + + if len(normalized_families) != 1 or len(normalized_families[0]) != 1: + _set_error("OpenCL backend currently supports exactly one queue") + return 0 + + entries = _enumerate_opencl_devices() + if len(entries) == 0: + if _error_string is None: + _set_error("No OpenCL devices were found") + return 0 + + logical_device_index = int(device_ids[0]) + if logical_device_index < 0 or logical_device_index >= len(entries): + _set_error( + f"Invalid OpenCL device index {logical_device_index}. " + f"Expected range [0, {len(entries) - 1}]" + ) + return 0 + + entry = entries[logical_device_index] + + try: + cl_context = cl.Context(devices=[entry.device]) + queue = cl.CommandQueue(cl_context, device=entry.device) + sub_buffer_alignment = max( + 1, + _coerce_int(_device_attr(entry.device, "mem_base_addr_align", 8), 8) // 8, + ) + ctx = _Context( + device_index=logical_device_index, + cl_context=cl_context, + queues=[queue], + queue_count=1, + queue_to_device=[0], + sub_buffer_alignment=sub_buffer_alignment, + submission_events=[[]], + stopped=False, + ) + return _new_handle(_contexts, ctx) + except Exception as exc: + _set_error(f"Failed to create OpenCL context: {exc}") + return 0 + + +def context_destroy(context): + ctx = _contexts.pop(int(context), None) + if ctx is None: + return + + for queue_events in ctx.submission_events: + for event_obj in queue_events: + _release_event(event_obj) + queue_events.clear() + + for queue in ctx.queues: + try: + queue.finish() + except Exception: + pass + try: + queue.release() + except Exception: + pass + + try: + ctx.cl_context.release() + except Exception: + pass + + +def context_stop_threads(context): + ctx = _contexts.get(int(context)) + if ctx is not None: + ctx.stopped = True + + +def get_error_string(): + if _error_string is None: + return 0 + return _error_string + + +# --- API: signals --- + + +def signal_wait(signal_ptr, wait_for_timestamp, queue_index): + _ = queue_index + + signal_obj = _signals.get(int(signal_ptr)) + if signal_obj is None: + return True + + if not bool(wait_for_timestamp): + if signal_obj.event is None: + return bool(signal_obj.done) + return bool(signal_obj.submitted) + + return _wait_signal(signal_obj) + + +def signal_insert(context, queue_index): + ctx = _context_from_handle(int(context)) + if ctx is None: + return 0 + + selected = _queue_indices(ctx, int(queue_index)) + if len(selected) == 0: + selected = [0] + + signal = _Signal(context_handle=int(context), queue_index=selected[0], submitted=False, done=False) + handle = _new_handle(_signals, signal) + + try: + event_obj = _insert_queue_marker_event(ctx.queues[selected[0]]) + if event_obj is None: + ctx.queues[selected[0]].finish() + signal.done = True + signal.submitted = True + else: + _record_signal(signal, event_obj) + except Exception as exc: + _set_error(f"Failed to insert signal: {exc}") + return 0 + + return handle + + +def signal_destroy(signal_ptr): + signal_obj = _signals.pop(int(signal_ptr), None) + if signal_obj is None: + return + + try: + if signal_obj.event is not None: + signal_obj.event.release() + except Exception: + pass + + +# --- API: buffers --- + + +def buffer_create(context, size, per_device): + _ = per_device + + ctx = _context_from_handle(int(context)) + if ctx is None: + return 0 + + size = int(size) + if size <= 0: + _set_error("Buffer size must be greater than zero") + return 0 + + try: + cl_buffer = cl.Buffer(ctx.cl_context, cl.mem_flags.READ_WRITE, size=size) + signal_handles = [ + _new_handle(_signals, _Signal(context_handle=int(context), queue_index=i, done=True)) + for i in range(ctx.queue_count) + ] + obj = _Buffer( + context_handle=int(context), + size=size, + cl_buffer=cl_buffer, + staging_data=[bytearray(size) for _ in range(ctx.queue_count)], + signal_handles=signal_handles, + ) + return _new_handle(_buffers, obj) + except Exception as exc: + _set_error(f"Failed to create OpenCL buffer: {exc}") + return 0 + + +def buffer_create_external(context, size, device_ptr): + _ = context + _ = size + _ = device_ptr + _set_error("OpenCL backend does not support external buffer aliases in MVP") + return 0 + + +def buffer_destroy(buffer): + obj = _buffers.pop(int(buffer), None) + if obj is None: + return + + for signal_handle in obj.signal_handles: + signal_destroy(signal_handle) + + try: + obj.cl_buffer.release() + except Exception: + pass + + +def buffer_get_queue_signal(buffer, queue_index): + obj = _buffers.get(int(buffer)) + if obj is None: + return _new_handle(_signals, _Signal(context_handle=0, queue_index=0, done=True)) + + queue_index = int(queue_index) + if queue_index < 0 or queue_index >= len(obj.signal_handles): + queue_index = 0 + + return obj.signal_handles[queue_index] + + +def buffer_wait_staging_idle(buffer, queue_index): + signal_handle = buffer_get_queue_signal(buffer, queue_index) + signal_obj = _signals.get(int(signal_handle)) + if signal_obj is None: + return True + return _query_signal(signal_obj) + + +def buffer_write_staging(buffer, queue_index, data, size): + obj = _buffers.get(int(buffer)) + if obj is None: + return + + queue_index = int(queue_index) + if queue_index < 0 or queue_index >= len(obj.staging_data): + return + + payload = _to_bytes(data) + size = min(int(size), len(payload), obj.size) + if size <= 0: + return + + obj.staging_data[queue_index][:size] = payload[:size] + + +def buffer_read_staging(buffer, queue_index, size): + obj = _buffers.get(int(buffer)) + if obj is None: + return bytes(int(size)) + + queue_index = int(queue_index) + if queue_index < 0 or queue_index >= len(obj.staging_data): + return bytes(int(size)) + + size = max(0, int(size)) + staging = obj.staging_data[queue_index] + + if size <= len(staging): + return bytes(staging[:size]) + + return bytes(staging) + bytes(size - len(staging)) + + +def buffer_write(buffer, offset, size, index): + obj = _buffers.get(int(buffer)) + if obj is None: + return + + ctx = _contexts.get(obj.context_handle) + if ctx is None: + _set_error(f"Missing context for buffer handle {buffer}") + return + + offset = int(offset) + size = int(size) + if size <= 0 or offset < 0: + return + + try: + for queue_index in _queue_indices(ctx, int(index), all_on_negative=True): + queue = ctx.queues[queue_index] + end = min(offset + size, obj.size) + copy_size = end - offset + if copy_size <= 0: + continue + + host_src = np.frombuffer(obj.staging_data[queue_index], dtype=np.uint8, count=copy_size) + event_obj = cl.enqueue_copy( + queue, + obj.cl_buffer, + host_src, + dst_offset=offset, + is_blocking=False, + ) + + signal_obj = _signals.get(obj.signal_handles[queue_index]) + if signal_obj is not None: + _record_signal(signal_obj, event_obj) + except Exception as exc: + _set_error(f"Failed to write OpenCL buffer: {exc}") + + +def buffer_read(buffer, offset, size, index): + obj = _buffers.get(int(buffer)) + if obj is None: + return + + ctx = _contexts.get(obj.context_handle) + if ctx is None: + _set_error(f"Missing context for buffer handle {buffer}") + return + + queue_index = int(index) + if queue_index < 0 or queue_index >= ctx.queue_count: + _set_error(f"Invalid queue index {queue_index} for buffer read") + return + + offset = int(offset) + size = int(size) + if size <= 0 or offset < 0: + return + + try: + queue = ctx.queues[queue_index] + end = min(offset + size, obj.size) + copy_size = end - offset + if copy_size <= 0: + return + + host_dst = np.frombuffer(obj.staging_data[queue_index], dtype=np.uint8, count=copy_size) + event_obj = cl.enqueue_copy( + queue, + host_dst, + obj.cl_buffer, + src_offset=offset, + is_blocking=False, + ) + + signal_obj = _signals.get(obj.signal_handles[queue_index]) + if signal_obj is not None: + _record_signal(signal_obj, event_obj) + except Exception as exc: + _set_error(f"Failed to read OpenCL buffer: {exc}") + + +# --- API: command lists --- + + +def command_list_create(context): + if int(context) not in _contexts: + _set_error("Invalid context handle for command_list_create") + return 0 + + return _new_handle(_command_lists, _CommandList(context_handle=int(context))) + + +def command_list_destroy(command_list): + _command_lists.pop(int(command_list), None) + + +def command_list_get_instance_size(command_list): + obj = _command_lists.get(int(command_list)) + if obj is None: + return 0 + + return int(sum(int(command.pc_size) for command in obj.commands)) + + +def command_list_reset(command_list): + obj = _command_lists.get(int(command_list)) + if obj is None: + return + + obj.commands = [] + + +def command_list_submit(command_list, data, instance_count, index): + obj = _command_lists.get(int(command_list)) + if obj is None: + return True + + ctx = _contexts.get(obj.context_handle) + if ctx is None: + _set_error(f"Missing context for command list {command_list}") + return True + + instance_count = int(instance_count) + if instance_count <= 0: + return True + + instance_size = command_list_get_instance_size(command_list) + payload = _to_bytes(data) + expected_payload_size = int(instance_size) * int(instance_count) + + if expected_payload_size == 0: + if len(payload) != 0: + _set_error( + f"Unexpected push-constant data for command list with instance_size=0 " + f"(got {len(payload)} bytes)." + ) + return True + elif len(payload) != expected_payload_size: + _set_error( + f"Push-constant data size mismatch. Expected {expected_payload_size} bytes " + f"(instance_size={instance_size}, instance_count={instance_count}) but got {len(payload)} bytes." + ) + return True + + queue_targets = _queue_indices(ctx, int(index), all_on_negative=True) + if len(queue_targets) == 0: + queue_targets = [0] + + try: + for queue_index in queue_targets: + if not _reserve_submission_slot(ctx, queue_index): + return False + + for queue_index in queue_targets: + queue = ctx.queues[queue_index] + for instance_index in range(instance_count): + instance_base_offset = instance_index * instance_size + per_instance_offset = 0 + for command in obj.commands: + plan = _compute_plans.get(command.plan_handle) + if plan is None: + raise RuntimeError(f"Invalid compute plan handle {command.plan_handle}") + + descriptor_set = None + if command.descriptor_set_handle != 0: + descriptor_set = _descriptor_sets.get(command.descriptor_set_handle) + if descriptor_set is None: + raise RuntimeError( + f"Invalid descriptor set handle {command.descriptor_set_handle}" + ) + + command_pc_size = int(command.pc_size) + pc_payload = b"" + if command_pc_size > 0 and len(payload) > 0: + pc_start = instance_base_offset + per_instance_offset + pc_end = pc_start + command_pc_size + pc_payload = payload[pc_start:pc_end] + + args, _keepalive = _build_kernel_args( + plan, + descriptor_set, + ctx, + pc_payload, + ) + + for arg_index, arg_value in enumerate(args): + plan.kernel.set_arg(arg_index, arg_value) + + local_x = max(1, int(plan.local_size[0])) + local_y = max(1, int(plan.local_size[1])) + local_z = max(1, int(plan.local_size[2])) + _validate_local_size_for_enqueue(ctx, (local_x, local_y, local_z)) + + blocks_x = max(1, int(command.blocks[0])) + blocks_y = max(1, int(command.blocks[1])) + blocks_z = max(1, int(command.blocks[2])) + + global_size = ( + blocks_x * local_x, + blocks_y * local_y, + blocks_z * local_z, + ) + + cl.enqueue_nd_range_kernel( + queue, + plan.kernel, + global_size, + (local_x, local_y, local_z), + ) + + per_instance_offset += command_pc_size + + if per_instance_offset != instance_size: + raise RuntimeError( + f"Internal command list size mismatch: computed {per_instance_offset} bytes, " + f"expected {instance_size} bytes." + ) + + _track_submission_completion(ctx, queue_index) + except Exception as exc: + _set_error(f"Failed to submit OpenCL command list: {exc}") + + return True + + +# --- API: descriptor sets --- + + +def descriptor_set_create(plan): + if int(plan) not in _compute_plans: + _set_error("Invalid compute plan handle for descriptor_set_create") + return 0 + + return _new_handle(_descriptor_sets, _DescriptorSet(plan_handle=int(plan))) + + +def descriptor_set_destroy(descriptor_set): + _descriptor_sets.pop(int(descriptor_set), None) + + +def descriptor_set_write_buffer( + descriptor_set, + binding, + object, + offset, + range, + uniform, + read_access, + write_access, +): + ds = _descriptor_sets.get(int(descriptor_set)) + if ds is None: + _set_error("Invalid descriptor set handle for descriptor_set_write_buffer") + return + + ds.buffer_bindings[int(binding)] = ( + int(object), + int(offset), + int(range), + int(uniform), + int(read_access), + int(write_access), + ) + + +def descriptor_set_write_image( + descriptor_set, + binding, + object, + sampler_obj, + read_access, + write_access, +): + _ = descriptor_set + _ = binding + _ = object + _ = sampler_obj + _ = read_access + _ = write_access + _set_error("OpenCL backend does not support image objects in MVP") + + +# --- API: compute stage --- + + +def stage_compute_plan_create(context, shader_source, bindings, pc_size, shader_name): + ctx = _context_from_handle(int(context)) + if ctx is None: + return 0 + + source_bytes = _to_bytes(shader_source) + shader_name_bytes = _to_bytes(shader_name) + source_text = source_bytes.decode("utf-8", errors="replace") + pc_size = int(pc_size) + + try: + program = cl.Program(ctx.cl_context, source_text).build() + kernel = cl.Kernel(program, "vkdispatch_main") + except Exception as exc: + kernel_name = shader_name_bytes.decode("utf-8", errors="replace") + _set_error(f"Failed to compile OpenCL kernel '{kernel_name}': {exc}") + return 0 + + try: + params = _parse_kernel_params(source_text) + local_size = _parse_local_size(source_text) + pc_layout = _build_push_constant_layout(source_text, pc_size) + except Exception as exc: + _set_error(f"Failed to parse OpenCL kernel metadata: {exc}") + return 0 + + plan = _ComputePlan( + context_handle=int(context), + shader_source=source_bytes, + bindings=[int(x) for x in bindings], + shader_name=shader_name_bytes, + program=program, + kernel=kernel, + local_size=local_size, + params=params, + pc_size=pc_size, + pc_layout=pc_layout, + ) + + return _new_handle(_compute_plans, plan) + + +def stage_compute_plan_destroy(plan): + plan_obj = _compute_plans.pop(int(plan), None) + if plan_obj is None: + return + + try: + plan_obj.kernel.release() + except Exception: + pass + + try: + plan_obj.program.release() + except Exception: + pass + + +def stage_compute_record(command_list, plan, descriptor_set, blocks_x, blocks_y, blocks_z): + cl_obj = _command_lists.get(int(command_list)) + cp_obj = _compute_plans.get(int(plan)) + if cl_obj is None or cp_obj is None: + _set_error("Invalid command list or compute plan handle for stage_compute_record") + return + + cl_obj.commands.append( + _CommandRecord( + plan_handle=int(plan), + descriptor_set_handle=int(descriptor_set), + blocks=(int(blocks_x), int(blocks_y), int(blocks_z)), + pc_size=int(cp_obj.pc_size), + ) + ) + + +# --- API: images/samplers (MVP unsupported) --- + + +def image_create(context, extent, layers, format, type, view_type, generate_mips): + _ = context + _ = extent + _ = layers + _ = format + _ = type + _ = view_type + _ = generate_mips + _set_error("OpenCL backend does not support image objects in MVP") + return 0 + + +def image_destroy(image): + _images.pop(int(image), None) + + +def image_create_sampler( + context, + mag_filter, + min_filter, + mip_mode, + address_mode, + mip_lod_bias, + min_lod, + max_lod, + border_color, +): + _ = context + _ = mag_filter + _ = min_filter + _ = mip_mode + _ = address_mode + _ = mip_lod_bias + _ = min_lod + _ = max_lod + _ = border_color + _set_error("OpenCL backend does not support image samplers in MVP") + return 0 + + +def image_destroy_sampler(sampler): + _samplers.pop(int(sampler), None) + + +def image_write(image, data, offset, extent, baseLayer, layerCount, device_index): + _ = image + _ = data + _ = offset + _ = extent + _ = baseLayer + _ = layerCount + _ = device_index + _set_error("OpenCL backend does not support image writes in MVP") + + +def image_format_block_size(format): + return int(_IMAGE_BLOCK_SIZES.get(int(format), 4)) + + +def image_read(image, out_size, offset, extent, baseLayer, layerCount, device_index): + _ = image + _ = offset + _ = extent + _ = baseLayer + _ = layerCount + _ = device_index + _set_error("OpenCL backend does not support image reads in MVP") + return bytes(max(0, int(out_size))) + + +# --- API: FFT stage (MVP unsupported) --- + + +def stage_fft_plan_create( + context, + dims, + axes, + buffer_size, + do_r2c, + normalize, + pad_left, + pad_right, + frequency_zeropadding, + kernel_num, + kernel_convolution, + conjugate_convolution, + convolution_features, + input_buffer_size, + num_batches, + single_kernel_multiple_batches, + keep_shader_code, +): + _ = context + _ = dims + _ = axes + _ = buffer_size + _ = do_r2c + _ = normalize + _ = pad_left + _ = pad_right + _ = frequency_zeropadding + _ = kernel_num + _ = kernel_convolution + _ = conjugate_convolution + _ = convolution_features + _ = input_buffer_size + _ = num_batches + _ = single_kernel_multiple_batches + _ = keep_shader_code + _set_error("OpenCL backend does not support FFT plans in MVP") + return 0 + + +def stage_fft_plan_destroy(plan): + _fft_plans.pop(int(plan), None) + + +def stage_fft_record(command_list, plan, buffer, inverse, kernel, input_buffer): + _ = command_list + _ = plan + _ = buffer + _ = inverse + _ = kernel + _ = input_buffer + _set_error("OpenCL backend does not support FFT stages in MVP") + + +__all__ = [ + "LOG_LEVEL_VERBOSE", + "LOG_LEVEL_INFO", + "LOG_LEVEL_WARNING", + "LOG_LEVEL_ERROR", + "DESCRIPTOR_TYPE_STORAGE_BUFFER", + "DESCRIPTOR_TYPE_STORAGE_IMAGE", + "DESCRIPTOR_TYPE_UNIFORM_BUFFER", + "DESCRIPTOR_TYPE_UNIFORM_IMAGE", + "DESCRIPTOR_TYPE_SAMPLER", + "init", + "log", + "set_log_level", + "get_devices", + "context_create", + "signal_wait", + "signal_insert", + "signal_destroy", + "context_destroy", + "get_error_string", + "context_stop_threads", + "buffer_create", + "buffer_create_external", + "buffer_destroy", + "buffer_get_queue_signal", + "buffer_wait_staging_idle", + "buffer_write_staging", + "buffer_read_staging", + "buffer_write", + "buffer_read", + "command_list_create", + "command_list_destroy", + "command_list_get_instance_size", + "command_list_reset", + "command_list_submit", + "descriptor_set_create", + "descriptor_set_destroy", + "descriptor_set_write_buffer", + "descriptor_set_write_image", + "image_create", + "image_destroy", + "image_create_sampler", + "image_destroy_sampler", + "image_write", + "image_format_block_size", + "image_read", + "stage_compute_plan_create", + "stage_compute_plan_destroy", + "stage_compute_record", + "stage_fft_plan_create", + "stage_fft_plan_destroy", + "stage_fft_record", +] diff --git a/vkdispatch/tests/test_utils.py b/vkdispatch/base/__init__.py similarity index 100% rename from vkdispatch/tests/test_utils.py rename to vkdispatch/base/__init__.py diff --git a/vkdispatch/base/brython_utils.py b/vkdispatch/base/brython_utils.py new file mode 100644 index 00000000..fa4e7b6b --- /dev/null +++ b/vkdispatch/base/brython_utils.py @@ -0,0 +1,4 @@ +import sys + +def is_brython() -> bool: + return sys.implementation.name == "Brython" \ No newline at end of file diff --git a/vkdispatch/base/buffer.py b/vkdispatch/base/buffer.py index c0aa417c..6f49b622 100644 --- a/vkdispatch/base/buffer.py +++ b/vkdispatch/base/buffer.py @@ -1,38 +1,86 @@ from typing import Tuple -from typing import Dict +from typing import List from typing import Union +from typing import Optional +from contextlib import nullcontext -import numpy as np - +from .init import is_cuda from .dtype import dtype -from .context import Handle +from .context import Handle, Signal from .errors import check_for_errors -from .dtype import to_numpy_dtype, from_numpy_dtype, complex64 +from .dtype import complex64 +from . import dtype as dtypes + +from ..compat import numpy_compat as npc +from .dtype import to_numpy_dtype, from_numpy_dtype -import vkdispatch_native +from ..backends.backend_selection import native import typing _ArgType = typing.TypeVar('_ArgType', bound=dtype) +import dataclasses + +def _suspend_cuda_capture_if_needed(): + if not is_cuda(): + return nullcontext() + + from ..execution_pipeline.cuda_graph_capture import suspend_cuda_capture + return suspend_cuda_capture() + +@dataclasses.dataclass +class ExternalBufferInfo: + writable: bool + iface: dict + keepalive: bool + cuda_ptr: int + class Buffer(Handle, typing.Generic[_ArgType]): - """TODO: Docstring""" + """ + Represents a contiguous block of memory on the GPU (or shared across multiple devices). + + Buffers are the primary mechanism for transferring data between the host (CPU) + and the device (GPU). They are typed using ``vkdispatch.dtype`` and support + multi-dimensional shapes, similar to NumPy arrays. + + :param shape: The dimensions of the buffer. Must be a tuple of 1, 2, or 3 integers. + :type shape: Tuple[int, ...] + :param var_type: The data type of the elements stored in the buffer. + :type var_type: vkdispatch.base.dtype.dtype + :raises ValueError: If the shape has more than 3 dimensions or if the requested size exceeds 2^30 elements. + """ var_type: dtype shape: Tuple[int] size: int mem_size: int - - def __init__(self, shape: Tuple[int, ...], var_type: dtype) -> None: + signals: List[Signal] + is_external: bool + owns_memory: bool + is_writable: bool + cuda_ptr: typing.Optional[int] + cuda_source: typing.Any + cuda_array_stream: typing.Optional[typing.Any] + + def __init__(self, shape: Tuple[int, ...], var_type: dtype, external_buffer: ExternalBufferInfo = None) -> None: super().__init__() + if isinstance(shape, int): + shape = (shape,) + if len(shape) > 3: raise ValueError("Buffer shape must be 1, 2, or 3 dimensions!") self.var_type: dtype = var_type self.shape: Tuple[int] = shape - self.size: int = int(np.prod(shape)) + + size = 1 + for dim in shape: + size *= dim + self.size = size + self.mem_size: int = self.size * self.var_type.item_size if self.size > 2 ** 30: @@ -47,63 +95,176 @@ def __init__(self, shape: Tuple[int, ...], var_type: dtype) -> None: self.shader_shape = tuple(shader_shape_internal) - handle = vkdispatch_native.buffer_create( - self.context._handle, self.mem_size, 0 - ) - check_for_errors() + self.signals = [] + self.is_external = external_buffer is not None + self.owns_memory = external_buffer is None + self.is_writable = True if external_buffer is None else external_buffer.writable + self.cuda_ptr = None if external_buffer is None else external_buffer.cuda_ptr + self.cuda_source = None if external_buffer is None else (external_buffer.iface if external_buffer.keepalive else None) + self.cuda_array_stream = None if external_buffer is None else external_buffer.iface.get("stream") + + with _suspend_cuda_capture_if_needed(): + if external_buffer is not None: + handle = native.buffer_create_external( + self.context._handle, + self.mem_size, + self.cuda_ptr, + ) + else: + handle = native.buffer_create( + self.context._handle, self.mem_size, 0 + ) + check_for_errors() + + self.signals = [ + Signal( + native.buffer_get_queue_signal( + handle, queue_index + ) + ) + for queue_index in range(self.context.queue_count) + ] self.register_handle(handle) + def __repr__(self): + return f"""Buffer {self._handle}: + shape={self.shape} + var_type={self.var_type.name} + mem_size={self.mem_size} bytes + is_external={self.is_external} + writable={self.is_writable} + cuda_ptr={self.cuda_ptr} + cuda_iface={self.cuda_source} +""" + def _destroy(self) -> None: """Destroy the buffer and all child handles.""" - vkdispatch_native.buffer_destroy(self._handle) + + for ii, signal in enumerate(self.signals): + signal.wait(False, ii) + + native.buffer_destroy(self._handle) def __del__(self) -> None: self.destroy() - def write(self, data: Union[bytes, np.ndarray], index: int = -1) -> None: - """Given data in some numpy array, write that data to the buffer at the - specified index. The default index of -1 will write to - all buffers. + def _wait_staging_idle(self, index: int): + with _suspend_cuda_capture_if_needed(): + is_idle = native.buffer_wait_staging_idle(self._handle, index) + check_for_errors() + return is_idle + + def _do_writes(self, data: bytes, index: int = None): + indicies = [index] if index is not None else range(self.context.queue_count) + completed_stages = [0] * len(indicies) + + with _suspend_cuda_capture_if_needed(): + while not all(stage == 1 for stage in completed_stages): + for i in range(len(indicies)): + if completed_stages[i] == 1: + continue + + queue_index = indicies[i] + + if not self.signals[queue_index].try_wait(True, queue_index): + continue + + completed_stages[i] = 1 - Parameters: - data (np.ndarray): The data to write to the buffer. - index (int): The index to write the data to. Default is -1 and - will write to all buffers. + native.buffer_write_staging(self._handle, queue_index, data, len(data)) + check_for_errors() - Returns: - None + native.buffer_write(self._handle, 0, len(data), queue_index) + check_for_errors() + + def write(self, data: Union[bytes, bytearray, memoryview, typing.Any], index: int = None) -> None: """ - if index < -1: - raise ValueError(f"Invalid buffer index {index}!") + Uploads data from the host to the GPU buffer. + + If ``index`` is None, the data is broadcast to the memory of all active devices + in the context. Otherwise, it writes only to the device specified by the index. + + :param data: The source data. Can be a bytes-like object or an array-like object. + :type data: Union[bytes, bytearray, memoryview, Any] + :param index: The device index to write to. Defaults to -1 (all devices). + :type index: int + :raises ValueError: If the data size exceeds the buffer size or if the index is invalid. + """ + if index is not None: + assert isinstance(index, int), "Index must be an integer or None!" + assert index >= 0 and index < self.context.queue_count, "Index must be valid!" + + if not getattr(self, "is_writable", True): + raise ValueError("Cannot write to a read-only buffer alias.") true_data_object = None - if isinstance(data, np.ndarray): - if data.size * np.dtype(data.dtype).itemsize != self.mem_size: + if npc.is_array_like(data): + if npc.array_nbytes(data) != self.mem_size: raise ValueError("Numpy buffer sizes must match!") - true_data_object = np.ascontiguousarray(data).tobytes() + true_data_object = npc.as_contiguous_bytes(data) else: - if len(data) > self.mem_size: + true_data_object = npc.ensure_bytes(data) + + if len(true_data_object) > self.mem_size: raise ValueError("Data Size must be less than buffer size") - true_data_object = data + self._do_writes(true_data_object, index) - vkdispatch_native.buffer_write( - self._handle, true_data_object, 0, len(true_data_object), index - ) - check_for_errors() + def _do_reads(self, var_type: dtype, shape: List[int], index: int = None) -> bytes: + assert index is None or (isinstance(index, int) and index >= 0), "Index must be None or a non-negative integer!" + + indicies = [index] if index is not None else range(self.context.queue_count) + completed_stages = [0] * len(indicies) + bytes_list: List[bytes] = [None] * len(indicies) + + mem_size = int(npc.prod(shape)) * var_type.item_size - def read(self, index: Union[int, None] = None) -> np.ndarray: - """Read the data in the buffer at the specified device index and return it as a - numpy array. + with _suspend_cuda_capture_if_needed(): + while not all(stage == 2 for stage in completed_stages): + for i in range(len(indicies)): + if completed_stages[i] == 2: + continue - Parameters: - index (int): The index to read the data from. Default is 0. + queue_index = indicies[i] - Returns: - (np.ndarray): The data in the buffer as a numpy array. + if completed_stages[i] == 0: + if self.signals[queue_index].try_wait(False, queue_index): + completed_stages[i] = 1 + native.buffer_read(self._handle, 0, mem_size, queue_index) + check_for_errors() + else: + continue + + if completed_stages[i] == 1: + if self.signals[queue_index].try_wait(True, queue_index): + completed_stages[i] = 2 + else: + continue + + bytes_list[i] = native.buffer_read_staging(self._handle, queue_index, mem_size) + check_for_errors() + + host_arrays = [] + + for b in bytes_list: + host_arrays.append( + npc.from_buffer(b, dtype=to_numpy_dtype(var_type), shape=tuple(shape)) + ) + + return host_arrays if index is None else host_arrays[0] + + def read(self, index: Union[int, None] = None): + """ + Downloads data from the GPU buffer to the host. + + :param index: The device index to read from. If ``None``, reads from all devices + and returns a stacked array with an extra dimension for the device index. + :type index: Union[int, None] + :return: A host array representation containing the buffer data. + :raises ValueError: If the specified index is invalid. """ true_scalar = self.var_type.scalar @@ -111,67 +272,177 @@ def read(self, index: Union[int, None] = None) -> np.ndarray: if true_scalar is None: true_scalar = self.var_type - if index is not None: - if index < 0: - raise ValueError(f"Invalid buffer index {index}!") - - result_bytes = vkdispatch_native.buffer_read( - self._handle, 0, self.mem_size, index - ) + data_shape = list(self.shape) + list(self.var_type.true_numpy_shape) - result = np.frombuffer(result_bytes, dtype=to_numpy_dtype(true_scalar)).reshape(self.shape + self.var_type.true_numpy_shape) + if index is not None: + return self._do_reads(true_scalar, data_shape, index) + + results = self._do_reads(true_scalar, data_shape, None) - check_for_errors() - else: - result = np.zeros((self.context.queue_count,) + self.shape + self.var_type.true_numpy_shape, dtype=to_numpy_dtype(true_scalar)) + if npc.HAS_NUMPY: + return npc.numpy_module().array(results) - for i in range(self.context.queue_count): - result[i] = self.read(i) + return results - return result +def asbuffer(array: typing.Any) -> Buffer: + """Cast an array-like object to a buffer object.""" + if hasattr(array, "__cuda_array_interface__"): + return from_cuda_array(array) -def asbuffer(array: np.ndarray) -> Buffer: - """Cast a numpy array to a buffer object.""" + if not npc.is_array_like(array): + raise TypeError("Expected an array-like object") - buffer = Buffer(array.shape, from_numpy_dtype(array.dtype)) + buffer = Buffer(npc.array_shape(array), from_numpy_dtype(npc.array_dtype(array))) buffer.write(array) return buffer +def from_cuda_array( + obj: typing.Any, + var_type: typing.Optional[dtype] = None, + require_contiguous: bool = True, + writable: typing.Optional[bool] = None, + keepalive: bool = True, +) -> Buffer: + assert is_cuda(), "__cuda_array_interface__ is only supported with CUDA backends." + + if not hasattr(obj, "__cuda_array_interface__"): + raise TypeError("Expected an object with __cuda_array_interface__") + + npc.require_numpy("from_cuda_array") + np = npc.numpy_module() + + iface = obj.__cuda_array_interface__ + if not isinstance(iface, dict): + raise TypeError("__cuda_array_interface__ must be a dictionary") + + if "shape" not in iface or "typestr" not in iface or "data" not in iface: + raise ValueError("__cuda_array_interface__ is missing required fields (shape/typestr/data)") + + shape = tuple(int(dim) for dim in iface["shape"]) + if len(shape) == 0: + shape = (1,) + + data_entry = iface["data"] + if not (isinstance(data_entry, tuple) and len(data_entry) >= 2): + raise ValueError("__cuda_array_interface__['data'] must be a tuple (ptr, read_only)") + + ptr = int(data_entry[0]) + source_read_only = bool(data_entry[1]) + + inferred_np_dtype = np.dtype(iface["typestr"]) + inferred_var_type = from_numpy_dtype(inferred_np_dtype) + if var_type is None: + var_type = inferred_var_type + + if not (var_type == inferred_var_type): + raise ValueError( + f"CAI dtype ({inferred_np_dtype}) does not match requested vd dtype ({var_type.name})." + ) + + if require_contiguous: + strides = iface.get("strides") + if strides is not None: + expected_strides = [] + stride = int(inferred_np_dtype.itemsize) + for dim in reversed(shape): + expected_strides.insert(0, stride) + stride *= int(dim) + if tuple(int(x) for x in strides) != tuple(expected_strides): + raise ValueError("Only contiguous C-order CUDA arrays are supported in from_cuda_array().") + + buffer_writable = (not source_read_only) if writable is None else bool(writable) + if buffer_writable and source_read_only: + raise ValueError("Requested writable=True for a read-only CUDA array.") + + external_buffer_info = ExternalBufferInfo( + writable=buffer_writable, + iface=iface, + keepalive=keepalive, + cuda_ptr=ptr + ) + + return Buffer(shape, var_type, external_buffer=external_buffer_info) + class RFFTBuffer(Buffer): - def __init__(self, shape: Tuple[int, ...]): - super().__init__(tuple(shape[:-1]) + (shape[-1] // 2 + 1,), complex64) + real_shape: Tuple[int, ...] + fourier_shape: Tuple[int, ...] + real_type: dtype + + def __init__(self, shape: Tuple[int, ...], fourier_type: dtype = complex64): + if not dtypes.is_complex(fourier_type): + raise ValueError("RFFTBuffer fourier_type must be complex32, complex64, or complex128") + + if not dtypes.is_float_dtype(fourier_type.child_type): + raise ValueError("RFFTBuffer fourier_type must use a floating-point scalar") + + super().__init__(tuple(shape[:-1]) + (shape[-1] // 2 + 1,), fourier_type) self.real_shape = shape self.fourier_shape = self.shape - - def read_real(self, index: Union[int, None] = None) -> np.ndarray: - return self.read(index).view(np.float32)[..., :self.real_shape[-1]] + self.real_type = fourier_type.child_type + + def read_real(self, index: Union[int, None] = None): + npc.require_numpy("RFFTBuffer.read_real") + np = npc.numpy_module() + + packed_shape = list(self.shape[:-1]) + [self.shape[-1] * 2] + packed_data = self._do_reads(self.real_type, packed_shape, index) - def read_fourier(self, index: Union[int, None] = None) -> np.ndarray: + if index is None: + packed_data = np.array(packed_data) + + return packed_data[..., :self.real_shape[-1]] + + def read_fourier(self, index: Union[int, None] = None): return self.read(index) - - def write_real(self, data: np.ndarray, index: int = -1): + + def write_real(self, data, index: int = None): + npc.require_numpy("RFFTBuffer.write_real") + np = npc.numpy_module() assert data.shape == self.real_shape, "Data shape must match real shape!" - assert not np.issubdtype(data.dtype, np.complexfloating) , "Data dtype must be scalar!" + assert not np.issubdtype(data.dtype, np.complexfloating), "Data dtype must be scalar!" - true_data = np.zeros(self.shape[:-1] + (self.shape[-1] * 2,), dtype=np.float32) + real_dtype = to_numpy_dtype(self.real_type) + true_data = np.zeros(self.shape[:-1] + (self.shape[-1] * 2,), dtype=real_dtype) true_data[..., :self.real_shape[-1]] = data - self.write(np.ascontiguousarray(true_data).view(np.complex64), index) + self.write(np.ascontiguousarray(true_data), index) - def write_fourier(self, data: np.ndarray, index: int = -1): + def write_fourier(self, data, index: int = None): + npc.require_numpy("RFFTBuffer.write_fourier") + np = npc.numpy_module() assert data.shape == self.fourier_shape, f"Data shape {data.shape} must match fourier shape {self.fourier_shape}!" - assert np.issubdtype(data.dtype, np.complexfloating) , "Data dtype must be complex!" + assert np.issubdtype(data.dtype, np.complexfloating), "Data dtype must be complex!" + + target_fourier_dtype = to_numpy_dtype(self.var_type) + if npc.is_host_dtype(target_fourier_dtype): + # complex32: pack complex values into float16 real/imag pairs. + complex_data = np.ascontiguousarray(data.astype(np.complex64)) + packed_pairs = np.empty(complex_data.shape + (2,), dtype=np.float16) + packed_pairs[..., 0] = complex_data.real.astype(np.float16) + packed_pairs[..., 1] = complex_data.imag.astype(np.float16) - self.write(np.ascontiguousarray(data.astype(np.complex64)).view(np.float32), index) + packed_real_shape = self.shape[:-1] + (self.shape[-1] * 2,) + self.write(np.ascontiguousarray(packed_pairs).reshape(packed_real_shape), index) + return -def asrfftbuffer(data: np.ndarray) -> RFFTBuffer: + self.write(np.ascontiguousarray(data.astype(target_fourier_dtype)), index) + + +def asrfftbuffer(data, fourier_type: Optional[dtype] = None) -> RFFTBuffer: + npc.require_numpy("asrfftbuffer") + np = npc.numpy_module() assert not np.issubdtype(data.dtype, np.complexfloating), "Data dtype must be scalar!" - buffer = RFFTBuffer(data.shape) + if fourier_type is None: + scalar_dtype = from_numpy_dtype(data.dtype) + scalar_dtype = dtypes.make_floating_dtype(scalar_dtype) + fourier_type = dtypes.complex_from_float(scalar_dtype) + + buffer = RFFTBuffer(data.shape, fourier_type=fourier_type) buffer.write_real(data) - return buffer \ No newline at end of file + return buffer diff --git a/vkdispatch/base/buffer_allocators.py b/vkdispatch/base/buffer_allocators.py new file mode 100644 index 00000000..e14fed86 --- /dev/null +++ b/vkdispatch/base/buffer_allocators.py @@ -0,0 +1,119 @@ +from .buffer import Buffer +from . import dtype as dt +from typing import Tuple + +def buffer_u32(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 32-bit integers with the specified shape.""" + return Buffer(shape, dt.uint32) + +def buffer_uv2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 32-bit integer vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.uvec2) + +def buffer_uv3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 32-bit integer vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.uvec3) + +def buffer_uv4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 32-bit integer vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.uvec4) + +def buffer_i32(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 32-bit integers with the specified shape.""" + return Buffer(shape, dt.int32) + +def buffer_iv2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 32-bit integer vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.ivec2) + +def buffer_iv3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 32-bit integer vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.ivec3) + +def buffer_iv4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 32-bit integer vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.ivec4) + +def buffer_f32(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 32-bit floating-point numbers with the specified shape.""" + return Buffer(shape, dt.float32) + +def buffer_v2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 32-bit floating-point vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.vec2) + +def buffer_v3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 32-bit floating-point vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.vec3) + +def buffer_v4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 32-bit floating-point vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.vec4) + +def buffer_c64(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 64-bit complex numbers with the specified shape.""" + return Buffer(shape, dt.complex64) + +def buffer_u16(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 16-bit integers with the specified shape.""" + return Buffer(shape, dt.uint16) + +def buffer_uhv2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 16-bit integer vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.uhvec2) + +def buffer_uhv3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 16-bit integer vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.uhvec3) + +def buffer_uhv4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of unsigned 16-bit integer vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.uhvec4) + +def buffer_i16(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 16-bit integers with the specified shape.""" + return Buffer(shape, dt.int16) + +def buffer_ihv2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 16-bit integer vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.ihvec2) + +def buffer_ihv3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 16-bit integer vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.ihvec3) + +def buffer_ihv4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of signed 16-bit integer vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.ihvec4) + +def buffer_f16(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 16-bit floating-point numbers with the specified shape.""" + return Buffer(shape, dt.float16) + +def buffer_hv2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 16-bit floating-point vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.hvec2) + +def buffer_hv3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 16-bit floating-point vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.hvec3) + +def buffer_hv4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 16-bit floating-point vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.hvec4) + +def buffer_f64(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 64-bit floating-point numbers with the specified shape.""" + return Buffer(shape, dt.float64) + +def buffer_dv2(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 64-bit floating-point vectors of size 2 with the specified shape.""" + return Buffer(shape, dt.dvec2) + +def buffer_dv3(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 64-bit floating-point vectors of size 3 with the specified shape.""" + return Buffer(shape, dt.dvec3) + +def buffer_dv4(shape: Tuple[int, ...]) -> Buffer: + """Create a buffer of 64-bit floating-point vectors of size 4 with the specified shape.""" + return Buffer(shape, dt.dvec4) \ No newline at end of file diff --git a/vkdispatch/base/command_list.py b/vkdispatch/base/command_list.py index 67ea91d0..99fa2799 100644 --- a/vkdispatch/base/command_list.py +++ b/vkdispatch/base/command_list.py @@ -1,33 +1,38 @@ from typing import Tuple from typing import Optional -import vkdispatch_native +from ..backends.backend_selection import native +from .init import is_cuda from .context import Handle from .errors import check_for_errors +from ..execution_pipeline.cuda_graph_capture import get_cuda_capture + from .compute_plan import ComputePlan from .descriptor_set import DescriptorSet -import numpy as np - class CommandList(Handle): """ - A class for recording and submitting command lists to the device. + Represents a sequence of GPU commands to be executed on a device. + + CommandLists are used to record dispatch operations, memory barriers, and + synchronization points. They act as the primary unit of work submission + to the Vulkan queue. Attributes: - _handle (int): The handle to the command list. + _handle (int): The internal handle to the native Vulkan command buffer wrapper. """ def __init__(self) -> None: super().__init__() - handle = vkdispatch_native.command_list_create(self.context._handle) + handle = native.command_list_create(self.context._handle) self.register_handle(handle) check_for_errors() def _destroy(self) -> None: - vkdispatch_native.command_list_destroy(self._handle) + native.command_list_destroy(self._handle) check_for_errors() def __del__(self) -> None: @@ -35,7 +40,7 @@ def __del__(self) -> None: def get_instance_size(self) -> int: """Get the total size of the command list in bytes.""" - result = vkdispatch_native.command_list_get_instance_size(self._handle) + result = native.command_list_get_instance_size(self._handle) check_for_errors() return result @@ -44,17 +49,19 @@ def record_compute_plan(self, descriptor_set: DescriptorSet, blocks: Tuple[int, int, int]) -> None: """ - Record a compute plan to the command list. - - Args: - plan (ComputePlan): The compute plan to record to the command list. - descriptor_set (DescriptorSet): The descriptor set to bind to the compute plan. - blocks (Tuple[int, int, int]): The number of blocks to run the compute shader in. + Records a compute shader dispatch into the command list. + + :param plan: The compiled compute plan (shader) to execute. + :type plan: vkdispatch.base.compute_plan.ComputePlan + :param descriptor_set: The resource bindings (buffers, images) for this execution. + :type descriptor_set: vkdispatch.base.descriptor_set.DescriptorSet + :param blocks: The dimensions of the workgroup grid (x, y, z) to dispatch. + :type blocks: Tuple[int, int, int] """ self.register_parent(plan) self.register_parent(descriptor_set) - vkdispatch_native.stage_compute_record( + native.stage_compute_record( self._handle, plan._handle, descriptor_set._handle, @@ -67,14 +74,30 @@ def record_compute_plan(self, def reset(self) -> None: """Reset the command list. """ - vkdispatch_native.command_list_reset(self._handle) + native.command_list_reset(self._handle) check_for_errors() self.clear_parents() - def submit(self, data: Optional[bytes] = None, queue_index: int = -2, instance_count: Optional[int] = None) -> None: + def submit( + self, + data: Optional[bytes] = None, + queue_index: int = -2, + instance_count: Optional[int] = None, + cuda_stream=None + ) -> None: """ - Submit the command list to the specified device with additional data to + Submits the recorded command list to the GPU queue for execution. + + :param data: Optional binary data (e.g., push constants) to append to the + front of the command list buffer before submission. + :type data: Optional[bytes] + :param queue_index: The index of the queue to submit to. -2 uses the default queue associated + with the command list's context. + :type queue_index: int + :param instance_count: The number of instances to execute if instanced dispatch is used. + :type instance_count: Optional[int] + :raises ValueError: If data length logic conflicts with instance size. """ if data is None and instance_count is None: @@ -92,7 +115,22 @@ def submit(self, data: Optional[bytes] = None, queue_index: int = -2, instance_c if self.get_instance_size() != 0: assert self.get_instance_size() * instance_count == len(data), "Data length must be the product of the instance size and instance count!" - vkdispatch_native.command_list_submit( - self._handle, data, instance_count, queue_index - ) - check_for_errors() \ No newline at end of file + if cuda_stream is None and get_cuda_capture() is not None: + cuda_stream = get_cuda_capture().cuda_stream + + if cuda_stream is not None: + if not is_cuda(): + raise RuntimeError("cuda_stream=... is currently only supported with CUDA backends.") + + native.cuda_stream_override_begin(cuda_stream) + check_for_errors() + + done = False + while not done: + done = native.command_list_submit( + self._handle, data, instance_count, queue_index + ) + check_for_errors() + + if cuda_stream is not None: + native.cuda_stream_override_end() diff --git a/vkdispatch/base/compute_plan.py b/vkdispatch/base/compute_plan.py index 087c1582..995ae177 100644 --- a/vkdispatch/base/compute_plan.py +++ b/vkdispatch/base/compute_plan.py @@ -1,4 +1,4 @@ -import vkdispatch_native +from ..backends.backend_selection import native from .context import Handle from .errors import check_for_compute_stage_errors, check_for_errors @@ -6,13 +6,21 @@ class ComputePlan(Handle): """ - ComputePlan is a wrapper for the native functions which create and dispatch Vulkan compute shaders. - - Attributes: - pc_size (int): The size of the push constants for the compute shader (in bytes) - shader_source (str): The source code of the compute shader (in GLSL) - binding_list (list): A list of binding types for the shader resources. - _handle (int): A pointer to the compute plan created by the native Vulkan dispatch. + Represents a compiled Compute Pipeline ready for execution. + + A ComputePlan wraps the native Vulkan pipeline objects, including the shader module, + descriptor set layouts, and pipeline layouts. It is created by compiling GLSL + source code generated by the ``vkdispatch.codegen`` module. + + :param shader_source: The GLSL source code for the compute shader. + :type shader_source: str + :param binding_type_list: A list of integers representing the type of resource + bound to each binding index. + :type binding_type_list: list + :param pc_size: The size of the push constant block in bytes. + :type pc_size: int + :param shader_name: A name for the shader, used for debugging and logging. + :type shader_name: str """ def __init__(self, shader_source: str, binding_type_list: list, pc_size: int, shader_name: str) -> None: @@ -22,15 +30,14 @@ def __init__(self, shader_source: str, binding_type_list: list, pc_size: int, sh self.shader_source = shader_source self.binding_list = binding_type_list - handle = vkdispatch_native.stage_compute_plan_create( + handle = native.stage_compute_plan_create( self.context._handle, shader_source.encode(), self.binding_list, pc_size, shader_name.encode() ) check_for_compute_stage_errors() - self.register_handle(handle) def _destroy(self) -> None: - vkdispatch_native.stage_compute_plan_destroy(self._handle) + native.stage_compute_plan_destroy(self._handle) check_for_errors() def __del__(self) -> None: diff --git a/vkdispatch/base/context.py b/vkdispatch/base/context.py index 386eb06e..d10f0c9a 100644 --- a/vkdispatch/base/context.py +++ b/vkdispatch/base/context.py @@ -10,9 +10,12 @@ import os, signal from .errors import check_for_errors, set_running -from .init import DeviceInfo, get_devices, initialize, set_log_level, LogLevel +from .init import DeviceInfo, is_cuda, is_opencl, is_dummy, get_devices, initialize, log_info +from ..backends.backend_selection import native -import vkdispatch_native +VK_SHADER_STAGE_COMPUTE_BIT = 0x00000020 + +VK_SUBGROUP_FEATURE_ARITHMETIC_BIT = 0x00000004 class Handle: context: "Context" @@ -53,6 +56,8 @@ def clear_parents(self) -> None: """ Clears the parent handles. """ + # children_dict uses weak references, so a child key may disappear + # before teardown reaches this point. for parent in self.parents.values(): parent.remove_child_handle(self) @@ -71,10 +76,8 @@ def remove_child_handle(self, child: "Handle") -> None: """ Removes a child handle from the current handle. """ - if child._handle not in self.children_dict.keys(): - raise ValueError(f"Child handle {child._handle} does not exist in parent handle!") - - self.children_dict.pop(child._handle) + # Be idempotent to tolerate repeated teardown paths and weakref eviction. + self.children_dict.pop(child._handle, None) def _destroy(self) -> None: raise NotImplementedError("destroy is an abstract method and must be implemented by subclasses.") @@ -84,28 +87,57 @@ def destroy(self) -> None: Destroys the context handle and cleans up resources. """ if self.destroyed: - return + return - child_list = list(self.children_dict.values()) + self.destroyed = True + self.clear_parents() - for child in child_list: - child.destroy() + child_keys = list(self.children_dict.keys()) + + for child_handle in child_keys: + if child_handle in self.children_dict: + child = self.children_dict[child_handle] + child.destroy() assert len(self.children_dict) == 0, "Not all children were destroyed!" assert not self.canary, "Handle was already destroyed!" - self._destroy() + if self._handle is not None: + self._destroy() + check_for_errors() + self.canary = True - check_for_errors() - - self.clear_parents() if self._handle in self.context.handles_dict.keys(): self.context.handles_dict.pop(self._handle) - self.destroyed = True +class Signal: + ptr_addr: int + + def __init__(self, ptr_addr: int = None): + self.ptr_addr = ptr_addr + + def wait(self, wait_for_timestamp: bool, queue_index: int): + done = False + while not done: + done = native.signal_wait( + self.ptr_addr, wait_for_timestamp, queue_index + ) + check_for_errors() + + def try_wait(self, wait_for_timestamp: bool, queue_index: int): + done = native.signal_wait( + self.ptr_addr, wait_for_timestamp, queue_index + ) + check_for_errors() + + return done + + def free(self): + native.signal_destroy(self.ptr_addr) + class Context: """ A class for managing the context of the vkdispatch library. @@ -125,11 +157,14 @@ class Context: """ _handle: int - devices: List[int] + device_ids: List[int] + mapped_device_ids: List[int] device_infos: List[DeviceInfo] queue_families: List[List[int]] queue_count: int subgroup_size: int + subgroup_enabled: bool + subgroup_arithmetic: bool max_workgroup_size: Tuple[int] max_workgroup_invocations: int max_workgroup_count: Tuple[int, int, int] @@ -139,17 +174,21 @@ class Context: def __init__( self, - devices: List[int], + device_ids: List[int], queue_families: List[List[int]] ) -> None: - self.devices = devices - self.device_infos = [get_devices()[dev] for dev in devices] + self.device_ids = device_ids + self.device_infos = [get_devices()[dev] for dev in device_ids] self.queue_families = queue_families self.queue_count = sum([len(i) for i in queue_families]) self.handles_dict = weakref.WeakValueDictionary() - self._handle = vkdispatch_native.context_create(devices, queue_families) + self.mapped_device_ids = [dev.dev_index for dev in self.device_infos] + self._handle = native.context_create(self.mapped_device_ids, queue_families) check_for_errors() - + + self._refresh_limits_from_device_infos() + + def _refresh_limits_from_device_infos(self) -> None: subgroup_sizes = [] max_workgroup_sizes_x = [] max_workgroup_sizes_y = [] @@ -161,6 +200,9 @@ def __init__( uniform_buffer_alignments = [] max_shared_memory = [] + subgroup_enabled = True + subgroup_arithmetic = True + for device in self.device_infos: subgroup_sizes.append(device.sub_group_size) @@ -178,6 +220,14 @@ def __init__( max_shared_memory.append(device.max_compute_shared_memory_size) + if not device.supported_stages & VK_SHADER_STAGE_COMPUTE_BIT: + subgroup_enabled = False + + if not device.supported_operations & VK_SUBGROUP_FEATURE_ARITHMETIC_BIT: + subgroup_arithmetic = False + + self.subgroup_enabled = subgroup_enabled + self.subgroup_arithmetic = subgroup_arithmetic self.subgroup_size = min(subgroup_sizes) self.max_workgroup_size = ( min(max_workgroup_sizes_x), @@ -341,6 +391,18 @@ def make_context( select_queue_families(dev_index, queue_family_count) ) + if is_cuda() or is_opencl(): + backend_name = "CUDA" if is_cuda() else "OpenCL" + if len(device_ids) != 1: + raise NotImplementedError( + f"The {backend_name} backend currently supports exactly one device." + ) + + if len(queue_families) != 1 or len(queue_families[0]) != 1: + raise NotImplementedError( + f"The {backend_name} backend currently supports exactly one queue." + ) + total_devices = len(get_devices()) # Do type checking before passing to native code @@ -358,6 +420,8 @@ def make_context( __context = Context(device_ids, queue_families) + queue_wait_idle(queue_index=None, context=__context) + return __context def is_context_initialized() -> bool: @@ -370,7 +434,110 @@ def get_context() -> Context: def get_context_handle() -> int: return get_context()._handle -def queue_wait_idle(queue_index: int = None) -> None: +def _as_positive_int(name: str, value) -> int: + try: + result = int(value) + except Exception as exc: + raise ValueError(f"{name} must be a positive integer") from exc + + if result <= 0: + raise ValueError(f"{name} must be a positive integer") + + return result + +def _as_positive_triplet(name: str, value) -> Tuple[int, int, int]: + try: + parts = list(value) + except Exception as exc: + raise ValueError(f"{name} must contain exactly 3 positive integers") from exc + + if len(parts) != 3: + raise ValueError(f"{name} must contain exactly 3 positive integers") + + return ( + _as_positive_int(f"{name}[0]", parts[0]), + _as_positive_int(f"{name}[1]", parts[1]), + _as_positive_int(f"{name}[2]", parts[2]), + ) + +def set_dummy_context_params( + subgroup_size: int = None, + max_workgroup_size: Tuple[int, int, int] = None, + max_workgroup_invocations: int = None, + max_workgroup_count: Tuple[int, int, int] = None, + max_shared_memory: int = None, +) -> None: + """ + Update cached context/device limit values for the active dummy backend context. + + This only works when a dummy context already exists. + """ + global __context + + if not is_dummy(): + raise RuntimeError( + "set_dummy_context_params() is only supported when running with backend='dummy'." + ) + + if __context is None: + __context = get_context() + + validated_subgroup_size = None + validated_max_workgroup_size = None + validated_max_workgroup_invocations = None + validated_max_workgroup_count = None + validated_max_shared_memory = None + + backend_kwargs = {} + + if subgroup_size is not None: + validated_subgroup_size = _as_positive_int("subgroup_size", subgroup_size) + backend_kwargs["subgroup_size"] = validated_subgroup_size + + if max_workgroup_size is not None: + validated_max_workgroup_size = _as_positive_triplet("max_workgroup_size", max_workgroup_size) + backend_kwargs["max_workgroup_size"] = validated_max_workgroup_size + + if max_workgroup_invocations is not None: + validated_max_workgroup_invocations = _as_positive_int( + "max_workgroup_invocations", + max_workgroup_invocations, + ) + backend_kwargs["max_workgroup_invocations"] = validated_max_workgroup_invocations + + if max_workgroup_count is not None: + validated_max_workgroup_count = _as_positive_triplet("max_workgroup_count", max_workgroup_count) + backend_kwargs["max_workgroup_count"] = validated_max_workgroup_count + + if max_shared_memory is not None: + validated_max_shared_memory = _as_positive_int("max_shared_memory", max_shared_memory) + backend_kwargs["max_compute_shared_memory_size"] = validated_max_shared_memory + + if backend_kwargs: + native.set_device_options(**backend_kwargs) + check_for_errors() + + for device in __context.device_infos: + if validated_subgroup_size is not None: + device.sub_group_size = validated_subgroup_size + + if validated_max_workgroup_size is not None: + device.max_workgroup_size = validated_max_workgroup_size + + if validated_max_workgroup_invocations is not None: + device.max_workgroup_invocations = validated_max_workgroup_invocations + + if validated_max_workgroup_count is not None: + device.max_workgroup_count = validated_max_workgroup_count + + if validated_max_shared_memory is not None: + device.max_compute_shared_memory_size = validated_max_shared_memory + + device.uniform_buffer_alignment = 0 + + __context._refresh_limits_from_device_infos() + +def queue_wait_idle(queue_index: int = None, context: Context = None) -> None: """ Wait for the specified queue to finish processing. For all queues, leave queue_index as None. @@ -378,13 +545,27 @@ def queue_wait_idle(queue_index: int = None) -> None: queue_index (int): The index of the queue. """ + if context is None: + context = get_context() + assert queue_index is None or isinstance(queue_index, int), "queue_index must be an integer or None." - assert queue_index is None or queue_index >= -1, "queue_index must be a non-negative integer or -1 (for all queues)." - assert queue_index is None or queue_index < get_context().queue_count, f"Queue index {queue_index} is out of bounds for context with {get_context().queue_count} queues." + assert queue_index is None or queue_index >= 0, "queue_index must be a non-negative integer or None (for all queues)." + assert queue_index is None or queue_index < context.queue_count, f"Queue index {queue_index} is out of bounds for context with {context.queue_count} queues." + + if queue_index is None: + for i in range(context.queue_count): + queue_wait_idle(i, context) + return - vkdispatch_native.context_queue_wait_idle(get_context_handle(), queue_index if queue_index is not None else -1) + signal_ptr = native.signal_insert(context._handle, queue_index) + check_for_errors() + + signal = Signal(signal_ptr) + signal.wait(True, queue_index) check_for_errors() + signal.free() + def destroy_context() -> None: """ Destroys the current context and cleans up resources. @@ -392,16 +573,22 @@ def destroy_context() -> None: global __context set_running(False) - if __context is not None: - handles_list = list(__context.handles_dict.values()) + if __context is None: + return + + log_info("Destroying context...") - for handle in handles_list: - handle.destroy() + handles_list = list(__context.handles_dict.values()) - assert len(__context.handles_dict) == 0, "Not all handles were destroyed!" + for handle in handles_list: + log_info(f"Destroying handle {handle._handle}...") + handle.destroy() - vkdispatch_native.context_destroy(__context._handle) - __context = None + assert len(__context.handles_dict) == 0, "Not all handles were destroyed!" + + log_info("Calling native context destroy...") + native.context_destroy(__context._handle) + __context = None atexit.register(destroy_context) @@ -409,7 +596,7 @@ def stop_threads() -> None: """ Stops all threads in the context. """ - vkdispatch_native.context_stop_threads(get_context_handle()) + native.context_stop_threads(get_context_handle()) _shutdown_once = False @@ -427,6 +614,8 @@ def _sig_handler(signum, frame): signal.signal(signum, signal.SIG_DFL) os.kill(os.getpid(), signum) -# Install from the main thread -signal.signal(signal.SIGINT, _sig_handler) -signal.signal(signal.SIGTERM, _sig_handler) \ No newline at end of file + +from .brython_utils import is_brython +if not is_brython(): + signal.signal(signal.SIGINT, _sig_handler) + signal.signal(signal.SIGTERM, _sig_handler) diff --git a/vkdispatch/base/descriptor_set.py b/vkdispatch/base/descriptor_set.py index e1814cef..e9d2823a 100644 --- a/vkdispatch/base/descriptor_set.py +++ b/vkdispatch/base/descriptor_set.py @@ -1,4 +1,4 @@ -import vkdispatch_native +from ..backends.backend_selection import native from .errors import check_for_errors @@ -8,6 +8,7 @@ from .image import Sampler from .init import log_info +from .init import is_cuda class DescriptorSet(Handle): """TODO: Docstring""" @@ -15,22 +16,25 @@ def __init__(self, compute_plan: ComputePlan) -> None: super().__init__() self._bound_resources = [] - handle = vkdispatch_native.descriptor_set_create(compute_plan._handle) + handle = native.descriptor_set_create(compute_plan._handle) check_for_errors() self.register_handle(handle) self.register_parent(compute_plan) def _destroy(self) -> None: - vkdispatch_native.descriptor_set_destroy(self._handle) + native.descriptor_set_destroy(self._handle) check_for_errors() def __del__(self) -> None: self.destroy() def bind_buffer(self, buffer: Buffer, binding: int, offset: int = 0, range: int = 0, uniform: bool = False, read_access: bool = True, write_access: bool = True) -> None: + if write_access and not getattr(buffer, "is_writable", True): + raise ValueError("Cannot bind a read-only buffer with write access enabled.") + self.register_parent(buffer) - vkdispatch_native.descriptor_set_write_buffer( + native.descriptor_set_write_buffer( self._handle, binding, buffer._handle, @@ -45,7 +49,7 @@ def bind_buffer(self, buffer: Buffer, binding: int, offset: int = 0, range: int def bind_sampler(self, sampler: Sampler, binding: int, read_access: bool = True, write_access: bool = True) -> None: self.register_parent(sampler) - vkdispatch_native.descriptor_set_write_image( + native.descriptor_set_write_image( self._handle, binding, sampler.image._handle, @@ -53,4 +57,11 @@ def bind_sampler(self, sampler: Sampler, binding: int, read_access: bool = True, 1 if read_access else 0, 1 if write_access else 0 ) - check_for_errors() \ No newline at end of file + check_for_errors() + + def set_inline_uniform_payload(self, payload: bytes) -> None: + if not is_cuda(): + raise RuntimeError("Inline uniform payloads are currently only supported on CUDA backends.") + + native.descriptor_set_write_inline_uniform(self._handle, payload) + check_for_errors() diff --git a/vkdispatch/base/dtype.py b/vkdispatch/base/dtype.py index 9c94434a..62ea81d3 100644 --- a/vkdispatch/base/dtype.py +++ b/vkdispatch/base/dtype.py @@ -1,12 +1,11 @@ -import numpy as np +from typing import Any, Optional -from typing import Optional +from ..compat import numpy_compat as npc class dtype: name: str item_size: int glsl_type: str - glsl_type_extern: Optional[str] = None dimentions: int format_str: str child_type: "dtype" @@ -22,8 +21,21 @@ class _Scalar(dtype): shape = (1,) numpy_shape = (1,) true_numpy_shape = () + child_type = None scalar = None +class _I16(_Scalar): + name = "int16" + item_size = 2 + glsl_type = "int16_t" + format_str = "%d" + +class _U16(_Scalar): + name = "uint16" + item_size = 2 + glsl_type = "uint16_t" + format_str = "%u" + class _I32(_Scalar): name = "int32" item_size = 4 @@ -36,20 +48,61 @@ class _U32(_Scalar): glsl_type = "uint" format_str = "%u" +class _I64(_Scalar): + name = "int64" + item_size = 8 + glsl_type = "int64_t" + format_str = "%lld" + +class _U64(_Scalar): + name = "uint64" + item_size = 8 + glsl_type = "uint64_t" + format_str = "%llu" + +class _F16(_Scalar): + name = "float16" + item_size = 2 + glsl_type = "float16_t" + format_str = "%f" + class _F32(_Scalar): name = "float32" item_size = 4 glsl_type = "float" format_str = "%f" +class _F64(_Scalar): + name = "float64" + item_size = 8 + glsl_type = "double" + format_str = "%lf" + +int16 = _I16 # type: ignore +uint16 = _U16 # type: ignore int32 = _I32 # type: ignore uint32 = _U32 # type: ignore +int64 = _I64 # type: ignore +uint64 = _U64 # type: ignore +float16 = _F16 # type: ignore float32 = _F32 # type: ignore +float64 = _F64 # type: ignore class _Complex(dtype): dimentions = 0 child_count = 2 +class _CF32(_Complex): + name = "complex32" + item_size = 4 + glsl_type = "f16vec2" + format_str = "(%f, %f)" + child_type = float16 + shape = (2,) + numpy_shape = (1,) + true_numpy_shape = () + scalar = None + class _CF64(_Complex): name = "complex64" item_size = 8 @@ -61,11 +114,64 @@ class _CF64(_Complex): true_numpy_shape = () scalar = None +class _CF128(_Complex): + name = "complex128" + item_size = 16 + glsl_type = "dvec2" + format_str = "(%lf, %lf)" + child_type = float64 + shape = (2,) + numpy_shape = (1,) + true_numpy_shape = () + scalar = None + +complex32 = _CF32 # type: ignore complex64 = _CF64 # type: ignore +complex128 = _CF128 # type: ignore class _Vector(dtype): dimentions = 1 +# --- float16 vectors --- + +class _V2F16(_Vector): + name = "hvec2" + item_size = 4 + glsl_type = "f16vec2" + format_str = "(%f, %f)" + child_type = float16 + child_count = 2 + shape = (2,) + numpy_shape = (2,) + true_numpy_shape = (2,) + scalar = float16 + +class _V3F16(_Vector): + name = "hvec3" + item_size = 6 + glsl_type = "f16vec3" + format_str = "(%f, %f, %f)" + child_type = float16 + child_count = 3 + shape = (3,) + numpy_shape = (3,) + true_numpy_shape = (3,) + scalar = float16 + +class _V4F16(_Vector): + name = "hvec4" + item_size = 8 + glsl_type = "f16vec4" + format_str = "(%f, %f, %f, %f)" + child_type = float16 + child_count = 4 + shape = (4,) + numpy_shape = (4,) + true_numpy_shape = (4,) + scalar = float16 + +# --- float32 vectors --- + class _V2F32(_Vector): name = "vec2" item_size = 8 @@ -80,9 +186,8 @@ class _V2F32(_Vector): class _V3F32(_Vector): name = "vec3" - item_size = 16 + item_size = 12 glsl_type = "vec3" - glsl_type_extern = "vec4" format_str = "(%f, %f, %f)" child_type = float32 child_count = 3 @@ -103,6 +208,84 @@ class _V4F32(_Vector): true_numpy_shape = (4,) scalar = float32 +# --- float64 vectors --- + +class _V2F64(_Vector): + name = "dvec2" + item_size = 16 + glsl_type = "dvec2" + format_str = "(%lf, %lf)" + child_type = float64 + child_count = 2 + shape = (2,) + numpy_shape = (2,) + true_numpy_shape = (2,) + scalar = float64 + +class _V3F64(_Vector): + name = "dvec3" + item_size = 24 + glsl_type = "dvec3" + format_str = "(%lf, %lf, %lf)" + child_type = float64 + child_count = 3 + shape = (3,) + numpy_shape = (3,) + true_numpy_shape = (3,) + scalar = float64 + +class _V4F64(_Vector): + name = "dvec4" + item_size = 32 + glsl_type = "dvec4" + format_str = "(%lf, %lf, %lf, %lf)" + child_type = float64 + child_count = 4 + shape = (4,) + numpy_shape = (4,) + true_numpy_shape = (4,) + scalar = float64 + +# --- int16 vectors --- + +class _V2I16(_Vector): + name = "ihvec2" + item_size = 4 + glsl_type = "i16vec2" + format_str = "(%d, %d)" + child_type = int16 + child_count = 2 + shape = (2,) + numpy_shape = (2,) + true_numpy_shape = (2,) + scalar = int16 + +class _V3I16(_Vector): + name = "ihvec3" + item_size = 6 + glsl_type = "i16vec3" + format_str = "(%d, %d, %d)" + child_type = int16 + child_count = 3 + shape = (3,) + numpy_shape = (3,) + true_numpy_shape = (3,) + scalar = int16 + +class _V4I16(_Vector): + name = "ihvec4" + item_size = 8 + glsl_type = "i16vec4" + format_str = "(%d, %d, %d, %d)" + child_type = int16 + child_count = 4 + shape = (4,) + numpy_shape = (4,) + true_numpy_shape = (4,) + scalar = int16 + +# --- int32 vectors --- + class _V2I32(_Vector): name = "ivec2" item_size = 8 @@ -117,9 +300,8 @@ class _V2I32(_Vector): class _V3I32(_Vector): name = "ivec3" - item_size = 16 + item_size = 12 glsl_type = "ivec3" - glsl_type_extern = "ivec4" format_str = "(%d, %d, %d)" child_type = int32 child_count = 3 @@ -140,6 +322,46 @@ class _V4I32(_Vector): true_numpy_shape = (4,) scalar = int32 +# --- uint16 vectors --- + +class _V2U16(_Vector): + name = "uhvec2" + item_size = 4 + glsl_type = "u16vec2" + format_str = "(%u, %u)" + child_type = uint16 + child_count = 2 + shape = (2,) + numpy_shape = (2,) + true_numpy_shape = (2,) + scalar = uint16 + +class _V3U16(_Vector): + name = "uhvec3" + item_size = 6 + glsl_type = "u16vec3" + format_str = "(%u, %u, %u)" + child_type = uint16 + child_count = 3 + shape = (3,) + numpy_shape = (3,) + true_numpy_shape = (3,) + scalar = uint16 + +class _V4U16(_Vector): + name = "uhvec4" + item_size = 8 + glsl_type = "u16vec4" + format_str = "(%u, %u, %u, %u)" + child_type = uint16 + child_count = 4 + shape = (4,) + numpy_shape = (4,) + true_numpy_shape = (4,) + scalar = uint16 + +# --- uint32 vectors --- + class _V2U32(_Vector): name = "uvec2" item_size = 8 @@ -154,9 +376,8 @@ class _V2U32(_Vector): class _V3U32(_Vector): name = "uvec3" - item_size = 16 + item_size = 12 glsl_type = "uvec3" - glsl_type_extern = "uvec4" format_str = "(%u, %u, %u)" child_type = uint32 child_count = 3 @@ -177,12 +398,24 @@ class _V4U32(_Vector): true_numpy_shape = (4,) scalar = uint32 +hvec2 = _V2F16 # type: ignore +hvec3 = _V3F16 # type: ignore +hvec4 = _V4F16 # type: ignore vec2 = _V2F32 # type: ignore vec3 = _V3F32 # type: ignore vec4 = _V4F32 # type: ignore +dvec2 = _V2F64 # type: ignore +dvec3 = _V3F64 # type: ignore +dvec4 = _V4F64 # type: ignore +ihvec2 = _V2I16 # type: ignore +ihvec3 = _V3I16 # type: ignore +ihvec4 = _V4I16 # type: ignore ivec2 = _V2I32 # type: ignore ivec3 = _V3I32 # type: ignore ivec4 = _V4I32 # type: ignore +uhvec2 = _V2U16 # type: ignore +uhvec3 = _V3U16 # type: ignore +uhvec4 = _V4U16 # type: ignore uvec2 = _V2U32 # type: ignore uvec3 = _V3U32 # type: ignore uvec4 = _V4U32 # type: ignore @@ -202,6 +435,18 @@ class _M2F32(_Matrix): true_numpy_shape = (2, 2) scalar = float32 +class _M3F32(_Matrix): + name = "mat3" + item_size = 36 + glsl_type = "mat3" + format_str = "\\\\n[%f, %f, %f]\\\\n[%f, %f, %f]\\\\n[%f, %f, %f]\\\\n" + child_type = vec3 + child_count = 3 + shape = (3, 3) + numpy_shape = (3, 3) + true_numpy_shape = (3, 3) + scalar = float32 + class _M4F32(_Matrix): name = "mat4" item_size = 64 @@ -215,35 +460,28 @@ class _M4F32(_Matrix): scalar = float32 mat2 = _M2F32 +mat3 = _M3F32 mat4 = _M4F32 +# Maps scalar dtype -> {count: vector_dtype} +_VECTOR_TABLE = { + int16: {1: int16, 2: ihvec2, 3: ihvec3, 4: ihvec4}, + uint16: {1: uint16, 2: uhvec2, 3: uhvec3, 4: uhvec4}, + int32: {1: int32, 2: ivec2, 3: ivec3, 4: ivec4}, + uint32: {1: uint32, 2: uvec2, 3: uvec3, 4: uvec4}, + float16: {1: float16, 2: hvec2, 3: hvec3, 4: hvec4}, + float32: {1: float32, 2: vec2, 3: vec3, 4: vec4}, + float64: {1: float64, 2: dvec2, 3: dvec3, 4: dvec4}, +} + def to_vector(dtype: dtype, count: int) -> dtype: # type: ignore - if count < 2 or count > 4: + if count < 1 or count > 4: raise ValueError(f"Unsupported count ({count})!") - if dtype == int32: - if count == 2: - return ivec2 - elif count == 3: - return ivec3 - elif count == 4: - return ivec4 - elif dtype == uint32: - if count == 2: - return uvec2 - elif count == 3: - return uvec3 - elif count == 4: - return uvec4 - elif dtype == float32: - if count == 2: - return vec2 - elif count == 3: - return vec3 - elif count == 4: - return vec4 - else: + table = _VECTOR_TABLE.get(dtype) + if table is None: raise ValueError(f"Unsupported dtype ({dtype})!") + return table[count] def is_dtype(in_type: dtype) -> bool: return issubclass(in_type, dtype) # type: ignore @@ -260,26 +498,223 @@ def is_vector(dtype: dtype) -> bool: def is_matrix(dtype: dtype) -> bool: return issubclass(dtype, _Matrix) # type: ignore -def from_numpy_dtype(dtype: type) -> dtype: - if dtype == np.int32: - return int32 - elif dtype == np.uint32: - return uint32 - elif dtype == np.float32: +def is_float_dtype(dtype: dtype) -> bool: + if not is_scalar(dtype): + dtype = dtype.scalar + + return dtype == float16 or dtype == float32 or dtype == float64 + +def is_integer_dtype(dtype: dtype) -> bool: + if not is_scalar(dtype): + dtype = dtype.scalar + + return dtype in (int16, uint16, int32, uint32, int64, uint64) + +# Promotion precedence: float64 > float32 > float16 > int64 > int32 > int16 > uint64 > uint32 > uint16 +_SCALAR_RANK = { + uint16: 0, + int16: 1, + uint32: 2, + int32: 3, + uint64: 4, + int64: 5, + float16: 6, + float32: 7, + float64: 8, +} + +_COMPLEX_FROM_FLOAT = { + float16: complex32, + float32: complex64, + float64: complex128, +} + +def complex_from_float(dtype: dtype) -> dtype: + if not is_scalar(dtype): + raise ValueError(f"Unsupported dtype ({dtype})!") + + result = _COMPLEX_FROM_FLOAT.get(dtype) + if result is None: + raise ValueError(f"Unsupported complex base dtype ({dtype})!") + return result + +def _promote_scalar(dtype: dtype) -> dtype: + """Return the floating-point type that matches the width of *dtype*. + + Used by make_floating_dtype to convert integer scalars to their natural + floating counterpart. + """ + if dtype == int16 or dtype == uint16: + return float16 + if dtype == int32 or dtype == uint32: return float32 - elif dtype == np.complex64: - return complex64 + if dtype == int64 or dtype == uint64: + return float64 + return dtype + +def make_floating_dtype(dtype: dtype) -> dtype: + if is_scalar(dtype): + return _promote_scalar(dtype) + elif is_vector(dtype): + return to_vector(_promote_scalar(dtype.scalar), dtype.child_count) + elif is_matrix(dtype): + return dtype + elif is_complex(dtype): + return dtype else: raise ValueError(f"Unsupported dtype ({dtype})!") -def to_numpy_dtype(shader_type: dtype) -> np.dtype: - if shader_type == int32: - return np.int32 - elif shader_type == uint32: - return np.uint32 - elif shader_type == float32: - return np.float32 - elif shader_type == complex64: - return np.complex64 +def vector_size(dtype: dtype) -> int: + if not is_vector(dtype): + raise ValueError(f"Type ({dtype}) is not a vector!") + + return dtype.child_count + +def cross_scalar_scalar(dtype1: dtype, dtype2: dtype) -> dtype: + assert is_scalar(dtype1) and is_scalar(dtype2), "Both types must be scalar types!" + + r1 = _SCALAR_RANK[dtype1] + r2 = _SCALAR_RANK[dtype2] + return dtype1 if r1 >= r2 else dtype2 + +def cross_vector_scalar(dtype1: dtype, dtype2: dtype) -> dtype: + assert is_vector(dtype1) and is_scalar(dtype2), "First type must be vector type and second type must be scalar type!" + + return to_vector(cross_scalar_scalar(dtype1.scalar, dtype2), dtype1.child_count) + +def cross_vector_vector(dtype1: dtype, dtype2: dtype) -> dtype: + assert is_vector(dtype1) and is_vector(dtype2), "Both types must be vector types!" + + if dtype1.child_count != dtype2.child_count: + raise ValueError(f"Cannot cross types of vectors of two sizes! ({dtype1.child_count} != {dtype2.child_count})") + + return to_vector(cross_scalar_scalar(dtype1.scalar, dtype2.scalar), dtype1.child_count) + +def cross_vector(dtype1: dtype, dtype2: dtype) -> dtype: + assert is_vector(dtype1), "First type must be vector type!" + + if is_vector(dtype2): + return cross_vector_vector(dtype1, dtype2) + elif is_scalar(dtype2): + return cross_vector_scalar(dtype1, dtype2) + elif is_complex(dtype2): + raise ValueError("Cannot cross vector and complex types!") else: + raise ValueError("Second type must be vector or scalar type!") + +def cross_matrix(dtype1: dtype, dtype2: dtype) -> dtype: + assert is_matrix(dtype1), "Both types must be matrix types!" + + if is_matrix(dtype2): + if dtype1.shape != dtype2.shape: + raise ValueError( + f"Cannot cross types of matrices with incompatible shapes! ({dtype1.shape} and {dtype2.shape})") + + return dtype1 + + if is_vector(dtype2) or is_complex(dtype2): + raise ValueError("Cannot cross matrix and vector/complex types!") + + if is_scalar(dtype2): + return dtype1 + + raise ValueError("Second type must be matrix or scalar type!") + +def cross_type(dtype1: dtype, dtype2: dtype) -> dtype: + if is_matrix(dtype1): + return cross_matrix(dtype1, dtype2) + elif is_matrix(dtype2): + return cross_matrix(dtype2, dtype1) + + if is_vector(dtype1): + return cross_vector(dtype1, dtype2) + elif is_vector(dtype2): + return cross_vector(dtype2, dtype1) + + if is_complex(dtype1): + if is_complex(dtype2): + return complex_from_float(cross_scalar_scalar(dtype1.child_type, dtype2.child_type)) + if is_scalar(dtype2): + return complex_from_float(cross_scalar_scalar(dtype1.child_type, _promote_scalar(dtype2))) + raise ValueError("Cannot cross complex and non-scalar types!") + elif is_complex(dtype2): + if is_scalar(dtype1): + return complex_from_float(cross_scalar_scalar(dtype2.child_type, _promote_scalar(dtype1))) + raise ValueError("Cannot cross complex and non-scalar types!") + + if is_scalar(dtype1) and is_scalar(dtype2): + return cross_scalar_scalar(dtype1, dtype2) + +def cross_multiply_type(dtype1: dtype, dtype2: dtype) -> dtype: + """Resolve result type for multiplication. + + Unlike ``cross_type``, multiplication is order-sensitive for matrix/vector + combinations and supports ``matN * vecN`` and ``vecN * matN``. + """ + if is_matrix(dtype1) and is_vector(dtype2): + if dtype1.child_count != dtype2.child_count: + raise ValueError( + f"Cannot multiply matrix '{dtype1.name}' and vector '{dtype2.name}' with incompatible dimensions!" + ) + if dtype1.scalar != float32 or dtype2.scalar != float32: + raise ValueError("Matrix/vector multiplication only supports float32 matrix and vector types.") + return dtype2 + + if is_vector(dtype1) and is_matrix(dtype2): + if dtype1.child_count != dtype2.child_count: + raise ValueError( + f"Cannot multiply vector '{dtype1.name}' and matrix '{dtype2.name}' with incompatible dimensions!" + ) + if dtype1.scalar != float32 or dtype2.scalar != float32: + raise ValueError("Matrix/vector multiplication only supports float32 matrix and vector types.") + return dtype1 + + return cross_type(dtype1, dtype2) + +def from_numpy_dtype(dtype: Any) -> dtype: + dtype_name = npc.host_dtype_name(dtype) + + _NAME_MAP = { + "int16": int16, + "uint16": uint16, + "int32": int32, + "uint32": uint32, + "int64": int64, + "uint64": uint64, + "float16": float16, + "float32": float32, + "float64": float64, + "complex32": complex32, + "complex64": complex64, + "complex128": complex128, + } + + result = _NAME_MAP.get(dtype_name) + if result is None: + raise ValueError(f"Unsupported dtype ({dtype})!") + return result + + +def to_numpy_dtype(shader_type: dtype) -> Any: + _TYPE_MAP = { + int16: "int16", + uint16: "uint16", + int32: "int32", + uint32: "uint32", + int64: "int64", + uint64: "uint64", + float16: "float16", + float32: "float32", + float64: "float64", + complex32: "complex32", + complex64: "complex64", + complex128: "complex128", + } + + name = _TYPE_MAP.get(shader_type) + if name is None: raise ValueError(f"Unsupported shader_type ({shader_type})!") + + if npc.HAS_NUMPY and hasattr(npc.numpy_module(), name): + return getattr(npc.numpy_module(), name) + return npc.host_dtype(name) diff --git a/vkdispatch/base/errors.py b/vkdispatch/base/errors.py index 07d3324a..ca6068b1 100644 --- a/vkdispatch/base/errors.py +++ b/vkdispatch/base/errors.py @@ -1,4 +1,4 @@ -import vkdispatch_native +from ..backends.backend_selection import native running = True @@ -17,7 +17,7 @@ def check_for_errors(): Check for errors in the vkdispatch_native library and raise a RuntimeError if found. """ global running - error = vkdispatch_native.get_error_string() + error = native.get_error_string() if error == 0 or not running: return @@ -26,18 +26,21 @@ def check_for_errors(): raise RuntimeError(error) else: raise RuntimeError("Unknown error occurred") - + + def check_for_compute_stage_errors(): """ Check for errors in the shader compilation stage of the vkdispatch_native library and raise a RuntimeError if found. """ - error = vkdispatch_native.get_error_string() + error = native.get_error_string() if error == 0: return if not isinstance(error, str): raise RuntimeError("Unknown error occurred") + + print("Shader compilation error:\n", error) - raise RuntimeError("Error occurred in compute stage") \ No newline at end of file + raise RuntimeError("Error occurred in compute stage") diff --git a/vkdispatch/base/image.py b/vkdispatch/base/image.py index 30b8c92a..f78ec483 100644 --- a/vkdispatch/base/image.py +++ b/vkdispatch/base/image.py @@ -1,23 +1,13 @@ import typing from enum import Enum -import numpy as np - -import vkdispatch_native +from ..backends.backend_selection import native +from ..compat import numpy_compat as npc from . import dtype as vdt from .context import Handle -__MAPPING__ = { - (np.uint8, 1), - (np.uint8, 1), - (np.uint8, 2), - (np.uint8, 2), - (np.uint8, 3), - (np.uint8, 3), - (np.uint8, 4), - (np.uint8, 4), -} +__MAPPING__ = set() class image_format(Enum): # TODO: Fix class naming scheme to adhere to convention @@ -82,46 +72,6 @@ def select_image_format(dtype: vdt.dtype, channels: int) -> image_format: # } # return __MAPPING__[(dtype, channels)] - """ - - if dtype == np.uint8: - if channels == 1: - return image_format.R8_UINT - elif channels == 2: - return image_format.R8G8_UINT - elif channels == 3: - return image_format.R8G8B8_UINT - elif channels == 4: - return image_format.R8G8B8A8_UINT - elif dtype == np.int8: - if channels == 1: - return image_format.R8_SINT - elif channels == 2: - return image_format.R8G8_SINT - elif channels == 3: - return image_format.R8G8B8_SINT - elif channels == 4: - return image_format.R8G8B8A8_SINT - elif dtype == np.uint16: - if channels == 1: - return image_format.R16_UINT - elif channels == 2: - return image_format.R16G16_UINT - elif channels == 3: - return image_format.R16G16B16_UINT - elif channels == 4: - return image_format.R16G16B16A16_UINT - elif dtype == np.int16: - if channels == 1: - return image_format.R16_SINT - elif channels == 2: - return image_format.R16G16_SINT - elif channels == 3: - return image_format.R16G16B16_SINT - elif channels == 4: - return image_format.R16G16B16A16_SINT - el """ - if dtype == vdt.uint32: if channels == 1: return image_format.R32_UINT @@ -268,7 +218,7 @@ def __init__(self, self.image = image - handle = vkdispatch_native.image_create_sampler( + handle = native.image_create_sampler( self.context._handle, mag_filter.value, min_filter.value, @@ -284,7 +234,7 @@ def __init__(self, self.register_parent(image) def _destroy(self): - vkdispatch_native.image_destroy_sampler(self._handle) + native.image_destroy_sampler(self._handle) def __del__(self) -> None: self.destroy() @@ -346,13 +296,13 @@ def __init__( if channels == 1: self.array_shape = self.array_shape[:-1] - self.block_size: int = vkdispatch_native.image_format_block_size( + self.block_size: int = native.image_format_block_size( self.format.value ) - self.mem_size: int = np.prod(self.shape) * self.block_size + self.mem_size: int = npc.prod(self.shape) * self.block_size - handle: int = vkdispatch_native.image_create( + handle: int = native.image_create( self.context._handle, self.extent, self.layers, @@ -365,17 +315,27 @@ def __init__( self.register_handle(handle) def _destroy(self) -> None: - vkdispatch_native.image_destroy(self._handle) + native.image_destroy(self._handle) def __del__(self) -> None: self.destroy() - def write(self, data: np.ndarray, device_index: int = -1) -> None: - if data.size * np.dtype(data.dtype).itemsize != self.mem_size: - raise ValueError(f"Numpy buffer sizes must match! {data.size * np.dtype(data.dtype).itemsize} != {self.mem_size}") - vkdispatch_native.image_write( + def write(self, data: typing.Any, device_index: int = -1) -> None: + if npc.is_array_like(data): + true_data = npc.as_contiguous_bytes(data) + data_size = npc.array_nbytes(data) + elif npc.is_bytes_like(data): + true_data = npc.ensure_bytes(data) + data_size = len(true_data) + else: + raise TypeError("Expected array-like or bytes-like image input") + + if data_size != self.mem_size: + raise ValueError(f"Image buffer sizes must match! {data_size} != {self.mem_size}") + + native.image_write( self._handle, - np.ascontiguousarray(data).tobytes(), + true_data, [0, 0, 0], self.extent, 0, @@ -383,17 +343,17 @@ def write(self, data: np.ndarray, device_index: int = -1) -> None: device_index, ) - def read(self, device_index: int = 0) -> np.ndarray: + def read(self, device_index: int = 0): true_scalar = self.dtype.scalar if self.dtype.scalar is None: true_scalar = self.dtype - out_size = np.prod(self.array_shape) * true_scalar.item_size - out_bytes = vkdispatch_native.image_read( + out_size = npc.prod(self.array_shape) * true_scalar.item_size + out_bytes = native.image_read( self._handle, out_size, [0, 0, 0], self.extent, 0, self.layers, device_index ) - return np.frombuffer(out_bytes, dtype=vdt.to_numpy_dtype(true_scalar)).reshape(self.array_shape) + return npc.from_buffer(out_bytes, dtype=vdt.to_numpy_dtype(true_scalar), shape=self.array_shape) def sample(self, mag_filter: Filter = Filter.LINEAR, @@ -428,7 +388,7 @@ def __class_getitem__(cls, arg: vdt.dtype) -> type: class Image2D(Image): def __init__( - self, shape: typing.Tuple[int, int], dtype: type = np.float32, channels: int = 1, enable_mipmaps: bool = False + self, shape: typing.Tuple[int, int], dtype: type = vdt.float32, channels: int = 1, enable_mipmaps: bool = False ) -> None: assert len(shape) == 2, "Shape must be 2D!" super().__init__(shape, 1, dtype, channels, image_view_type.VIEW_TYPE_2D, enable_mipmaps) @@ -443,7 +403,7 @@ def __init__( self, shape: typing.Tuple[int, int], layers: int, - dtype: type = np.float32, + dtype: type = vdt.float32, channels: int = 1, enable_mipmaps: bool = False ) -> None: @@ -459,7 +419,7 @@ def __class_getitem__(cls, arg: tuple) -> type: class Image3D(Image): def __init__( - self, shape: typing.Tuple[int, int, int], dtype: type = np.float32, channels: int = 1, enable_mipmaps: bool = False + self, shape: typing.Tuple[int, int, int], dtype: type = vdt.float32, channels: int = 1, enable_mipmaps: bool = False ) -> None: assert len(shape) == 3, "Shape must be 3D!" super().__init__(shape, 1, dtype, channels, image_view_type.VIEW_TYPE_3D, enable_mipmaps) diff --git a/vkdispatch/base/init.py b/vkdispatch/base/init.py index 474c0813..a4aa7c26 100644 --- a/vkdispatch/base/init.py +++ b/vkdispatch/base/init.py @@ -1,12 +1,24 @@ -import typing + from enum import Enum import os +from typing import Tuple, List, Optional import inspect from .errors import check_for_errors - -import vkdispatch_native +from ..backends.backend_selection import ( + BACKEND_CUDA, + BACKEND_OPENCL, + BACKEND_VULKAN, + BACKEND_DUMMY, + BackendUnavailableError, + clear_active_backend, + get_active_backend_name, + get_backend_module, + native, + get_environment_backend, + set_active_backend, +) # string representations of device types device_type_id_to_str_dict = { @@ -31,7 +43,7 @@ 4: 1 } -def get_queue_type_strings(queue_type: int, verbose: bool) -> typing.List[str]: +def get_queue_type_strings(queue_type: int, verbose: bool) -> List[str]: """ A function which returns a list of strings representing the queue's supported operations. @@ -154,9 +166,9 @@ def __init__( uniform_and_storage_buffer_16_bit_access: int, storage_push_constant_16: int, storage_input_output_16: int, - max_workgroup_size: typing.Tuple[int, int, int], + max_workgroup_size: Tuple[int, int, int], max_workgroup_invocations: int, - max_workgroup_count: typing.Tuple[int, int, int], + max_workgroup_count: Tuple[int, int, int], max_bound_descriptor_sets: int, max_push_constant_size: int, max_storage_buffer_range: int, @@ -167,9 +179,13 @@ def __init__( supported_operations: int, quad_operations_in_all_stages: int, max_compute_shared_memory_size: int, - queue_properties: typing.List[typing.Tuple[int, int]] + queue_properties: List[Tuple[int, int]], + scalar_block_layout: int, + timeline_semaphores: int, + uuid: Optional[bytes], ): self.dev_index = dev_index + self.sorted_index = -1 # to be set later self.version_variant = version_variant self.version_major = version_major @@ -216,6 +232,10 @@ def __init__( self.queue_properties = queue_properties + self.scalar_block_layout = scalar_block_layout + self.timeline_semaphores = timeline_semaphores + self.uuid = uuid + def is_nvidia(self) -> bool: """ A method which checks if the device is an NVIDIA device. @@ -245,9 +265,21 @@ def get_info_string(self, verbose: bool = False) -> str: str: A string representation of the device information. """ - result = f"Device {self.dev_index}: {self.device_name}\n" + result = f"Device {self.sorted_index}: {self.device_name}\n" - result += f"\tVulkan Version: {self.version_major}.{self.version_minor}.{self.version_patch}\n" + backend_type = "Vulkan" + version_number = f"{self.version_major}.{self.version_minor}.{self.version_patch}" + + if is_cuda(): + backend_type = "CUDA Compute Capability" + version_number = f"{self.version_major}.{self.version_minor}" + elif is_opencl(): + backend_type = "OpenCL" + version_number = f"{self.version_major}.{self.version_minor}" + elif is_dummy(): + backend_type = "Dummy" + + result += f"\t{backend_type} Version: {version_number}\n" result += f"\tDevice Type: {device_type_id_to_str_dict[self.device_type]}\n" if self.version_variant != 0: @@ -258,10 +290,23 @@ def get_info_string(self, verbose: bool = False) -> str: result += f"\tVendor ID={self.vendor_id}\n" result += f"\tDevice ID={self.device_id}\n" + + if self.uuid is not None: + uuid_str = '-'.join([ + self.uuid[0:4].hex(), + self.uuid[4:6].hex(), + self.uuid[6:8].hex(), + self.uuid[8:10].hex(), + self.uuid[10:16].hex(), + ]) + result += f"\tUUID: {uuid_str}\n" + result += "\n\tFeatures:\n" if verbose: result += f"\t\tFloat32 Atomics: {self.shader_buffer_float32_atomics == 1}\n" + result += f"\t\tScalar Block Layout: {self.scalar_block_layout == 1}\n" + result += f"\t\tTimeline Semaphores: {self.timeline_semaphores == 1}\n" result += f"\t\tFloat32 Atomic Add: {self.shader_buffer_float32_atomic_add == 1}\n" @@ -306,13 +351,16 @@ def get_info_string(self, verbose: bool = False) -> str: result += f"\t\t{ii} (count={queue[0]}, flags={hex(queue[1])}): " result += " | ".join(queue_types) + "\n" + + return result def __repr__(self) -> str: return self.get_info_string() __initilized_instance: bool = False - +__device_infos: List[DeviceInfo] = None +__backend_name: str = BACKEND_VULKAN def is_initialized() -> bool: """ @@ -326,7 +374,154 @@ def is_initialized() -> bool: return __initilized_instance -def initialize(debug_mode: bool = False, log_level: LogLevel = LogLevel.WARNING, loader_debug_logs: bool = False): +def get_cuda_device_map(): + """ + Returns a dict mapping CUDA device index -> UUID (bytes). + Format: { 0: b'\x00...', 1: b'\x01...' } + + If the CUDA driver bindings are not available, returns None. + """ + try: + from cuda.bindings import driver + except (ImportError, ModuleNotFoundError): + __log_noinit("'cuda-python' not installed, skipping CUDA device matching", level=LogLevel.WARNING) + return None + + try: + err, = driver.cuInit(0) + if err != driver.CUresult.CUDA_SUCCESS: + raise RuntimeError("Failed to initialize CUDA Driver API") + + err, count = driver.cuDeviceGetCount() + if err != driver.CUresult.CUDA_SUCCESS: + raise RuntimeError("Failed to get CUDA device count") + + uuid_map = {} + + for i in range(count): + err, device = driver.cuDeviceGet(i) + if err != driver.CUresult.CUDA_SUCCESS: + continue + + err, uuid_bytes = driver.cuDeviceGetUuid(device) + if err == driver.CUresult.CUDA_SUCCESS: + assert len(uuid_bytes.bytes) == 16 + uuid_map[i] = uuid_bytes.bytes + except Exception as e: + __log_noinit(f"Error while querying CUDA devices: {e}", level=LogLevel.WARNING) + return None + + return uuid_map + + +def _set_initialized_state(backend_name: str, devices: List[DeviceInfo]) -> None: + global __initilized_instance + global __backend_name + global __device_infos + + __initilized_instance = True + __backend_name = backend_name + __device_infos = devices + + for ii, dev in enumerate(__device_infos): + dev.sorted_index = ii + + +def _build_no_gpu_backend_error( + vulkan_error: Exception, + cuda_python_error: Exception, + opencl_error: Exception, +) -> RuntimeError: + return RuntimeError( + "vkdispatch could not find an available GPU backend.\n" + f"Vulkan backend unavailable: {vulkan_error}\n" + f"CUDA Python backend unavailable: {cuda_python_error}\n" + f"OpenCL backend unavailable: {opencl_error}\n" + "Install the Vulkan backend with `pip install vkdispatch`, or install CUDA support " + "(`pip install cuda-python`), or install OpenCL support (`pip install pyopencl`), " + "or explicitly use `vd.initialize(backend='dummy')` " + "for codegen-only workflows." + ) + + +def _build_vulkan_backend_error(vulkan_error: Exception) -> RuntimeError: + return RuntimeError( + "vkdispatch could not load the Vulkan backend.\n" + f"Vulkan backend unavailable: {vulkan_error}\n" + "Install the Vulkan backend with `pip install vkdispatch`, use a CUDA backend " + "(`pip install cuda-python`), use an OpenCL backend (`pip install pyopencl`), " + "or explicitly use `vd.initialize(backend='dummy')` " + "for codegen-only workflows." + ) + + +def _initialize_with_backend( + backend_name: str, + debug_mode: bool, + log_level: LogLevel, + loader_debug_logs: bool, +) -> None: + global __initilized_instance + + set_active_backend(backend_name) + + try: + if loader_debug_logs and backend_name == BACKEND_VULKAN: + os.environ["VK_LOADER_DEBUG"] = "all" + + # Force import now so backend availability errors are distinct from runtime init errors. + get_backend_module(backend_name) + + native.init(debug_mode, log_level.value) + check_for_errors() + + devivces = [ + DeviceInfo(ii, *dev_obj) + for ii, dev_obj in enumerate(native.get_devices()) + ] + + if backend_name != BACKEND_VULKAN: + _set_initialized_state(backend_name, devivces) + return + + is_cuda = any(dev.is_nvidia() for dev in devivces) + cuda_uuids = get_cuda_device_map() if is_cuda else None + + if cuda_uuids is None: + _set_initialized_state(backend_name, devivces) + return + + # try to match CUDA devices to Vulkan devices by UUID + cuda_uuid_to_index = { + uuid_bytes: cuda_index + for cuda_index, uuid_bytes in cuda_uuids.items() + } + matched_devices: List[Tuple[int, DeviceInfo]] = [] + unmatched_devices: List[DeviceInfo] = [] + for dev in devivces: + if dev.uuid is not None and dev.uuid in cuda_uuid_to_index: + matched_devices.append((cuda_uuid_to_index[dev.uuid], dev)) + else: + unmatched_devices.append(dev) + + matched_devices.sort(key=lambda x: x[0]) + result = [dev for _, dev in matched_devices] + unmatched_devices + + for dev_id, dev in enumerate(result): + dev.sorted_index = dev_id + + _set_initialized_state(backend_name, result) + except Exception: + if not __initilized_instance: + clear_active_backend() + raise + +def initialize( + debug_mode: bool = False, + log_level: LogLevel = LogLevel.WARNING, + loader_debug_logs: bool = False, + backend: Optional[str] = None, +): """ A function which initializes the Vulkan dispatch library. @@ -338,23 +533,76 @@ def initialize(debug_mode: bool = False, log_level: LogLevel = LogLevel.WARNING, LogLevel.WARNING LogLevel.ERROR loader_debug_logs (bool): A flag to enable vulkan loader debug logs. + backend (`Optional[str]`): Runtime backend to use. Supported values are + "vulkan", "cuda", "opencl", and "dummy". If omitted, the currently selected backend is + reused. If no backend was selected yet, `VKDISPATCH_BACKEND` is used + when set, otherwise "vulkan" is used. """ global __initilized_instance + + backend_name = get_active_backend_name(backend) + backend_explicitly_selected = (backend is not None) or (get_environment_backend() is not None) if __initilized_instance: + if __backend_name != backend_name: + raise RuntimeError( + f"vkdispatch is already initialized with backend '{__backend_name}'. " + f"Cannot reinitialize with '{backend_name}' in the same process." + ) return - - if loader_debug_logs: - os.environ["VK_LOADER_DEBUG"] = "all" - - vkdispatch_native.init(debug_mode, log_level.value) - check_for_errors() - - __initilized_instance = True - -def get_devices() -> typing.List[DeviceInfo]: + if ( + not backend_explicitly_selected + and backend_name == BACKEND_VULKAN + ): + try: + _initialize_with_backend( + BACKEND_VULKAN, + debug_mode=debug_mode, + log_level=log_level, + loader_debug_logs=loader_debug_logs, + ) + return + except BackendUnavailableError as vulkan_error: + try: + _initialize_with_backend( + BACKEND_CUDA, + debug_mode=debug_mode, + log_level=log_level, + loader_debug_logs=loader_debug_logs, + ) + return + except Exception as cuda_python_error: + try: + _initialize_with_backend( + BACKEND_OPENCL, + debug_mode=debug_mode, + log_level=log_level, + loader_debug_logs=loader_debug_logs, + ) + return + except Exception as opencl_error: + raise _build_no_gpu_backend_error( + vulkan_error, + cuda_python_error, + opencl_error, + ) from opencl_error + + try: + _initialize_with_backend( + backend_name, + debug_mode=debug_mode, + log_level=log_level, + loader_debug_logs=loader_debug_logs, + ) + except BackendUnavailableError as backend_error: + if backend_name == BACKEND_VULKAN: + raise _build_vulkan_backend_error(backend_error) from backend_error + raise + + +def get_devices() -> List[DeviceInfo]: """ Get a list of DeviceInfo instances representing all the Vulkan devices on the system. @@ -362,14 +610,60 @@ def get_devices() -> typing.List[DeviceInfo]: `List[DeviceInfo]`: A list of DeviceInfo instances. """ + global __device_infos + initialize() + + return __device_infos - return [ - DeviceInfo(ii, *dev_obj) - for ii, dev_obj in enumerate(vkdispatch_native.get_devices()) - ] -def log(text: str, end: str = '\n', level: LogLevel = LogLevel.ERROR, stack_offset: int = 1): +def get_backend() -> str: + if __initilized_instance: + return __backend_name + + return get_active_backend_name() + +def is_vulkan() -> bool: + """ + A function which checks if the active backend is the Vulkan backend. + + Returns: + `bool`: A flag indicating whether the active backend is the Vulkan backend. + """ + + return get_backend() == BACKEND_VULKAN + +def is_cuda() -> bool: + """ + A function which checks if the active backend is a CUDA backend. + + Returns: + `bool`: A flag indicating whether the active backend is a CUDA backend. + """ + + return get_backend() == BACKEND_CUDA + +def is_opencl() -> bool: + """ + A function which checks if the active backend is the OpenCL backend. + + Returns: + `bool`: A flag indicating whether the active backend is the OpenCL backend. + """ + + return get_backend() == BACKEND_OPENCL + +def is_dummy() -> bool: + """ + A function which checks if the active backend is the dummy backend. + + Returns: + `bool`: A flag indicating whether the active backend is the dummy backend. + """ + + return get_backend() == BACKEND_DUMMY + +def __log_noinit(text: str, end: str = '\n', level: LogLevel = LogLevel.ERROR, stack_offset: int = 1): """ A function which logs a message at the specified log level. @@ -378,16 +672,27 @@ def log(text: str, end: str = '\n', level: LogLevel = LogLevel.ERROR, stack_offs message (`str`): The message to log. """ - initialize() - frame = inspect.stack()[stack_offset] - vkdispatch_native.log( + native.log( level.value, (text + end).encode(), os.path.relpath(frame.filename, os.getcwd()).encode(), frame.lineno ) +def log(text: str, end: str = '\n', level: LogLevel = LogLevel.ERROR, stack_offset: int = 1): + """ + A function which logs a message at the specified log level. + + Args: + level (`LogLevel`): The log level. + message (`str`): The message to log. + """ + + initialize() + + __log_noinit(text, end, level, stack_offset + 1) + def log_error(text: str, end: str = '\n'): """ A function which logs an error message. @@ -438,4 +743,4 @@ def set_log_level(level: LogLevel): initialize() - vkdispatch_native.set_log_level(level.value) \ No newline at end of file + native.set_log_level(level.value) diff --git a/vkdispatch/codegen/__init__.py b/vkdispatch/codegen/__init__.py index 58b2af8f..1d07e8eb 100644 --- a/vkdispatch/codegen/__init__.py +++ b/vkdispatch/codegen/__init__.py @@ -1,50 +1,89 @@ - from .arguments import Constant, Variable, ConstantArray, VariableArray from .arguments import Buffer, Image1D, Image2D, Image3D from .arguments import _ArgType -from .struct_builder import StructBuilder, StructElement -#from .variables import ShaderVariable # BaseVariable, ShaderVariable -#from .variables import BoundVariable, BufferVariable, ImageVariable - -from .builder import ShaderBinding -from .builder import ShaderDescription -from .builder import ShaderBuilder -from .builder import ShaderVariable, BufferVariable, ImageVariable - -from .global_builder import inf_f32, ninf_f32, set_global_builder, comment -from .global_builder import global_invocation, local_invocation, workgroup -from .global_builder import workgroup_size, num_workgroups, num_subgroups -from .global_builder import subgroup_id, subgroup_size, subgroup_invocation, shared_buffer - -from .global_builder import abs, acos, acosh, asin, asinh -from .global_builder import atan, atan2, atanh, atomic_add, barrier -from .global_builder import ceil, clamp, cos, cosh, cross -from .global_builder import degrees, determinant, distance, dot -from .global_builder import exp, exp2, float_bits_to_int, float_bits_to_uint -from .global_builder import floor, fma, int_bits_to_float -from .global_builder import inverse, inverse_sqrt, isinf, isnan -from .global_builder import length, log, log2, max, memory_barrier -from .global_builder import memory_barrier_shared, min, mix, mod -from .global_builder import normalize, pow, radians, round, round_even -from .global_builder import sign, sin, sinh, smoothstep, sqrt, step -from .global_builder import tan, tanh, transpose, trunc, uint_bits_to_float -from .global_builder import mult_c64, mult_conj_c64, complex_from_euler_angle, mult_c64_by_const - -from .global_builder import if_statement, if_any, if_all, else_statement -from .global_builder import else_if_statement, else_if_any, else_if_all -from .global_builder import return_statement, while_statement, new_scope, end -from .global_builder import logical_and, logical_or -from .global_builder import subgroup_add, subgroup_mul -from .global_builder import subgroup_min, subgroup_max, subgroup_and -from .global_builder import subgroup_or, subgroup_xor, subgroup_elect -from .global_builder import subgroup_barrier, mapping_index, kernel_index, mapping_registers -from .global_builder import set_kernel_index, set_mapping_index, set_mapping_registers -from .global_builder import printf, unravel_index -from .global_builder import print_vars as print, builder_context -from .global_builder import new, new_float, new_int, new_uint -from .global_builder import new_vec2, new_ivec2, new_uvec2 -from .global_builder import new_vec3, new_ivec3, new_uvec3 -from .global_builder import new_vec4, new_ivec4, new_uvec4 - -from .abreviations import * \ No newline at end of file +from .struct_builder import StructElement + +from .variables.variables import ShaderVariable + +from .variables.bound_variables import BufferVariable, ImageVariable, BoundVariable + +from .functions.common_builtins import abs, sign, floor, ceil, trunc, round, round_even, comment +from .functions.common_builtins import fract, mod, modf, min, max, clip, clamp, mix +from .functions.common_builtins import step, smoothstep, isnan, isinf, float_bits_to_int +from .functions.common_builtins import float_bits_to_uint, int_bits_to_float, uint_bits_to_float, fma + +from .functions.trigonometry import sin, cos, tan, asin, acos, atan, atan2 +from .functions.trigonometry import sinh, cosh, tanh, asinh, acosh, atanh, radians, degrees + +from .functions.complex_numbers import complex_from_euler_angle + +from .functions.exponential import exp, exp2, log, log2, pow, sqrt, inversesqrt + +from .functions.geometric import length, distance, dot, cross, normalize + +from .functions.block_synchonization import barrier, memory_barrier, memory_barrier_buffer +from .functions.block_synchonization import memory_barrier_shared, memory_barrier_image, group_memory_barrier + +from .functions.matrix import matrix_comp_mult, outer_product, transpose +from .functions.matrix import determinant, inverse + +from .functions.atomic_memory import atomic_add + +from .functions.type_casting import to_dtype, str_to_dtype +from .functions.type_casting import to_float16, to_float, to_float64 +from .functions.type_casting import to_int16, to_int, to_int64, to_uint16, to_uint, to_uint64 +from .functions.type_casting import to_complex, to_complex32, to_complex64, to_complex128 +from .functions.type_casting import to_hvec2, to_hvec3, to_hvec4 +from .functions.type_casting import to_vec2, to_vec3, to_vec4 +from .functions.type_casting import to_dvec2, to_dvec3, to_dvec4 +from .functions.type_casting import to_ihvec2, to_ihvec3, to_ihvec4 +from .functions.type_casting import to_ivec2, to_ivec3, to_ivec4 +from .functions.type_casting import to_uhvec2, to_uhvec3, to_uhvec4 +from .functions.type_casting import to_uvec2, to_uvec3, to_uvec4 +from .functions.type_casting import to_mat2, to_mat3, to_mat4 + +from .functions.registers import new_register, new_complex_register +from .functions.registers import new_float16_register, new_float_register, new_float64_register +from .functions.registers import new_int16_register, new_int_register, new_int64_register +from .functions.registers import new_uint16_register, new_uint_register, new_uint64_register +from .functions.registers import new_complex32_register, new_complex64_register, new_complex128_register +from .functions.registers import new_hvec2_register, new_hvec3_register, new_hvec4_register +from .functions.registers import new_vec2_register, new_vec3_register, new_vec4_register +from .functions.registers import new_dvec2_register, new_dvec3_register, new_dvec4_register +from .functions.registers import new_ihvec2_register, new_ihvec3_register, new_ihvec4_register +from .functions.registers import new_ivec2_register, new_ivec3_register, new_ivec4_register +from .functions.registers import new_uhvec2_register, new_uhvec3_register, new_uhvec4_register +from .functions.registers import new_uvec2_register, new_uvec3_register, new_uvec4_register +from .functions.registers import new_mat2_register, new_mat3_register, new_mat4_register + +from .functions.subgroups import subgroup_add, subgroup_mul +from .functions.subgroups import subgroup_min, subgroup_max, subgroup_and +from .functions.subgroups import subgroup_or, subgroup_xor, subgroup_elect +from .functions.subgroups import subgroup_barrier + +from .functions.control_flow import if_statement, if_any, if_all, else_statement +from .functions.control_flow import else_if_statement, else_if_any, else_if_all +from .functions.control_flow import return_statement, while_statement, new_scope, end +from .functions.control_flow import logical_and, logical_or + +from .functions.complex_numbers import mult_complex, complex_from_euler_angle + +from .functions.builtin_constants import global_invocation_id, local_invocation_id, workgroup_id, local_invocation_index +from .functions.builtin_constants import workgroup_size, num_workgroups, num_subgroups, subgroup_id +from .functions.builtin_constants import subgroup_size, subgroup_invocation_id, inf_f32, ninf_f32, inf_f64, ninf_f64, inf_f16, ninf_f16 + +from .functions.index_raveling import ravel_index, unravel_index + +from .functions.printing import printf +from .functions.printing import print_vars as print + +from .builder import ShaderBinding, ShaderDescription +from .builder import ShaderBuilder, ShaderFlags + +from .backends import CodeGenBackend, GLSLBackend, CUDABackend, OpenCLBackend + +from .global_builder import set_builder, get_builder, shared_buffer, set_shader_print_line_numbers, get_shader_print_line_numbers +from .global_builder import set_codegen_backend, get_codegen_backend + +from .abreviations import * diff --git a/vkdispatch/codegen/abreviations.py b/vkdispatch/codegen/abreviations.py index 1fdff076..f9815812 100644 --- a/vkdispatch/codegen/abreviations.py +++ b/vkdispatch/codegen/abreviations.py @@ -7,20 +7,40 @@ from .arguments import Image2D as Img2 from .arguments import Image3D as Img3 +from vkdispatch.base.dtype import float16 as f16 from vkdispatch.base.dtype import float32 as f32 -from vkdispatch.base.dtype import uint32 as u32 +from vkdispatch.base.dtype import float64 as f64 +from vkdispatch.base.dtype import int16 as i16 +from vkdispatch.base.dtype import uint16 as u16 from vkdispatch.base.dtype import int32 as i32 +from vkdispatch.base.dtype import uint32 as u32 +from vkdispatch.base.dtype import int64 as i64 +from vkdispatch.base.dtype import uint64 as u64 +from vkdispatch.base.dtype import complex32 as c32 from vkdispatch.base.dtype import complex64 as c64 +from vkdispatch.base.dtype import complex128 as c128 +from vkdispatch.base.dtype import hvec2 as hv2 +from vkdispatch.base.dtype import hvec3 as hv3 +from vkdispatch.base.dtype import hvec4 as hv4 from vkdispatch.base.dtype import vec2 as v2 from vkdispatch.base.dtype import vec3 as v3 from vkdispatch.base.dtype import vec4 as v4 -from vkdispatch.base.dtype import uvec2 as uv2 -from vkdispatch.base.dtype import uvec3 as uv3 -from vkdispatch.base.dtype import uvec4 as uv4 +from vkdispatch.base.dtype import dvec2 as dv2 +from vkdispatch.base.dtype import dvec3 as dv3 +from vkdispatch.base.dtype import dvec4 as dv4 +from vkdispatch.base.dtype import ihvec2 as ihv2 +from vkdispatch.base.dtype import ihvec3 as ihv3 +from vkdispatch.base.dtype import ihvec4 as ihv4 from vkdispatch.base.dtype import ivec2 as iv2 from vkdispatch.base.dtype import ivec3 as iv3 from vkdispatch.base.dtype import ivec4 as iv4 +from vkdispatch.base.dtype import uhvec2 as uhv2 +from vkdispatch.base.dtype import uhvec3 as uhv3 +from vkdispatch.base.dtype import uhvec4 as uhv4 +from vkdispatch.base.dtype import uvec2 as uv2 +from vkdispatch.base.dtype import uvec3 as uv3 +from vkdispatch.base.dtype import uvec4 as uv4 from vkdispatch.base.dtype import mat2 as m2 from vkdispatch.base.dtype import mat4 as m4 diff --git a/vkdispatch/codegen/backends/__init__.py b/vkdispatch/codegen/backends/__init__.py new file mode 100644 index 00000000..773f5bee --- /dev/null +++ b/vkdispatch/codegen/backends/__init__.py @@ -0,0 +1,4 @@ +from .base import CodeGenBackend +from .glsl import GLSLBackend +from .cuda import CUDABackend +from .opencl import OpenCLBackend diff --git a/vkdispatch/codegen/backends/base.py b/vkdispatch/codegen/backends/base.py new file mode 100644 index 00000000..aafdab6f --- /dev/null +++ b/vkdispatch/codegen/backends/base.py @@ -0,0 +1,271 @@ +from typing import List, Optional + +import vkdispatch.base.dtype as dtypes + + +class CodeGenBackend: + """ + Interface for backend-specific code generation. + + Subclasses should override all methods that are used by the codegen + pipeline. The base implementation raises NotImplementedError so placeholder + backends can be defined incrementally. + """ + + name: str = "base" + + def reset_state(self) -> None: + # Stateless backends can ignore this. + return + + def mark_feature_usage(self, feature_name: str) -> None: + # Backends that emit optional helper code can override this. + return + + def mark_composite_unary_op(self, var_type: dtypes.dtype, op: str) -> None: + # Backends with composite helper/operator code can override this. + return + + def mark_composite_binary_op( + self, + lhs_type: dtypes.dtype, + rhs_type: dtypes.dtype, + op: str, + *, + inplace: bool = False, + ) -> None: + # Backends with composite helper/operator code can override this. + return + + def type_name(self, var_type: dtypes.dtype) -> str: + raise NotImplementedError + + def constructor( + self, + var_type: dtypes.dtype, + args: List[str], + arg_types: Optional[List[Optional[dtypes.dtype]]] = None, + ) -> str: + _ = arg_types + raise NotImplementedError + + def component_access_expr(self, expr: str, component: str, base_type: dtypes.dtype) -> str: + return f"{expr}.{component}" + + def buffer_component_expr( + self, + scalar_buffer_expr: str, + base_type: dtypes.dtype, + element_index_expr: str, + component_index_expr: str, + ) -> Optional[str]: + _ = (scalar_buffer_expr, base_type, element_index_expr, component_index_expr) + return None + + def fma_function_name(self, var_type: dtypes.dtype) -> str: + return "fma" + + def math_func_name(self, func_name: str, var_type: dtypes.dtype) -> str: + """Return the backend-specific function name for a math operation. + + Backends can override this to remap function names for specific types + (e.g. CUDA __half intrinsics). + """ + return func_name + + def unary_math_expr(self, func_name: str, arg_type: dtypes.dtype, arg_expr: str) -> str: + return f"{self.math_func_name(func_name, arg_type)}({arg_expr})" + + def binary_math_expr( + self, + func_name: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + ) -> str: + mapped = self.math_func_name(func_name, lhs_type) + if func_name == "atan2": + mapped_atan = self.math_func_name("atan", lhs_type) + return f"{mapped_atan}({lhs_expr}, {rhs_expr})" + + return f"{mapped}({lhs_expr}, {rhs_expr})" + + def arithmetic_unary_expr(self, op: str, var_type: dtypes.dtype, var_expr: str) -> Optional[str]: + """Optional backend override for unary arithmetic expressions.""" + _ = (op, var_type, var_expr) + return None + + def arithmetic_binary_expr( + self, + op: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + ) -> Optional[str]: + """Optional backend override for binary arithmetic expressions.""" + _ = (op, lhs_type, lhs_expr, rhs_type, rhs_expr) + return None + + def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str: + raise NotImplementedError + + def make_source(self, header: str, body: str, x: int, y: int, z: int) -> str: + raise NotImplementedError + + def constant_namespace(self) -> str: + raise NotImplementedError + + def variable_namespace(self) -> str: + raise NotImplementedError + + def exec_bounds_guard(self, exec_count_expr: str) -> str: + raise NotImplementedError + + def shared_buffer_declaration(self, var_type: dtypes.dtype, name: str, size: int) -> str: + raise NotImplementedError + + def uniform_block_declaration(self, contents: str) -> str: + raise NotImplementedError + + def storage_buffer_declaration(self, binding: int, var_type: dtypes.dtype, name: str) -> str: + raise NotImplementedError + + def sampler_declaration(self, binding: int, dimensions: int, name: str) -> str: + raise NotImplementedError + + def push_constant_declaration(self, contents: str) -> str: + raise NotImplementedError + + def entry_point(self, body_contents: str) -> str: + raise NotImplementedError + + def inf_f32_expr(self) -> str: + raise NotImplementedError + + def ninf_f32_expr(self) -> str: + raise NotImplementedError + + def inf_f64_expr(self) -> str: + raise NotImplementedError + + def ninf_f64_expr(self) -> str: + raise NotImplementedError + + def inf_f16_expr(self) -> str: + raise NotImplementedError + + def ninf_f16_expr(self) -> str: + raise NotImplementedError + + def float_bits_to_int_expr(self, var_expr: str) -> str: + raise NotImplementedError + + def float_bits_to_uint_expr(self, var_expr: str) -> str: + raise NotImplementedError + + def int_bits_to_float_expr(self, var_expr: str) -> str: + raise NotImplementedError + + def uint_bits_to_float_expr(self, var_expr: str) -> str: + raise NotImplementedError + + def global_invocation_id_expr(self) -> str: + raise NotImplementedError + + def local_invocation_id_expr(self) -> str: + raise NotImplementedError + + def local_invocation_index_expr(self) -> str: + raise NotImplementedError + + def workgroup_id_expr(self) -> str: + raise NotImplementedError + + def workgroup_size_expr(self) -> str: + raise NotImplementedError + + def num_workgroups_expr(self) -> str: + raise NotImplementedError + + def num_subgroups_expr(self) -> str: + raise NotImplementedError + + def subgroup_id_expr(self) -> str: + raise NotImplementedError + + def subgroup_size_expr(self) -> str: + raise NotImplementedError + + def subgroup_invocation_id_expr(self) -> str: + raise NotImplementedError + + def barrier_statement(self) -> str: + raise NotImplementedError + + def memory_barrier_statement(self) -> str: + raise NotImplementedError + + def memory_barrier_buffer_statement(self) -> str: + raise NotImplementedError + + def memory_barrier_shared_statement(self) -> str: + raise NotImplementedError + + def memory_barrier_image_statement(self) -> str: + raise NotImplementedError + + def group_memory_barrier_statement(self) -> str: + raise NotImplementedError + + def subgroup_add_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + raise NotImplementedError + + def subgroup_mul_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + raise NotImplementedError + + def subgroup_min_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + raise NotImplementedError + + def subgroup_max_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + raise NotImplementedError + + def subgroup_and_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + raise NotImplementedError + + def subgroup_or_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + raise NotImplementedError + + def subgroup_xor_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + raise NotImplementedError + + def subgroup_elect_expr(self) -> str: + raise NotImplementedError + + def subgroup_barrier_statement(self) -> str: + raise NotImplementedError + + def printf_statement(self, fmt: str, args: List[str]) -> str: + raise NotImplementedError + + def texture_size_expr(self, texture_expr: str, lod: int, dimensions: int) -> str: + raise NotImplementedError + + def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Optional[str] = None) -> str: + raise NotImplementedError + + def mark_texture_sample_dimension(self, dimensions: int) -> None: + return + + def atomic_add_expr(self, mem_expr: str, value_expr: str, var_type: dtypes.dtype) -> str: + raise NotImplementedError( + f"atomic_add is not supported for backend '{self.name}' and type '{var_type.name}'" + ) diff --git a/vkdispatch/codegen/backends/cuda/__init__.py b/vkdispatch/codegen/backends/cuda/__init__.py new file mode 100644 index 00000000..31730746 --- /dev/null +++ b/vkdispatch/codegen/backends/cuda/__init__.py @@ -0,0 +1,3 @@ +from .backend import CUDABackend + +__all__ = ["CUDABackend"] diff --git a/vkdispatch/codegen/backends/cuda/backend.py b/vkdispatch/codegen/backends/cuda/backend.py new file mode 100644 index 00000000..7cd91f29 --- /dev/null +++ b/vkdispatch/codegen/backends/cuda/backend.py @@ -0,0 +1,892 @@ +from typing import Dict, List, Optional, Set, Tuple + +import vkdispatch.base.dtype as dtypes + +from ..base import CodeGenBackend +from .composite_emitters import ( + _cuda_emit_mat_helpers, + _cuda_emit_mat_type, + _cuda_emit_subgroup_shuffle_xor_vec_overloads, + _cuda_emit_vec_helper, + _cuda_emit_vec_type, + _cuda_emit_vec_wrapper_conversion_helpers, +) +from .helper_snippets import ( + _HELPER_DEPENDENCIES as _CUDA_HELPER_DEPENDENCIES, + _HELPER_ORDER as _CUDA_HELPER_ORDER, + _HELPER_SNIPPETS as _CUDA_HELPER_SNIPPETS, + initialize_feature_usage, +) +from .math_utils import ( + cuda_fast_binary_math_name, + cuda_fast_unary_math_name, + cuda_float_vec_components_for_suffix, + cuda_float_vec_helper_suffix, + cuda_scalar_binary_math_name, + cuda_scalar_unary_math_name, + emit_used_vec_math_helpers, +) +from .specs import ( + _CUDA_MAT_ORDER, + _CUDA_MAT_TYPE_SPECS, + _CUDA_VEC_ORDER, + _CUDA_VEC_TYPE_SPECS, + _DTYPE_TO_COMPOSITE_KEY as _CUDA_DTYPE_TO_COMPOSITE_KEY, + _FLOAT_VEC_DTYPES as _CUDA_FLOAT_VEC_DTYPES, + _FLOAT_VEC_HELPER_SUFFIX_MAP as _CUDA_FLOAT_VEC_HELPER_SUFFIX_MAP, + _SCALAR_TYPE_NAMES as _CUDA_SCALAR_TYPE_NAMES, +) + +class CUDABackend(CodeGenBackend): + name = "cuda" + _CUDA_BUILTIN_UVEC3_SENTINELS: Dict[str, Dict[str, str]] = { + "global_invocation_id": { + "sentinel": "VKDISPATCH_CUDA_GLOBAL_INVOCATION_ID_SENTINEL()", + "x": "(unsigned int)(blockIdx.x * blockDim.x + threadIdx.x)", + "y": "(unsigned int)(blockIdx.y * blockDim.y + threadIdx.y)", + "z": "(unsigned int)(blockIdx.z * blockDim.z + threadIdx.z)", + }, + "local_invocation_id": { + "sentinel": "VKDISPATCH_CUDA_LOCAL_INVOCATION_ID_SENTINEL()", + "x": "(unsigned int)threadIdx.x", + "y": "(unsigned int)threadIdx.y", + "z": "(unsigned int)threadIdx.z", + }, + "workgroup_id": { + "sentinel": "VKDISPATCH_CUDA_WORKGROUP_ID_SENTINEL()", + "x": "(unsigned int)blockIdx.x", + "y": "(unsigned int)blockIdx.y", + "z": "(unsigned int)blockIdx.z", + }, + } + + _HELPER_SNIPPETS: Dict[str, str] = _CUDA_HELPER_SNIPPETS + _HELPER_ORDER: List[str] = _CUDA_HELPER_ORDER + _HELPER_DEPENDENCIES: Dict[str, List[str]] = _CUDA_HELPER_DEPENDENCIES + + def __init__(self) -> None: + self._fixed_preamble = "" + self.reset_state() + + def reset_state(self) -> None: + self._kernel_params: List[str] = [] + self._entry_alias_lines: List[str] = [] + self._composite_type_usage: Set[str] = set() + self._composite_vec_op_usage: Dict[str, Set[str]] = {} + self._composite_mat_op_usage: Dict[str, Set[str]] = {} + self._composite_vec_unary_math_usage: Dict[str, Set[str]] = {} + self._composite_vec_binary_math_usage: Dict[str, Set[str]] = {} + self._sample_texture_dims: Set[int] = set() + self._needs_cuda_fp16: bool = False + self._feature_usage: Dict[str, bool] = initialize_feature_usage() + + def mark_feature_usage(self, feature_name: str) -> None: + if feature_name in self._feature_usage: + self._feature_usage[feature_name] = True + + _DTYPE_TO_COMPOSITE_KEY = _CUDA_DTYPE_TO_COMPOSITE_KEY + + def _composite_key_for_dtype(self, var_type: dtypes.dtype) -> Optional[str]: + return self._DTYPE_TO_COMPOSITE_KEY.get(var_type) + + def _record_composite_type_key(self, key: str) -> None: + self.mark_feature_usage("composite_types") + self._composite_type_usage.add(key) + + if key in _CUDA_MAT_TYPE_SPECS: + dim = _CUDA_MAT_TYPE_SPECS[key][3] + self._composite_type_usage.add(f"float{dim}") + + def _record_composite_type(self, var_type: dtypes.dtype) -> Optional[str]: + key = self._composite_key_for_dtype(var_type) + if key is None: + return None + self._record_composite_type_key(key) + return key + + def _record_vec_op(self, key: str, token: str) -> None: + self._record_composite_type_key(key) + self._composite_vec_op_usage.setdefault(key, set()).add(token) + + def _record_mat_op(self, key: str, token: str) -> None: + self._record_composite_type_key(key) + self._composite_mat_op_usage.setdefault(key, set()).add(token) + + def _record_vec_unary_math(self, key: str, func_name: str) -> None: + self._record_composite_type_key(key) + self._composite_vec_unary_math_usage.setdefault(key, set()).add(func_name) + + def _record_vec_binary_math(self, key: str, func_name: str, signature: str) -> None: + self._record_composite_type_key(key) + self._composite_vec_binary_math_usage.setdefault(key, set()).add(f"{func_name}:{signature}") + + def _propagate_matrix_vec_dependencies(self, mat_key: str, token: str) -> None: + dim = _CUDA_MAT_TYPE_SPECS[mat_key][3] + vec_key = f"float{dim}" + + if token == "un:-": + self._record_vec_op(vec_key, "un:-") + return + + if token.startswith("cmpd:"): + if token.endswith(":m"): + vec_token = token[:-1] + "v" + self._record_vec_op(vec_key, vec_token) + return + if token.endswith(":s"): + self._record_vec_op(vec_key, token) + return + + if token.startswith("bin:"): + parts = token.split(":") + if len(parts) != 3: + return + _, op, shape = parts + if shape == "mm": + if op in ["+", "-"]: + self._record_vec_op(vec_key, f"bin:{op}:vv") + elif op == "*": + self._record_mat_op(mat_key, "bin:*:mv") + self._propagate_matrix_vec_dependencies(mat_key, "bin:*:mv") + return + if shape == "ms": + self._record_vec_op(vec_key, f"bin:{op}:vs") + return + if shape == "sm": + self._record_vec_op(vec_key, f"bin:{op}:sv") + return + if shape == "mv": + self._record_vec_op(vec_key, "bin:*:vs") + self._record_vec_op(vec_key, "bin:+:vv") + return + if shape == "vm": + return + + def mark_composite_unary_op(self, var_type: dtypes.dtype, op: str) -> None: + key = self._record_composite_type(var_type) + if key is None: + return + + token = f"un:{op}" + if key in _CUDA_VEC_TYPE_SPECS: + self._record_vec_op(key, token) + return + if key in _CUDA_MAT_TYPE_SPECS: + self._record_mat_op(key, token) + self._propagate_matrix_vec_dependencies(key, token) + + def mark_composite_binary_op( + self, + lhs_type: dtypes.dtype, + rhs_type: dtypes.dtype, + op: str, + *, + inplace: bool = False, + ) -> None: + lhs_key = self._record_composite_type(lhs_type) + rhs_key = self._record_composite_type(rhs_type) + + lhs_is_composite = lhs_key is not None + rhs_is_composite = rhs_key is not None + if not lhs_is_composite and not rhs_is_composite: + return + + lhs_is_scalar = dtypes.is_scalar(lhs_type) + rhs_is_scalar = dtypes.is_scalar(rhs_type) + + if lhs_key in _CUDA_VEC_TYPE_SPECS and (rhs_is_scalar or rhs_key in _CUDA_VEC_TYPE_SPECS): + if inplace: + suffix = "s" if rhs_is_scalar else "v" + self._record_vec_op(lhs_key, f"cmpd:{op}=:{suffix}") + return + shape = "vs" if rhs_is_scalar else "vv" + self._record_vec_op(lhs_key, f"bin:{op}:{shape}") + return + + if rhs_key in _CUDA_VEC_TYPE_SPECS and lhs_is_scalar and not inplace: + self._record_vec_op(rhs_key, f"bin:{op}:sv") + return + + if lhs_key in _CUDA_MAT_TYPE_SPECS: + if inplace: + if rhs_is_scalar: + token = f"cmpd:{op}=:s" + elif rhs_key in _CUDA_MAT_TYPE_SPECS: + token = f"cmpd:{op}=:m" + else: + return + self._record_mat_op(lhs_key, token) + self._propagate_matrix_vec_dependencies(lhs_key, token) + return + + if rhs_is_scalar: + token = f"bin:{op}:ms" + self._record_mat_op(lhs_key, token) + self._propagate_matrix_vec_dependencies(lhs_key, token) + return + + if rhs_key in _CUDA_MAT_TYPE_SPECS: + token = "bin:*:mm" if op == "*" else f"bin:{op}:mm" + self._record_mat_op(lhs_key, token) + self._propagate_matrix_vec_dependencies(lhs_key, token) + return + + if rhs_key in _CUDA_VEC_TYPE_SPECS and op == "*": + token = "bin:*:mv" + self._record_mat_op(lhs_key, token) + self._propagate_matrix_vec_dependencies(lhs_key, token) + return + + if rhs_key in _CUDA_MAT_TYPE_SPECS and lhs_is_scalar and not inplace: + token = f"bin:{op}:sm" + self._record_mat_op(rhs_key, token) + self._propagate_matrix_vec_dependencies(rhs_key, token) + return + + if lhs_key in _CUDA_VEC_TYPE_SPECS and rhs_key in _CUDA_MAT_TYPE_SPECS and op == "*" and not inplace: + token = "bin:*:vm" + self._record_mat_op(rhs_key, token) + self._propagate_matrix_vec_dependencies(rhs_key, token) + + def _emit_used_composite_helpers(self) -> str: + if len(self._composite_type_usage) == 0: + return "" + + parts: List[str] = [] + + # Subgroup helpers use vector binary operators internally (e.g. value = value + shuffled) + # even if user code never directly emits the corresponding operator on that vector type. + subgroup_vec_op_requirements = [ + ("subgroup_add", "bin:+:vv"), + ("subgroup_mul", "bin:*:vv"), + ("subgroup_and", "bin:&:vv"), + ("subgroup_or", "bin:|:vv"), + ("subgroup_xor", "bin:^:vv"), + ] + for feature_name, token in subgroup_vec_op_requirements: + if not self._feature_usage.get(feature_name, False): + continue + for key in self._composite_type_usage: + if key in _CUDA_VEC_TYPE_SPECS: + self._composite_vec_op_usage.setdefault(key, set()).add(token) + + emitted_vec_keys: Set[str] = set() + for key in _CUDA_VEC_ORDER: + if key not in self._composite_type_usage: + continue + vec_name, scalar_type, dim, cuda_native_type, allow_neg, enable_bitwise = _CUDA_VEC_TYPE_SPECS[key] + emitted_vec_keys.add(key) + parts.append( + _cuda_emit_vec_type( + vec_name, + scalar_type, + dim, + cuda_native_type, + allow_unary_neg=allow_neg, + enable_bitwise=enable_bitwise, + needed_ops=self._composite_vec_op_usage.get(key, set()), + ) + ) + parts.append(_cuda_emit_vec_helper(key, vec_name, scalar_type, dim)) + for key in _CUDA_VEC_ORDER: + if key not in emitted_vec_keys: + continue + vec_name, scalar_type, dim, _, _, _ = _CUDA_VEC_TYPE_SPECS[key] + conversion_helpers = _cuda_emit_vec_wrapper_conversion_helpers( + key, + vec_name, + scalar_type, + dim, + available_keys=emitted_vec_keys, + ) + if len(conversion_helpers) > 0: + parts.append(conversion_helpers) + + subgroup_shuffle_overloads = _cuda_emit_subgroup_shuffle_xor_vec_overloads(emitted_vec_keys) + if len(subgroup_shuffle_overloads) > 0: + parts.append(subgroup_shuffle_overloads) + + for key in _CUDA_MAT_ORDER: + if key not in self._composite_type_usage: + continue + mat_name, vec_name, vec_helper_suffix, dim = _CUDA_MAT_TYPE_SPECS[key] + parts.append(_cuda_emit_mat_type(mat_name, vec_name, dim, self._composite_mat_op_usage.get(key, set()))) + parts.append(_cuda_emit_mat_helpers(mat_name, key, vec_name, vec_helper_suffix, dim)) + + vec_math_helpers = self._emit_used_vec_math_helpers() + if len(vec_math_helpers) > 0: + parts.append(vec_math_helpers) + + return "\n\n".join(parts) + + @staticmethod + def _cuda_scalar_unary_math_name(func_name: str, scalar_type: str) -> str: + return cuda_scalar_unary_math_name(func_name, scalar_type) + + @staticmethod + def _cuda_scalar_binary_math_name(func_name: str, scalar_type: str) -> str: + return cuda_scalar_binary_math_name(func_name, scalar_type) + + def _emit_used_vec_math_helpers(self) -> str: + return emit_used_vec_math_helpers( + self._composite_vec_unary_math_usage, + self._composite_vec_binary_math_usage, + ) + + def _register_kernel_param(self, param_decl: str) -> None: + if param_decl not in self._kernel_params: + self._kernel_params.append(param_decl) + + def _register_alias_line(self, alias_line: str) -> None: + if alias_line not in self._entry_alias_lines: + self._entry_alias_lines.append(alias_line) + + @staticmethod + def _is_plain_integer_literal(expr: str) -> bool: + if len(expr) == 0: + return False + if expr[0] in "+-": + return len(expr) > 1 and expr[1:].isdigit() + return expr.isdigit() + + _SCALAR_TYPE_NAMES = _CUDA_SCALAR_TYPE_NAMES + + def type_name(self, var_type: dtypes.dtype) -> str: + scalar_name = self._SCALAR_TYPE_NAMES.get(var_type) + if scalar_name is not None: + if var_type == dtypes.float16: + self._needs_cuda_fp16 = True + return scalar_name + + key = self._composite_key_for_dtype(var_type) + if key is not None: + self._record_composite_type(var_type) + if key in _CUDA_VEC_TYPE_SPECS: + # Track fp16 header need when half vector types are used. + if _CUDA_VEC_TYPE_SPECS[key][1] == "__half": + self._needs_cuda_fp16 = True + return _CUDA_VEC_TYPE_SPECS[key][0] + if key in _CUDA_MAT_TYPE_SPECS: + return _CUDA_MAT_TYPE_SPECS[key][0] + + raise ValueError(f"Unsupported CUDA type mapping for '{var_type.name}'") + + _FLOAT_VEC_DTYPES = _CUDA_FLOAT_VEC_DTYPES + + def constructor( + self, + var_type: dtypes.dtype, + args: List[str], + arg_types: Optional[List[Optional[dtypes.dtype]]] = None, + ) -> str: + _ = arg_types + if ( + len(args) == 1 + and var_type in self._FLOAT_VEC_DTYPES + and self._is_plain_integer_literal(args[0]) + ): + scalar_type = None + if dtypes.is_complex(var_type): + scalar_type = var_type.child_type + elif dtypes.is_vector(var_type): + scalar_type = var_type.scalar + + if scalar_type == dtypes.float64: + args = [f"{args[0]}.0"] + else: + args = [f"{args[0]}.0f"] + + target_type = self.type_name(var_type) + + if dtypes.is_scalar(var_type): + assert len(args) > 0, f"Constructor for scalar type '{var_type.name}' needs at least one argument." + return f"(({target_type})({args[0]}))" + + if var_type == dtypes.mat2: + self.mark_feature_usage("make_mat2") + return f"vkdispatch_make_mat2({', '.join(args)})" + if var_type == dtypes.mat3: + self.mark_feature_usage("make_mat3") + return f"vkdispatch_make_mat3({', '.join(args)})" + if var_type == dtypes.mat4: + self.mark_feature_usage("make_mat4") + return f"vkdispatch_make_mat4({', '.join(args)})" + + helper_suffix = target_type[len("vkdispatch_"):] if target_type.startswith("vkdispatch_") else target_type + helper_name = f"vkdispatch_make_{helper_suffix}" + self.mark_feature_usage(f"make_{helper_suffix}") + return f"{helper_name}({', '.join(args)})" + + def component_access_expr(self, expr: str, component: str, base_type: dtypes.dtype) -> str: + if dtypes.is_scalar(base_type): + if component == "x": + return expr + return super().component_access_expr(expr, component, base_type) + + if dtypes.is_vector(base_type) or dtypes.is_complex(base_type): + direct_builtin_component = self._cuda_builtin_uvec3_component_expr(expr, component, base_type) + if direct_builtin_component is not None: + return direct_builtin_component + return f"{expr}.v.{component}" + + return super().component_access_expr(expr, component, base_type) + + def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str: + subgroup_support = "1" if enable_subgroup_ops else "0" + printf_support = "1" if enable_printf else "0" + + self._enable_subgroup_ops = enable_subgroup_ops + self._enable_printf = enable_printf + + helper_header = self._helper_header() + fp16_include = "#include \n" if self._needs_cuda_fp16 else "" + + self._fixed_preamble = ( + "#include \n" + f"{fp16_include}\n" + f"#define VKDISPATCH_ENABLE_SUBGROUP_OPS {subgroup_support}\n" + f"#define VKDISPATCH_ENABLE_PRINTF {printf_support}\n\n" + f"{helper_header}\n\n" + ) + + return self._fixed_preamble + + def _resolve_helper_dependencies(self, helpers: Set[str]) -> Set[str]: + pending = list(helpers) + resolved = set(helpers) + + while len(pending) > 0: + helper_name = pending.pop() + + for dependency in self._HELPER_DEPENDENCIES.get(helper_name, []): + if dependency not in resolved: + resolved.add(dependency) + pending.append(dependency) + + return resolved + + def _helper_header(self) -> str: + enabled_helpers = { + helper_name + for helper_name, is_enabled in self._feature_usage.items() + if is_enabled + } + + resolved_helpers = self._resolve_helper_dependencies(enabled_helpers) + + if len(resolved_helpers) == 0: + return "" + + helper_sections: List[str] = [] + + for helper_name in self._HELPER_ORDER: + if helper_name in resolved_helpers: + if helper_name == "composite_types": + composite_helpers = self._emit_used_composite_helpers() + if len(composite_helpers) > 0: + helper_sections.append(composite_helpers) + continue + + snippet = self._HELPER_SNIPPETS[helper_name] + if len(snippet) > 0: + helper_sections.append(snippet) + + return "\n\n".join(helper_sections) + "\n\n" + + def make_source(self, header: str, body: str, x: int, y: int, z: int) -> str: + header, body = self._finalize_cuda_builtin_uvec3_sentinels(header, body) + + expected_size_header = ( + f"// Expected local size: ({x}, {y}, {z})\n" + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X {x}\n" + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y {y}\n" + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Z {z}\n" + ) + + return f"{expected_size_header}\n{header}\n{body}" + + def constant_namespace(self) -> str: + return "UBO" + + def variable_namespace(self) -> str: + return "PC" + + def exec_bounds_guard(self, exec_count_expr: str) -> str: + gid = self.global_invocation_id_expr() + exec_expr = f"({exec_count_expr})" + gid_expr = f"({gid})" + return ( + f"if ({self.component_access_expr(exec_expr, 'x', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'x', dtypes.uvec3)} || " + f"{self.component_access_expr(exec_expr, 'y', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'y', dtypes.uvec3)} || " + f"{self.component_access_expr(exec_expr, 'z', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'z', dtypes.uvec3)}) {{ return; }}\n" + ) + + def shared_buffer_declaration(self, var_type: dtypes.dtype, name: str, size: int) -> str: + return f"__shared__ {self.type_name(var_type)} {name}[{size}];" + + def uniform_block_declaration(self, contents: str) -> str: + self._register_kernel_param("const UniformObjectBuffer vkdispatch_uniform_value") + self._register_alias_line("const UniformObjectBuffer& UBO = vkdispatch_uniform_value;") + return f"\nstruct UniformObjectBuffer {{\n{contents}\n}};\n" + + def storage_buffer_declaration(self, binding: int, var_type: dtypes.dtype, name: str) -> str: + struct_name = f"Buffer{binding}" + param_name = f"vkdispatch_binding_{binding}_ptr" + self._register_kernel_param(f"{self.type_name(var_type)}* {param_name}") + self._register_alias_line(f"{struct_name} {name} = {{{param_name}}};") + return f"struct {struct_name} {{ {self.type_name(var_type)}* data; }};\n" + + def sampler_declaration(self, binding: int, dimensions: int, name: str) -> str: + param_name = f"vkdispatch_sampler_{binding}" + self._register_kernel_param(f"cudaTextureObject_t {param_name}") + self._register_alias_line(f"cudaTextureObject_t {name} = {param_name};") + return f"// sampler binding {binding}, dimensions={dimensions}\n" + + def push_constant_declaration(self, contents: str) -> str: + self._register_kernel_param("const PushConstant vkdispatch_pc_value") + self._register_alias_line("const PushConstant& PC = vkdispatch_pc_value;") + return f"\nstruct PushConstant {{\n{contents}\n}};\n" + + def entry_point(self, body_contents: str) -> str: + params = ", ".join(self._kernel_params) + + alias_block = "" + for line in self._entry_alias_lines: + alias_block += f" {line}\n" + + return ( + f'extern "C" __global__ void vkdispatch_main({params}) {{\n' + f"{alias_block}" + f"{body_contents}" + f"}}\n" + ) + + def inf_f32_expr(self) -> str: + self.mark_feature_usage("uintBitsToFloat") + return "uintBitsToFloat(0x7F800000u)" + + def ninf_f32_expr(self) -> str: + self.mark_feature_usage("uintBitsToFloat") + return "uintBitsToFloat(0xFF800000u)" + + def inf_f64_expr(self) -> str: + self.mark_feature_usage("longlong_as_double") + return "__longlong_as_double(0x7FF0000000000000LL)" + + def ninf_f64_expr(self) -> str: + self.mark_feature_usage("longlong_as_double") + return "__longlong_as_double(0xFFF0000000000000LL)" + + def inf_f16_expr(self) -> str: + self.mark_feature_usage("ushort_as_half") + return "__ushort_as_half(0x7C00u)" + + def ninf_f16_expr(self) -> str: + self.mark_feature_usage("ushort_as_half") + return "__ushort_as_half(0xFC00u)" + + def fma_function_name(self, var_type: dtypes.dtype) -> str: + if var_type == dtypes.float16: + return "__hfma" + if var_type == dtypes.float32: + return "fmaf" + return "fma" + + def math_func_name(self, func_name: str, var_type: dtypes.dtype) -> str: + scalar = var_type + if dtypes.is_vector(var_type) or dtypes.is_matrix(var_type): + scalar = var_type.scalar + elif dtypes.is_complex(var_type): + scalar = var_type.child_type + + if scalar == dtypes.float16: + return self._cuda_scalar_unary_math_name(func_name, "__half") + if scalar == dtypes.float32: + return self._cuda_fast_unary_math_name(func_name) + # double and integer types use standard C names + return func_name + + @staticmethod + def _cuda_fast_unary_math_name(func_name: str) -> str: + return cuda_fast_unary_math_name(func_name) + + @staticmethod + def _cuda_fast_binary_math_name(func_name: str) -> str: + return cuda_fast_binary_math_name(func_name) + + _FLOAT_VEC_HELPER_SUFFIX_MAP = _CUDA_FLOAT_VEC_HELPER_SUFFIX_MAP + + @staticmethod + def _cuda_float_vec_helper_suffix(var_type: dtypes.dtype) -> Optional[str]: + return cuda_float_vec_helper_suffix(var_type) + + @staticmethod + def _cuda_float_vec_components_for_suffix(helper_suffix: str) -> List[str]: + return cuda_float_vec_components_for_suffix(helper_suffix) + + def _cuda_componentwise_unary_math_expr(self, func_name: str, arg_type: dtypes.dtype, arg_expr: str) -> Optional[str]: + helper_suffix = self._cuda_float_vec_helper_suffix(arg_type) + if helper_suffix is None: + return None + + self._record_vec_unary_math(helper_suffix, func_name) + return f"{func_name}({arg_expr})" + + def _cuda_componentwise_binary_math_expr( + self, + func_name: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + ) -> Optional[str]: + lhs_helper = self._cuda_float_vec_helper_suffix(lhs_type) + rhs_helper = self._cuda_float_vec_helper_suffix(rhs_type) + + if lhs_helper is None and rhs_helper is None: + return None + + if lhs_helper is not None and rhs_helper is not None and lhs_helper != rhs_helper: + return None + + helper_suffix = lhs_helper if lhs_helper is not None else rhs_helper + assert helper_suffix is not None + + signature = ("v" if lhs_helper is not None else "s") + ("v" if rhs_helper is not None else "s") + self._record_vec_binary_math(helper_suffix, func_name, signature) + return f"{func_name}({lhs_expr}, {rhs_expr})" + + def unary_math_expr(self, func_name: str, arg_type: dtypes.dtype, arg_expr: str) -> str: + vector_expr = self._cuda_componentwise_unary_math_expr(func_name, arg_type, arg_expr) + if vector_expr is not None: + return vector_expr + + mapped = self.math_func_name(func_name, arg_type) + return f"{mapped}({arg_expr})" + + def binary_math_expr( + self, + func_name: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + ) -> str: + vector_expr = self._cuda_componentwise_binary_math_expr( + func_name, + lhs_type, + lhs_expr, + rhs_type, + rhs_expr, + ) + if vector_expr is not None: + return vector_expr + + if dtypes.is_scalar(lhs_type) and dtypes.is_scalar(rhs_type): + scalar = lhs_type + scalar_name = self._SCALAR_TYPE_NAMES.get(scalar, "float") + return f"{self._cuda_scalar_binary_math_name(func_name, scalar_name)}({lhs_expr}, {rhs_expr})" + + return f"{func_name}({lhs_expr}, {rhs_expr})" + + def float_bits_to_int_expr(self, var_expr: str) -> str: + self.mark_feature_usage("floatBitsToInt") + return f"floatBitsToInt({var_expr})" + + def float_bits_to_uint_expr(self, var_expr: str) -> str: + self.mark_feature_usage("floatBitsToUint") + return f"floatBitsToUint({var_expr})" + + def int_bits_to_float_expr(self, var_expr: str) -> str: + self.mark_feature_usage("intBitsToFloat") + return f"intBitsToFloat({var_expr})" + + def uint_bits_to_float_expr(self, var_expr: str) -> str: + self.mark_feature_usage("uintBitsToFloat") + return f"uintBitsToFloat({var_expr})" + + def global_invocation_id_expr(self) -> str: + return self._CUDA_BUILTIN_UVEC3_SENTINELS["global_invocation_id"]["sentinel"] + + def local_invocation_id_expr(self) -> str: + return self._CUDA_BUILTIN_UVEC3_SENTINELS["local_invocation_id"]["sentinel"] + + def local_invocation_index_expr(self) -> str: + self.mark_feature_usage("local_invocation_index") + return "vkdispatch_local_invocation_index()" + + def workgroup_id_expr(self) -> str: + return self._CUDA_BUILTIN_UVEC3_SENTINELS["workgroup_id"]["sentinel"] + + def workgroup_size_expr(self) -> str: + self._record_composite_type_key("uint3") + self.mark_feature_usage("make_uint3") + return "vkdispatch_make_uint3((unsigned int)blockDim.x, (unsigned int)blockDim.y, (unsigned int)blockDim.z)" + + def num_workgroups_expr(self) -> str: + self._record_composite_type_key("uint3") + self.mark_feature_usage("make_uint3") + return "vkdispatch_make_uint3((unsigned int)gridDim.x, (unsigned int)gridDim.y, (unsigned int)gridDim.z)" + + def num_subgroups_expr(self) -> str: + self.mark_feature_usage("num_subgroups") + return "vkdispatch_num_subgroups()" + + def subgroup_id_expr(self) -> str: + self.mark_feature_usage("subgroup_id") + return "vkdispatch_subgroup_id()" + + def subgroup_size_expr(self) -> str: + self.mark_feature_usage("subgroup_size") + return "vkdispatch_subgroup_size()" + + def subgroup_invocation_id_expr(self) -> str: + self.mark_feature_usage("subgroup_invocation_id") + return "vkdispatch_subgroup_invocation_id()" + + def barrier_statement(self) -> str: + return "__syncthreads();" + + def memory_barrier_statement(self) -> str: + return "__threadfence();" + + def memory_barrier_buffer_statement(self) -> str: + return "__threadfence();" + + def memory_barrier_shared_statement(self) -> str: + return "__threadfence_block();" + + def memory_barrier_image_statement(self) -> str: + return "__threadfence();" + + def group_memory_barrier_statement(self) -> str: + return "__threadfence_block();" + + @staticmethod + def _strip_outer_parens(expr: str) -> str: + stripped = expr.strip() + while len(stripped) >= 2 and stripped[0] == "(" and stripped[-1] == ")": + depth = 0 + balanced = True + for idx, ch in enumerate(stripped): + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + if depth < 0: + balanced = False + break + if depth == 0 and idx != len(stripped) - 1: + balanced = False + break + if not balanced or depth != 0: + break + stripped = stripped[1:-1].strip() + return stripped + + def _cuda_builtin_uvec3_component_expr( + self, + expr: str, + component: str, + base_type: dtypes.dtype, + ) -> Optional[str]: + if base_type != dtypes.uvec3 or component not in ("x", "y", "z"): + return None + + stripped_expr = self._strip_outer_parens(expr) + for builtin_spec in self._CUDA_BUILTIN_UVEC3_SENTINELS.values(): + if stripped_expr == builtin_spec["sentinel"]: + return builtin_spec[component] + + return None + + def _finalize_cuda_builtin_uvec3_sentinels(self, header: str, body: str) -> Tuple[str, str]: + for builtin_spec in self._CUDA_BUILTIN_UVEC3_SENTINELS.values(): + sentinel = builtin_spec["sentinel"] + if sentinel not in header and sentinel not in body: + continue + + self._record_composite_type_key("uint3") + self.mark_feature_usage("make_uint3") + replacement = ( + "vkdispatch_make_uint3(" + f"{builtin_spec['x']}, {builtin_spec['y']}, {builtin_spec['z']}" + ")" + ) + header = header.replace(sentinel, replacement) + body = body.replace(sentinel, replacement) + + return header, body + + def subgroup_add_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_add") + return f"vkdispatch_subgroup_add({arg_expr})" + + def subgroup_mul_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_mul") + return f"vkdispatch_subgroup_mul({arg_expr})" + + def subgroup_min_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_min") + return f"vkdispatch_subgroup_min({arg_expr})" + + def subgroup_max_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_max") + return f"vkdispatch_subgroup_max({arg_expr})" + + def subgroup_and_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_and") + return f"vkdispatch_subgroup_and({arg_expr})" + + def subgroup_or_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_or") + return f"vkdispatch_subgroup_or({arg_expr})" + + def subgroup_xor_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + self.mark_feature_usage("subgroup_xor") + return f"vkdispatch_subgroup_xor({arg_expr})" + + def subgroup_elect_expr(self) -> str: + self.mark_feature_usage("subgroup_invocation_id") + return "((int)(vkdispatch_subgroup_invocation_id() == 0u))" + + def subgroup_barrier_statement(self) -> str: + return "__syncwarp();" + + def printf_statement(self, fmt: str, args: List[str]) -> str: + #safe_fmt = fmt.replace("\\", "\\\\").replace('"', '\\"') + + if len(args) == 0: + return f'printf("{fmt}");' + + return f'printf("{fmt}", {", ".join(args)});' + + def texture_size_expr(self, texture_expr: str, lod: int, dimensions: int) -> str: + # CUDA texture objects do not expose shape directly in device code. + # The future CUDA backend should pass explicit texture shape parameters. + if dimensions == 1: + return "1.0f" + if dimensions == 2: + self.mark_feature_usage("make_float2") + return "vkdispatch_make_float2(1.0f)" + if dimensions == 3: + self.mark_feature_usage("make_float3") + return "vkdispatch_make_float3(1.0f)" + + raise ValueError(f"Unsupported texture dimensions '{dimensions}'") + + def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Optional[str] = None) -> str: + raise NotImplementedError("Direct texture sampling is not supported in CUDA backend. Use vkdispatch_sample_texture helper functions instead.") + + def atomic_add_expr(self, mem_expr: str, value_expr: str, var_type: dtypes.dtype) -> str: + if var_type not in (dtypes.int32, dtypes.uint32): + raise NotImplementedError(f"CUDA atomic_add only supports int32/uint32, got '{var_type.name}'") + + return f"atomicAdd(&({mem_expr}), {value_expr})" diff --git a/vkdispatch/codegen/backends/cuda/composite_emitters.py b/vkdispatch/codegen/backends/cuda/composite_emitters.py new file mode 100644 index 00000000..abb23ed6 --- /dev/null +++ b/vkdispatch/codegen/backends/cuda/composite_emitters.py @@ -0,0 +1,380 @@ +from typing import List, Optional, Set + +from .specs import _CUDA_MAT_TYPE_SPECS, _CUDA_VEC_ORDER, _CUDA_VEC_TYPE_SPECS + + +def _cuda_vec_components(dim: int) -> List[str]: + if dim < 2 or dim > 4: + raise ValueError(f"Unsupported vector dimension '{dim}'") + return list("xyzw"[:dim]) + + +def _cuda_join_statements(statements: List[str]) -> str: + if len(statements) == 0: + return "" + return " ".join(statements) + + +def _cuda_emit_vec_type( + vec_name: str, + scalar_type: str, + dim: int, + cuda_native_type: str, + *, + allow_unary_neg: bool, + enable_bitwise: bool, + needed_ops: Optional[Set[str]] = None, +) -> str: + comps = _cuda_vec_components(dim) + if needed_ops is None: + needed_ops = set() + if allow_unary_neg: + needed_ops.add("un:-") + if enable_bitwise: + needed_ops.add("un:~") + for op in ["+", "-", "*", "/"]: + needed_ops.add(f"cmpd:{op}=:v") + needed_ops.add(f"cmpd:{op}=:s") + needed_ops.add(f"bin:{op}:vv") + needed_ops.add(f"bin:{op}:vs") + needed_ops.add(f"bin:{op}:sv") + if enable_bitwise: + for op in ["&", "|", "^", "<<", ">>"]: + needed_ops.add(f"cmpd:{op}=:v") + needed_ops.add(f"cmpd:{op}=:s") + needed_ops.add(f"bin:{op}:vv") + needed_ops.add(f"bin:{op}:vs") + needed_ops.add(f"bin:{op}:sv") + + def has(token: str) -> bool: + return token in needed_ops + + def self_comp(c: str) -> str: + return f"v.{c}" + + def wrap_comp(obj: str, c: str) -> str: + return f"{obj}.v.{c}" + + def native_comp(obj: str, c: str) -> str: + return f"{obj}.{c}" + + def index_op_body() -> str: + branches: List[str] = [] + for idx, c in enumerate(comps): + prefix = "if" if idx == 0 else "else if" + branches.append(f"{prefix} (i == {idx}) return v.{c};") + branches.append(f"else return v.{comps[0]};") + return " ".join(branches) + + lines: List[str] = [f"struct {vec_name} {{"] + lines.append(f" {cuda_native_type} v;") + lines.append("") + ctor_args = ", ".join([f"{scalar_type} {c}_" for c in comps]) + ctor_init = "{" + ", ".join([f"{c}_" for c in comps]) + "}" + splat_init = "{" + ", ".join(["s" for _ in comps]) + "}" + cast_init = "{" + ", ".join([f"({scalar_type}){native_comp('src', c)}" for c in comps]) + "}" + member_guard = ", ".join([f"(void)(((const TVec*)0)->{c})" for c in comps]) + lines.append(f" __device__ __forceinline__ {vec_name}() = default;") + lines.append(f" __device__ __forceinline__ {vec_name}({ctor_args}) : v{ctor_init} {{}}") + lines.append(f" __device__ __forceinline__ explicit {vec_name}({scalar_type} s) : v{splat_init} {{}}") + lines.append(f" __device__ __forceinline__ explicit {vec_name}(const {cuda_native_type}& native) : v(native) {{}}") + lines.append(f" template ") + lines.append(f" __device__ __forceinline__ explicit {vec_name}(const TVec& src) : v{cast_init} {{}}") + lines.append(f" __device__ __forceinline__ {scalar_type}& operator[](int i) {{ {index_op_body()} }}") + lines.append(f" __device__ __forceinline__ const {scalar_type}& operator[](int i) const {{ {index_op_body()} }}") + + if allow_unary_neg and has("un:-"): + neg_expr = ", ".join([f"-{self_comp(c)}" for c in comps]) + lines.append(f" __device__ __forceinline__ {vec_name} operator-() const {{ return {vec_name}({neg_expr}); }}") + + if enable_bitwise and has("un:~"): + not_expr = ", ".join([f"~{self_comp(c)}" for c in comps]) + lines.append(f" __device__ __forceinline__ {vec_name} operator~() const {{ return {vec_name}({not_expr}); }}") + + for op in ["+", "-", "*", "/"]: + op_assign = op + "=" + if has(f"cmpd:{op}=:v"): + vv_ops = _cuda_join_statements([f"{self_comp(c)} {op_assign} {wrap_comp('b', c)};" for c in comps]) + lines.append( + f" __device__ __forceinline__ {vec_name}& operator{op_assign}(const {vec_name}& b) {{ {vv_ops} return *this; }}" + ) + if has(f"cmpd:{op}=:s"): + sv_ops = _cuda_join_statements([f"{self_comp(c)} {op_assign} b;" for c in comps]) + lines.append( + f" __device__ __forceinline__ {vec_name}& operator{op_assign}({scalar_type} b) {{ {sv_ops} return *this; }}" + ) + + if enable_bitwise: + for op in ["&", "|", "^", "<<", ">>"]: + op_assign = op + "=" + if has(f"cmpd:{op}=:v"): + vv_ops = _cuda_join_statements([f"{self_comp(c)} {op_assign} {wrap_comp('b', c)};" for c in comps]) + lines.append( + f" __device__ __forceinline__ {vec_name}& operator{op_assign}(const {vec_name}& b) {{ {vv_ops} return *this; }}" + ) + if has(f"cmpd:{op}=:s"): + sv_ops = _cuda_join_statements([f"{self_comp(c)} {op_assign} b;" for c in comps]) + lines.append( + f" __device__ __forceinline__ {vec_name}& operator{op_assign}({scalar_type} b) {{ {sv_ops} return *this; }}" + ) + + lines.append("};") + lines.append( + f'static_assert(sizeof({vec_name}) == sizeof({cuda_native_type}), "{vec_name} size must match {cuda_native_type}");' + ) + lines.append( + f'static_assert(alignof({vec_name}) == alignof({cuda_native_type}), "{vec_name} alignment must match {cuda_native_type}");' + ) + + for op in ["+", "-", "*", "/"]: + if has(f"bin:{op}:vv"): + vv_expr = ", ".join([f"({wrap_comp('a', c)} {op} {wrap_comp('b', c)})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, const {vec_name}& b) {{ return {vec_name}({vv_expr}); }}" + ) + if has(f"bin:{op}:vs"): + vs_expr = ", ".join([f"({wrap_comp('a', c)} {op} b)" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, {scalar_type} b) {{ return {vec_name}({vs_expr}); }}" + ) + if has(f"bin:{op}:sv"): + if op in ["+", "*"]: + sv_expr = ", ".join([f"(a {op} {wrap_comp('b', c)})" for c in comps]) + else: + sv_expr = ", ".join([f"({scalar_type})(a {op} {wrap_comp('b', c)})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} operator{op}({scalar_type} a, const {vec_name}& b) {{ return {vec_name}({sv_expr}); }}" + ) + + if enable_bitwise: + for op in ["&", "|", "^", "<<", ">>"]: + if has(f"bin:{op}:vv"): + vv_expr = ", ".join([f"({wrap_comp('a', c)} {op} {wrap_comp('b', c)})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, const {vec_name}& b) {{ return {vec_name}({vv_expr}); }}" + ) + if has(f"bin:{op}:vs"): + vs_expr = ", ".join([f"({wrap_comp('a', c)} {op} b)" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} operator{op}(const {vec_name}& a, {scalar_type} b) {{ return {vec_name}({vs_expr}); }}" + ) + if has(f"bin:{op}:sv"): + sv_expr = ", ".join([f"({scalar_type})(a {op} {wrap_comp('b', c)})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} operator{op}({scalar_type} a, const {vec_name}& b) {{ return {vec_name}({sv_expr}); }}" + ) + + return "\n".join(lines) + + +def _cuda_emit_vec_helper(helper_suffix: str, vec_name: str, scalar_type: str, dim: int) -> str: + comps = _cuda_vec_components(dim) + args = ", ".join([f"{scalar_type} {c}" for c in comps]) + ctor_args = ", ".join(comps) + member_guard = ", ".join([f"(void)(((const TVec*)0)->{c})" for c in comps]) + return "\n".join( + [ + f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}({args}) {{ return {vec_name}({ctor_args}); }}", + f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}({scalar_type} x) {{ return {vec_name}(x); }}", + f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}(const {vec_name}& v) {{ return v; }}", + f"template ", + f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}(const TVec& v) {{ return {vec_name}(v); }}", + ] + ) + + +def _cuda_emit_vec_wrapper_conversion_helpers( + helper_suffix: str, + vec_name: str, + scalar_type: str, + dim: int, + *, + available_keys: Optional[Set[str]] = None, +) -> str: + comps = _cuda_vec_components(dim) + dim_keys = [key for key in _CUDA_VEC_TYPE_SPECS if key.endswith(str(dim))] + if available_keys is not None: + dim_keys = [key for key in dim_keys if key in available_keys] + + lines: List[str] = [] + for src_key in dim_keys: + if src_key == helper_suffix: + continue + src_vec_name = _CUDA_VEC_TYPE_SPECS[src_key][0] + ctor_args = ", ".join([f"({scalar_type})src.v.{c}" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} vkdispatch_make_{helper_suffix}(const {src_vec_name}& src) {{ return {vec_name}({ctor_args}); }}" + ) + + return "\n".join(lines) + + +def _cuda_emit_mat_type(mat_name: str, vec_name: str, dim: int, needed_ops: Optional[Set[str]] = None) -> str: + cols = [f"c{i}" for i in range(dim)] + if needed_ops is None: + needed_ops = { + "un:-", + "cmpd:+=:m", + "cmpd:+=:s", + "cmpd:-=:m", + "cmpd:-=:s", + "cmpd:*=:s", + "cmpd:/=:s", + "bin:+:mm", + "bin:+:ms", + "bin:+:sm", + "bin:-:mm", + "bin:-:ms", + "bin:-:sm", + "bin:*:ms", + "bin:*:sm", + "bin:/:ms", + "bin:/:sm", + "bin:*:mv", + "bin:*:vm", + "bin:*:mm", + } + + def has(token: str) -> bool: + return token in needed_ops + + lines: List[str] = [f"struct {mat_name} {{"] + lines.extend([f" {vec_name} {c};" for c in cols]) + lines.append("") + lines.append(f" __device__ __forceinline__ {mat_name}() = default;") + ctor_args = ", ".join([f"{vec_name} {c}_" for c in cols]) + ctor_init = ", ".join([f"{c}({c}_)" for c in cols]) + lines.append(f" __device__ __forceinline__ {mat_name}({ctor_args}) : {ctor_init} {{}}") + + zero = "0.0f" + diag_init = ", ".join( + [f"c{col_idx}({vec_name}({', '.join(['s' if row_idx == col_idx else zero for row_idx in range(dim)])}))" for col_idx in range(dim)] + ) + lines.append(f" __device__ __forceinline__ explicit {mat_name}(float s) : {diag_init} {{}}") + lines.append(f" __device__ __forceinline__ {vec_name}& operator[](int i) {{ return (&c0)[i]; }}") + lines.append(f" __device__ __forceinline__ const {vec_name}& operator[](int i) const {{ return (&c0)[i]; }}") + if has("un:-"): + lines.append( + f" __device__ __forceinline__ {mat_name} operator-() const {{ return {mat_name}({', '.join([f'-c{i}' for i in range(dim)])}); }}" + ) + + for op in ["+", "-"]: + op_assign = op + "=" + if has(f"cmpd:{op}=:m"): + mm_ops = _cuda_join_statements([f"c{i} {op_assign} b.c{i};" for i in range(dim)]) + lines.append( + f" __device__ __forceinline__ {mat_name}& operator{op_assign}(const {mat_name}& b) {{ {mm_ops} return *this; }}" + ) + if has(f"cmpd:{op}=:s"): + ms_ops = _cuda_join_statements([f"c{i} {op_assign} b;" for i in range(dim)]) + lines.append( + f" __device__ __forceinline__ {mat_name}& operator{op_assign}(float b) {{ {ms_ops} return *this; }}" + ) + + for op in ["*", "/"]: + op_assign = op + "=" + if has(f"cmpd:{op}=:s"): + ms_ops = _cuda_join_statements([f"c{i} {op_assign} b;" for i in range(dim)]) + lines.append( + f" __device__ __forceinline__ {mat_name}& operator{op_assign}(float b) {{ {ms_ops} return *this; }}" + ) + + lines.append("};") + + for op in ["+", "-"]: + if has(f"bin:{op}:mm"): + cols_expr = ", ".join([f"(a.c{i} {op} b.c{i})" for i in range(dim)]) + lines.append( + f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}" + ) + if has(f"bin:{op}:ms"): + cols_expr = ", ".join([f"(a.c{i} {op} b)" for i in range(dim)]) + lines.append( + f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, float b) {{ return {mat_name}({cols_expr}); }}" + ) + if has(f"bin:{op}:sm"): + cols_expr = ", ".join([f"(a {op} b.c{i})" for i in range(dim)]) + lines.append( + f"__device__ __forceinline__ {mat_name} operator{op}(float a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}" + ) + + for op in ["*", "/"]: + if has(f"bin:{op}:ms"): + cols_expr = ", ".join([f"(a.c{i} {op} b)" for i in range(dim)]) + lines.append( + f"__device__ __forceinline__ {mat_name} operator{op}(const {mat_name}& a, float b) {{ return {mat_name}({cols_expr}); }}" + ) + if has(f"bin:{op}:sm"): + cols_expr = ", ".join([f"(a {op} b.c{i})" for i in range(dim)]) + lines.append( + f"__device__ __forceinline__ {mat_name} operator{op}(float a, const {mat_name}& b) {{ return {mat_name}({cols_expr}); }}" + ) + + vec_comps = _cuda_vec_components(dim) + if has("bin:*:mv"): + mat_vec_terms = [f"(m.c{i} * v.v.{vec_comps[i]})" for i in range(dim)] + mat_vec_expr = " + ".join(mat_vec_terms) + lines.append( + f"__device__ __forceinline__ {vec_name} operator* (const {mat_name}& m, const {vec_name}& v) {{ return {mat_vec_expr}; }}" + ) + + if has("bin:*:vm"): + row_exprs: List[str] = [] + for col_idx in range(dim): + terms = [f"(v.v.{vec_comps[row_idx]} * m.c{col_idx}.v.{vec_comps[row_idx]})" for row_idx in range(dim)] + row_exprs.append(" + ".join(terms)) + lines.append( + f"__device__ __forceinline__ {vec_name} operator* (const {vec_name}& v, const {mat_name}& m) {{ return {vec_name}({', '.join(row_exprs)}); }}" + ) + + if has("bin:*:mm"): + col_products = ", ".join([f"(a * b.c{i})" for i in range(dim)]) + lines.append( + f"__device__ __forceinline__ {mat_name} operator* (const {mat_name}& a, const {mat_name}& b) {{ return {mat_name}({col_products}); }}" + ) + + return "\n".join(lines) + + +def _cuda_emit_mat_helpers(mat_name: str, helper_suffix: str, vec_name: str, vec_helper_suffix: str, dim: int) -> str: + col_type = vec_name + col_args = ", ".join([f"{col_type} c{i}" for i in range(dim)]) + col_ctor = ", ".join([f"c{i}" for i in range(dim)]) + + flat_names = [f"m{col}{row}" for col in range(dim) for row in range(dim)] + flat_args = ", ".join([f"float {name}" for name in flat_names]) + flat_cols: List[str] = [] + for col in range(dim): + values = [f"m{col}{row}" for row in range(dim)] + flat_cols.append(f"vkdispatch_make_{vec_helper_suffix}({', '.join(values)})") + flat_ctor = ", ".join(flat_cols) + + cast_cols = ", ".join([f"vkdispatch_make_{vec_helper_suffix}(m[{i}])" for i in range(dim)]) + + return "\n".join( + [ + f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}({col_args}) {{ return {mat_name}({col_ctor}); }}", + f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}(float s) {{ return {mat_name}(s); }}", + f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}({flat_args}) {{ return {mat_name}({flat_ctor}); }}", + "template ", + f"__device__ __forceinline__ {mat_name} vkdispatch_make_{helper_suffix}(TMat m) {{ return {mat_name}({cast_cols}); }}", + ] + ) + + +def _cuda_emit_subgroup_shuffle_xor_vec_overloads(vec_keys: Set[str]) -> str: + lines: List[str] = [] + + for key in _CUDA_VEC_ORDER: + if key not in vec_keys: + continue + + vec_name, _, dim, _, _, _ = _CUDA_VEC_TYPE_SPECS[key] + comps = _cuda_vec_components(dim) + comp_exprs = ", ".join([f"__shfl_xor_sync(mask, value.v.{c}, lane_mask)" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} vkdispatch_subgroup_shuffle_xor(unsigned int mask, const {vec_name}& value, int lane_mask) " + f"{{ return vkdispatch_make_{key}({comp_exprs}); }}" + ) + + return "\n".join(lines) diff --git a/vkdispatch/codegen/backends/cuda/helper_snippets.py b/vkdispatch/codegen/backends/cuda/helper_snippets.py new file mode 100644 index 00000000..93fa3eeb --- /dev/null +++ b/vkdispatch/codegen/backends/cuda/helper_snippets.py @@ -0,0 +1,287 @@ +from typing import Dict, List + + +_HELPER_SNIPPETS: Dict[str, str] = { + "composite_types": "", + "mat2_type": "", + "mat3_type": "", + "mat4_type": "", + "make_mat2": "", + "make_mat3": "", + "make_mat4": "", + "make_short2": "", + "make_short3": "", + "make_short4": "", + "make_ushort2": "", + "make_ushort3": "", + "make_ushort4": "", + "make_int2": "", + "make_int3": "", + "make_int4": "", + "make_uint2": "", + "make_uint3": "", + "make_uint4": "", + "make_half2": "", + "make_half3": "", + "make_half4": "", + "float2_ops": "", + "make_float2": "", + "make_float3": "", + "make_float4": "", + "make_double2": "", + "make_double3": "", + "make_double4": "", + "global_invocation_id": ( + "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_global_invocation_id() {\n" + " return vkdispatch_uint3(\n" + " (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x),\n" + " (unsigned int)(blockIdx.y * blockDim.y + threadIdx.y),\n" + " (unsigned int)(blockIdx.z * blockDim.z + threadIdx.z)\n" + " );\n" + "}" + ), + "local_invocation_id": ( + "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_local_invocation_id() {\n" + " return vkdispatch_uint3((unsigned int)threadIdx.x, (unsigned int)threadIdx.y, (unsigned int)threadIdx.z);\n" + "}" + ), + "workgroup_id": ( + "__device__ __forceinline__ vkdispatch_uint3 vkdispatch_workgroup_id() {\n" + " return vkdispatch_uint3((unsigned int)blockIdx.x, (unsigned int)blockIdx.y, (unsigned int)blockIdx.z);\n" + "}" + ), + "local_invocation_index": ( + "__device__ __forceinline__ unsigned int vkdispatch_local_invocation_index() {\n" + " return (unsigned int)(threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z));\n" + "}" + ), + "subgroup_size": "__device__ __forceinline__ unsigned int vkdispatch_subgroup_size() { return (unsigned int)warpSize; }", + "num_subgroups": ( + "__device__ __forceinline__ unsigned int vkdispatch_num_subgroups() {\n" + " unsigned int local_count = (unsigned int)(blockDim.x * blockDim.y * blockDim.z);\n" + " return (local_count + vkdispatch_subgroup_size() - 1u) / vkdispatch_subgroup_size();\n" + "}" + ), + "subgroup_id": ( + "__device__ __forceinline__ unsigned int vkdispatch_subgroup_id() {\n" + " return vkdispatch_local_invocation_index() / vkdispatch_subgroup_size();\n" + "}" + ), + "subgroup_invocation_id": ( + "__device__ __forceinline__ unsigned int vkdispatch_subgroup_invocation_id() {\n" + " return vkdispatch_local_invocation_index() % vkdispatch_subgroup_size();\n" + "}" + ), + "subgroup_shuffle_xor": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_shuffle_xor(unsigned int mask, T value, int lane_mask) {\n" + " return __shfl_xor_sync(mask, value, lane_mask);\n" + "}" + ), + "subgroup_add": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_add(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " value = value + vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " }\n" + " return value;\n" + "}" + ), + "subgroup_mul": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_mul(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " value = value * vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " }\n" + " return value;\n" + "}" + ), + "subgroup_min": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_min(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " T other = vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " value = other < value ? other : value;\n" + " }\n" + " return value;\n" + "}" + ), + "subgroup_max": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_max(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " T other = vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " value = other > value ? other : value;\n" + " }\n" + " return value;\n" + "}" + ), + "subgroup_and": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_and(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " value = value & vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " }\n" + " return value;\n" + "}" + ), + "subgroup_or": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_or(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " value = value | vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " }\n" + " return value;\n" + "}" + ), + "subgroup_xor": ( + "template \n" + "__device__ __forceinline__ T vkdispatch_subgroup_xor(T value) {\n" + " unsigned int mask = 0xffffffffu;\n" + " for (unsigned int offset = vkdispatch_subgroup_size() >> 1; offset > 0u; offset >>= 1u) {\n" + " value = value ^ vkdispatch_subgroup_shuffle_xor(mask, value, (int)offset);\n" + " }\n" + " return value;\n" + "}" + ), + "mod": ( + "__device__ __forceinline__ float mod(float x, float y) { return fmodf(x, y); }\n" + "__device__ __forceinline__ double mod(double x, double y) { return fmod(x, y); }" + ), + "fract": ( + "__device__ __forceinline__ float fract(float x) { return x - floorf(x); }\n" + "__device__ __forceinline__ double fract(double x) { return x - floor(x); }" + ), + "roundEven": ( + "__device__ __forceinline__ float roundEven(float x) { return nearbyintf(x); }\n" + "__device__ __forceinline__ double roundEven(double x) { return nearbyint(x); }" + ), + "mix": ( + "__device__ __forceinline__ float mix(float x, float y, float a) { return x + (y - x) * a; }\n" + "__device__ __forceinline__ double mix(double x, double y, double a) { return x + (y - x) * a; }" + ), + "step": ( + "__device__ __forceinline__ float step(float edge, float x) { return x < edge ? 0.0f : 1.0f; }\n" + "__device__ __forceinline__ double step(double edge, double x) { return x < edge ? 0.0 : 1.0; }" + ), + "smoothstep": ( + "__device__ __forceinline__ float smoothstep(float edge0, float edge1, float x) {\n" + " float t = fminf(fmaxf((x - edge0) / (edge1 - edge0), 0.0f), 1.0f);\n" + " return t * t * (3.0f - 2.0f * t);\n" + "}\n" + "__device__ __forceinline__ double smoothstep(double edge0, double edge1, double x) {\n" + " double t = fmin(fmax((x - edge0) / (edge1 - edge0), 0.0), 1.0);\n" + " return t * t * (3.0 - 2.0 * t);\n" + "}" + ), + "radians": ( + "__device__ __forceinline__ float radians(float x) { return x * (3.14159265358979323846f / 180.0f); }\n" + "__device__ __forceinline__ double radians(double x) { return x * (3.14159265358979323846 / 180.0); }" + ), + "degrees": ( + "__device__ __forceinline__ float degrees(float x) { return x * (180.0f / 3.14159265358979323846f); }\n" + "__device__ __forceinline__ double degrees(double x) { return x * (180.0 / 3.14159265358979323846); }" + ), + "inversesqrt": ( + "__device__ __forceinline__ float inversesqrt(float x) { return rsqrtf(x); }\n" + "__device__ __forceinline__ double inversesqrt(double x) { return rsqrt(x); }" + ), + "floatBitsToInt": "__device__ __forceinline__ int floatBitsToInt(float x) { return __float_as_int(x); }", + "floatBitsToUint": "__device__ __forceinline__ unsigned int floatBitsToUint(float x) { return __float_as_uint(x); }", + "intBitsToFloat": "__device__ __forceinline__ float intBitsToFloat(int x) { return __int_as_float(x); }", + "uintBitsToFloat": "__device__ __forceinline__ float uintBitsToFloat(unsigned int x) { return __uint_as_float(x); }", + "longlong_as_double": "__device__ __forceinline__ double longlong_as_double(long long x) { return __longlong_as_double(x); }", + "ushort_as_half": "__device__ __forceinline__ __half ushort_as_half(unsigned short x) { __half h; *reinterpret_cast(&h) = x; return h; }", + "sample_texture": "", +} + +_HELPER_ORDER: List[str] = [ + "composite_types", + "global_invocation_id", + "local_invocation_id", + "workgroup_id", + "local_invocation_index", + "subgroup_size", + "num_subgroups", + "subgroup_id", + "subgroup_invocation_id", + "subgroup_shuffle_xor", + "subgroup_add", + "subgroup_mul", + "subgroup_min", + "subgroup_max", + "subgroup_and", + "subgroup_or", + "subgroup_xor", + "mod", + "fract", + "roundEven", + "mix", + "step", + "smoothstep", + "radians", + "degrees", + "inversesqrt", + "floatBitsToInt", + "floatBitsToUint", + "intBitsToFloat", + "uintBitsToFloat", + "longlong_as_double", + "ushort_as_half", + "sample_texture", +] + +_HELPER_DEPENDENCIES: Dict[str, List[str]] = { + "mat2_type": ["composite_types"], + "mat3_type": ["composite_types"], + "mat4_type": ["composite_types"], + "make_mat2": ["composite_types"], + "make_mat3": ["composite_types"], + "make_mat4": ["composite_types"], + "make_short2": ["composite_types"], + "make_short3": ["composite_types"], + "make_short4": ["composite_types"], + "make_ushort2": ["composite_types"], + "make_ushort3": ["composite_types"], + "make_ushort4": ["composite_types"], + "make_int2": ["composite_types"], + "make_int3": ["composite_types"], + "make_int4": ["composite_types"], + "make_uint2": ["composite_types"], + "make_uint3": ["composite_types"], + "make_uint4": ["composite_types"], + "make_half2": ["composite_types"], + "make_half3": ["composite_types"], + "make_half4": ["composite_types"], + "float2_ops": ["composite_types"], + "make_float2": ["composite_types"], + "make_float3": ["composite_types"], + "make_float4": ["composite_types"], + "make_double2": ["composite_types"], + "make_double3": ["composite_types"], + "make_double4": ["composite_types"], + "global_invocation_id": ["composite_types"], + "local_invocation_id": ["composite_types"], + "workgroup_id": ["composite_types"], + "sample_texture": ["composite_types"], + "num_subgroups": ["subgroup_size"], + "subgroup_id": ["local_invocation_index", "subgroup_size"], + "subgroup_invocation_id": ["local_invocation_index", "subgroup_size"], + "subgroup_add": ["subgroup_size", "subgroup_shuffle_xor"], + "subgroup_mul": ["subgroup_size", "subgroup_shuffle_xor"], + "subgroup_min": ["subgroup_size", "subgroup_shuffle_xor"], + "subgroup_max": ["subgroup_size", "subgroup_shuffle_xor"], + "subgroup_and": ["subgroup_size", "subgroup_shuffle_xor"], + "subgroup_or": ["subgroup_size", "subgroup_shuffle_xor"], + "subgroup_xor": ["subgroup_size", "subgroup_shuffle_xor"], +} + + +def initialize_feature_usage() -> Dict[str, bool]: + return {feature_name: False for feature_name in _HELPER_SNIPPETS} diff --git a/vkdispatch/codegen/backends/cuda/math_utils.py b/vkdispatch/codegen/backends/cuda/math_utils.py new file mode 100644 index 00000000..fc5ce5ad --- /dev/null +++ b/vkdispatch/codegen/backends/cuda/math_utils.py @@ -0,0 +1,174 @@ +from typing import Dict, List, Optional, Set + +import vkdispatch.base.dtype as dtypes + +from .composite_emitters import _cuda_vec_components +from .specs import _CUDA_VEC_TYPE_SPECS, _FLOAT_VEC_HELPER_SUFFIX_MAP + + +def cuda_fast_unary_math_name(func_name: str) -> str: + if func_name == "sin": + return "__sinf" + if func_name == "cos": + return "__cosf" + if func_name == "tan": + return "__tanf" + if func_name == "exp": + return "__expf" + if func_name == "exp2": + return "__exp2f" + if func_name == "log": + return "__logf" + if func_name == "log2": + return "__log2f" + if func_name == "asin": + return "asinf" + if func_name == "acos": + return "acosf" + if func_name == "atan": + return "atanf" + if func_name == "sinh": + return "sinhf" + if func_name == "cosh": + return "coshf" + if func_name == "tanh": + return "tanhf" + if func_name == "asinh": + return "asinhf" + if func_name == "acosh": + return "acoshf" + if func_name == "atanh": + return "atanhf" + if func_name == "sqrt": + return "sqrtf" + + return func_name + + +def cuda_fast_binary_math_name(func_name: str) -> str: + if func_name == "atan2": + return "atan2f" + if func_name == "pow": + return "__powf" + + return func_name + + +def cuda_scalar_unary_math_name(func_name: str, scalar_type: str) -> str: + if scalar_type == "__half": + half_math = { + "sin": "hsin", + "cos": "hcos", + "exp": "hexp", + "exp2": "hexp2", + "log": "hlog", + "log2": "hlog2", + "sqrt": "hsqrt", + } + return half_math.get(func_name, func_name) + if scalar_type == "double": + return func_name + return cuda_fast_unary_math_name(func_name) + + +def cuda_scalar_binary_math_name(func_name: str, scalar_type: str) -> str: + if scalar_type == "__half": + return func_name + if scalar_type == "double": + return func_name + return cuda_fast_binary_math_name(func_name) + + +def cuda_float_vec_components_for_suffix(helper_suffix: str) -> List[str]: + dim_char = helper_suffix[-1] + if dim_char == "2": + return ["x", "y"] + if dim_char == "3": + return ["x", "y", "z"] + if dim_char == "4": + return ["x", "y", "z", "w"] + + raise ValueError(f"Unsupported CUDA float vector helper suffix '{helper_suffix}'") + + +def cuda_float_vec_helper_suffix(var_type: dtypes.dtype) -> Optional[str]: + return _FLOAT_VEC_HELPER_SUFFIX_MAP.get(var_type) + + +def emit_used_vec_math_helpers( + composite_vec_unary_math_usage: Dict[str, Set[str]], + composite_vec_binary_math_usage: Dict[str, Set[str]], +) -> str: + helper_sections: List[str] = [] + + unary_order = [ + "sin", + "cos", + "tan", + "asin", + "acos", + "atan", + "sinh", + "cosh", + "tanh", + "asinh", + "acosh", + "atanh", + "exp", + "exp2", + "log", + "log2", + "sqrt", + ] + binary_order = ["atan2", "pow"] + signature_order = ["vv", "vs", "sv"] + + for key in ["half2", "half3", "half4", "float2", "float3", "float4", "double2", "double3", "double4"]: + unary_funcs = composite_vec_unary_math_usage.get(key, set()) + binary_tokens = composite_vec_binary_math_usage.get(key, set()) + if len(unary_funcs) == 0 and len(binary_tokens) == 0: + continue + + if key not in _CUDA_VEC_TYPE_SPECS: + continue + + vec_name, scalar_type, dim, _, _, _ = _CUDA_VEC_TYPE_SPECS[key] + comps = _cuda_vec_components(dim) + lines: List[str] = [] + + for func_name in unary_order: + if func_name not in unary_funcs: + continue + scalar_func = cuda_scalar_unary_math_name(func_name, scalar_type) + comp_args = ", ".join([f"{scalar_func}(v.v.{c})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} {func_name}(const {vec_name}& v) {{ return vkdispatch_make_{key}({comp_args}); }}" + ) + + for func_name in binary_order: + scalar_func = cuda_scalar_binary_math_name(func_name, scalar_type) + for signature in signature_order: + token = f"{func_name}:{signature}" + if token not in binary_tokens: + continue + + if signature == "vv": + comp_args = ", ".join([f"{scalar_func}(a.v.{c}, b.v.{c})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} {func_name}(const {vec_name}& a, const {vec_name}& b) {{ return vkdispatch_make_{key}({comp_args}); }}" + ) + elif signature == "vs": + comp_args = ", ".join([f"{scalar_func}(a.v.{c}, b)" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} {func_name}(const {vec_name}& a, {scalar_type} b) {{ return vkdispatch_make_{key}({comp_args}); }}" + ) + elif signature == "sv": + comp_args = ", ".join([f"{scalar_func}(a, b.v.{c})" for c in comps]) + lines.append( + f"__device__ __forceinline__ {vec_name} {func_name}({scalar_type} a, const {vec_name}& b) {{ return vkdispatch_make_{key}({comp_args}); }}" + ) + + if len(lines) > 0: + helper_sections.append("\n".join(lines)) + + return "\n\n".join(helper_sections) diff --git a/vkdispatch/codegen/backends/cuda/specs.py b/vkdispatch/codegen/backends/cuda/specs.py new file mode 100644 index 00000000..c029b5b0 --- /dev/null +++ b/vkdispatch/codegen/backends/cuda/specs.py @@ -0,0 +1,120 @@ +from typing import Dict, FrozenSet, Tuple + +import vkdispatch.base.dtype as dtypes + + +_CUDA_VEC_TYPE_SPECS: Dict[str, Tuple[str, str, int, str, bool, bool]] = { + "short2": ("vkdispatch_short2", "short", 2, "short2", True, True), + "short3": ("vkdispatch_short3", "short", 3, "short3", True, True), + "short4": ("vkdispatch_short4", "short", 4, "short4", True, True), + "ushort2": ("vkdispatch_ushort2", "unsigned short", 2, "ushort2", False, True), + "ushort3": ("vkdispatch_ushort3", "unsigned short", 3, "ushort3", False, True), + "ushort4": ("vkdispatch_ushort4", "unsigned short", 4, "ushort4", False, True), + "int2": ("vkdispatch_int2", "int", 2, "int2", True, True), + "int3": ("vkdispatch_int3", "int", 3, "int3", True, True), + "int4": ("vkdispatch_int4", "int", 4, "int4", True, True), + "uint2": ("vkdispatch_uint2", "unsigned int", 2, "uint2", False, True), + "uint3": ("vkdispatch_uint3", "unsigned int", 3, "uint3", False, True), + "uint4": ("vkdispatch_uint4", "unsigned int", 4, "uint4", False, True), + "half2": ("vkdispatch_half2", "__half", 2, "half2", True, False), + "half3": ("vkdispatch_half3", "__half", 3, "half3", True, False), + "half4": ("vkdispatch_half4", "__half", 4, "half4", True, False), + "float2": ("vkdispatch_float2", "float", 2, "float2", True, False), + "float3": ("vkdispatch_float3", "float", 3, "float3", True, False), + "float4": ("vkdispatch_float4", "float", 4, "float4", True, False), + "double2": ("vkdispatch_double2", "double", 2, "double2", True, False), + "double3": ("vkdispatch_double3", "double", 3, "double3", True, False), + "double4": ("vkdispatch_double4", "double", 4, "double4", True, False), +} + +_CUDA_MAT_TYPE_SPECS: Dict[str, Tuple[str, str, str, int]] = { + "mat2": ("vkdispatch_mat2", "vkdispatch_float2", "float2", 2), + "mat3": ("vkdispatch_mat3", "vkdispatch_float3", "float3", 3), + "mat4": ("vkdispatch_mat4", "vkdispatch_float4", "float4", 4), +} + +_CUDA_VEC_ORDER = [ + "short2", "short3", "short4", + "ushort2", "ushort3", "ushort4", + "int2", "int3", "int4", + "uint2", "uint3", "uint4", + "half2", "half3", "half4", + "float2", "float3", "float4", + "double2", "double3", "double4", +] + +_CUDA_MAT_ORDER = ["mat2", "mat3", "mat4"] + +_DTYPE_TO_COMPOSITE_KEY = { + dtypes.ihvec2: "short2", + dtypes.ihvec3: "short3", + dtypes.ihvec4: "short4", + dtypes.uhvec2: "ushort2", + dtypes.uhvec3: "ushort3", + dtypes.uhvec4: "ushort4", + dtypes.ivec2: "int2", + dtypes.ivec3: "int3", + dtypes.ivec4: "int4", + dtypes.uvec2: "uint2", + dtypes.uvec3: "uint3", + dtypes.uvec4: "uint4", + dtypes.hvec2: "half2", + dtypes.hvec3: "half3", + dtypes.hvec4: "half4", + dtypes.complex32: "half2", + dtypes.complex64: "float2", + dtypes.complex128: "double2", + dtypes.vec2: "float2", + dtypes.vec3: "float3", + dtypes.vec4: "float4", + dtypes.dvec2: "double2", + dtypes.dvec3: "double3", + dtypes.dvec4: "double4", + dtypes.mat2: "mat2", + dtypes.mat3: "mat3", + dtypes.mat4: "mat4", +} + +_SCALAR_TYPE_NAMES = { + dtypes.int16: "short", + dtypes.uint16: "unsigned short", + dtypes.int32: "int", + dtypes.uint32: "unsigned int", + dtypes.int64: "long long", + dtypes.uint64: "unsigned long long", + dtypes.float16: "__half", + dtypes.float32: "float", + dtypes.float64: "double", +} + +_FLOAT_VEC_DTYPES: FrozenSet[dtypes.dtype] = frozenset( + { + dtypes.complex32, + dtypes.complex64, + dtypes.complex128, + dtypes.hvec2, + dtypes.hvec3, + dtypes.hvec4, + dtypes.vec2, + dtypes.vec3, + dtypes.vec4, + dtypes.dvec2, + dtypes.dvec3, + dtypes.dvec4, + } +) + +_FLOAT_VEC_HELPER_SUFFIX_MAP = { + dtypes.hvec2: "half2", + dtypes.hvec3: "half3", + dtypes.hvec4: "half4", + dtypes.complex32: "half2", + dtypes.complex64: "float2", + dtypes.complex128: "double2", + dtypes.vec2: "float2", + dtypes.vec3: "float3", + dtypes.vec4: "float4", + dtypes.dvec2: "double2", + dtypes.dvec3: "double3", + dtypes.dvec4: "double4", +} diff --git a/vkdispatch/codegen/backends/glsl.py b/vkdispatch/codegen/backends/glsl.py new file mode 100644 index 00000000..c2187e06 --- /dev/null +++ b/vkdispatch/codegen/backends/glsl.py @@ -0,0 +1,235 @@ +from typing import List, Optional, Set + +import vkdispatch.base.dtype as dtypes + +from .base import CodeGenBackend + +# Map scalar dtypes to GLSL extension names. +_GLSL_TYPE_EXTENSIONS = { + dtypes.float16: "GL_EXT_shader_explicit_arithmetic_types_float16", + dtypes.int16: "GL_EXT_shader_explicit_arithmetic_types_int16", + dtypes.uint16: "GL_EXT_shader_explicit_arithmetic_types_int16", + dtypes.int64: "GL_ARB_gpu_shader_int64", + dtypes.uint64: "GL_ARB_gpu_shader_int64", + dtypes.float64: "GL_ARB_gpu_shader_fp64", +} + + +class GLSLBackend(CodeGenBackend): + name = "glsl" + + def __init__(self) -> None: + super().__init__() + self._needed_extensions: Set[str] = set() + + def reset_state(self) -> None: + self._needed_extensions = set() + + def _track_type_extension(self, var_type: dtypes.dtype) -> None: + """Record the GLSL extension required by *var_type* (if any).""" + scalar = var_type + if dtypes.is_vector(var_type) or dtypes.is_matrix(var_type): + scalar = var_type.scalar + elif dtypes.is_complex(var_type): + scalar = var_type.child_type + ext = _GLSL_TYPE_EXTENSIONS.get(scalar) + if ext is not None: + self._needed_extensions.add(ext) + + def type_name(self, var_type: dtypes.dtype) -> str: + self._track_type_extension(var_type) + return var_type.glsl_type + + def constructor( + self, + var_type: dtypes.dtype, + args: List[str], + arg_types: Optional[List[Optional[dtypes.dtype]]] = None, + ) -> str: + _ = arg_types + return f"{self.type_name(var_type)}({', '.join(args)})" + + def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str: + header = "#version 450\n" + header += "#extension GL_EXT_scalar_block_layout : require\n" + + if enable_subgroup_ops: + header += "#extension GL_KHR_shader_subgroup_arithmetic : require\n" + + if enable_printf: + header += "#extension GL_EXT_debug_printf : require\n" + + ext_block = "" + for ext in sorted(self._needed_extensions): + ext_line = f"#extension {ext} : require\n" + if ext_line not in header: + ext_block += ext_line + + return header + ext_block + + def make_source(self, header: str, body: str, x: int, y: int, z: int) -> str: + layout_str = f"layout(local_size_x = {x}, local_size_y = {y}, local_size_z = {z}) in;" + + return f"{header}\n{layout_str}\n{body}" + + def constant_namespace(self) -> str: + return "UBO" + + def variable_namespace(self) -> str: + return "PC" + + def exec_bounds_guard(self, exec_count_expr: str) -> str: + return f"if(any(lessThanEqual({exec_count_expr}.xyz, {self.global_invocation_id_expr()}))) {{ return; }}\n" + + def shared_buffer_declaration(self, var_type: dtypes.dtype, name: str, size: int) -> str: + return f"shared {self.type_name(var_type)} {name}[{size}];" + + def uniform_block_declaration(self, contents: str) -> str: + return f"\nlayout(set = 0, binding = 0, scalar) uniform UniformObjectBuffer {{\n{contents}\n}} UBO;\n" + + def storage_buffer_declaration(self, binding: int, var_type: dtypes.dtype, name: str) -> str: + return f"layout(set = 0, binding = {binding}, scalar) buffer Buffer{binding} {{ {self.type_name(var_type)} data[]; }} {name};\n" + + def sampler_declaration(self, binding: int, dimensions: int, name: str) -> str: + return f"layout(set = 0, binding = {binding}) uniform sampler{dimensions}D {name};\n" + + def push_constant_declaration(self, contents: str) -> str: + return f"\nlayout(push_constant, scalar) uniform PushConstant {{\n{contents}\n}} PC;\n" + + def entry_point(self, body_contents: str) -> str: + return f"void main() {{\n{body_contents}}}\n" + + def inf_f32_expr(self) -> str: + return "uintBitsToFloat(0x7F800000)" + + def ninf_f32_expr(self) -> str: + return "uintBitsToFloat(0xFF800000)" + + def inf_f64_expr(self) -> str: + return "packDouble2x32(uvec2(0x00000000u, 0x7FF00000u))" + + def ninf_f64_expr(self) -> str: + return "packDouble2x32(uvec2(0x00000000u, 0xFFF00000u))" + + def inf_f16_expr(self) -> str: + return "float16_t(uintBitsToFloat(0x7F800000))" + + def ninf_f16_expr(self) -> str: + return "float16_t(uintBitsToFloat(0xFF800000))" + + def float_bits_to_int_expr(self, var_expr: str) -> str: + return f"floatBitsToInt({var_expr})" + + def float_bits_to_uint_expr(self, var_expr: str) -> str: + return f"floatBitsToUint({var_expr})" + + def int_bits_to_float_expr(self, var_expr: str) -> str: + return f"intBitsToFloat({var_expr})" + + def uint_bits_to_float_expr(self, var_expr: str) -> str: + return f"uintBitsToFloat({var_expr})" + + def global_invocation_id_expr(self) -> str: + return "gl_GlobalInvocationID" + + def local_invocation_id_expr(self) -> str: + return "gl_LocalInvocationID" + + def local_invocation_index_expr(self) -> str: + return "gl_LocalInvocationIndex" + + def workgroup_id_expr(self) -> str: + return "gl_WorkGroupID" + + def workgroup_size_expr(self) -> str: + return "gl_WorkGroupSize" + + def num_workgroups_expr(self) -> str: + return "gl_NumWorkGroups" + + def num_subgroups_expr(self) -> str: + return "gl_NumSubgroups" + + def subgroup_id_expr(self) -> str: + return "gl_SubgroupID" + + def subgroup_size_expr(self) -> str: + return "gl_SubgroupSize" + + def subgroup_invocation_id_expr(self) -> str: + return "gl_SubgroupInvocationID" + + def barrier_statement(self) -> str: + return "barrier();" + + def memory_barrier_statement(self) -> str: + return "memoryBarrier();" + + def memory_barrier_buffer_statement(self) -> str: + return "memoryBarrierBuffer();" + + def memory_barrier_shared_statement(self) -> str: + return "memoryBarrierShared();" + + def memory_barrier_image_statement(self) -> str: + return "memoryBarrierImage();" + + def group_memory_barrier_statement(self) -> str: + return "groupMemoryBarrier();" + + def subgroup_add_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + return f"subgroupAdd({arg_expr})" + + def subgroup_mul_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + return f"subgroupMul({arg_expr})" + + def subgroup_min_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + return f"subgroupMin({arg_expr})" + + def subgroup_max_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + return f"subgroupMax({arg_expr})" + + def subgroup_and_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + return f"subgroupAnd({arg_expr})" + + def subgroup_or_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + return f"subgroupOr({arg_expr})" + + def subgroup_xor_expr(self, arg_expr: str, arg_type: Optional[dtypes.dtype] = None) -> str: + _ = arg_type + return f"subgroupXor({arg_expr})" + + def subgroup_elect_expr(self) -> str: + return "subgroupElect()" + + def subgroup_barrier_statement(self) -> str: + return "subgroupBarrier();" + + def printf_statement(self, fmt: str, args: List[str]) -> str: + args_suffix = "" + + if len(args) > 0: + args_suffix = ", " + ", ".join(args) + + return f'debugPrintfEXT("{fmt}"{args_suffix});' + + def texture_size_expr(self, texture_expr: str, lod: int, dimensions: int) -> str: + return f"textureSize({texture_expr}, {lod})" + + def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Optional[str] = None) -> str: + if lod_expr is None: + return f"texture({texture_expr}, {coord_expr})" + + return f"texture({texture_expr}, {coord_expr}, {lod_expr})" + + def atomic_add_expr(self, mem_expr: str, value_expr: str, var_type: dtypes.dtype) -> str: + if var_type not in (dtypes.int32, dtypes.uint32): + raise NotImplementedError(f"GLSL atomic_add only supports int32/uint32, got '{var_type.name}'") + + return f"atomicAdd({mem_expr}, {value_expr})" diff --git a/vkdispatch/codegen/backends/opencl.py b/vkdispatch/codegen/backends/opencl.py new file mode 100644 index 00000000..907f0508 --- /dev/null +++ b/vkdispatch/codegen/backends/opencl.py @@ -0,0 +1,701 @@ +from typing import List, Optional, Set + +import vkdispatch.base.dtype as dtypes + +from .base import CodeGenBackend + + +class OpenCLBackend(CodeGenBackend): + name = "opencl" + + _SCALAR_TYPE_NAMES = { + dtypes.int16: "short", + dtypes.uint16: "ushort", + dtypes.int32: "int", + dtypes.uint32: "uint", + dtypes.int64: "long", + dtypes.uint64: "ulong", + dtypes.float16: "half", + dtypes.float32: "float", + dtypes.float64: "double", + } + + _MATRIX_TYPE_NAMES = { + dtypes.mat2: "vkdispatch_mat2", + dtypes.mat3: "vkdispatch_mat3", + dtypes.mat4: "vkdispatch_mat4", + } + + def __init__(self) -> None: + self.reset_state() + + def reset_state(self) -> None: + self._kernel_params: List[str] = [] + self._entry_alias_lines: List[str] = [] + self._shared_buffer_lines: List[str] = [] + self._matrix_type_usage: Set[int] = set() + + def _register_kernel_param(self, param_decl: str) -> None: + if param_decl not in self._kernel_params: + self._kernel_params.append(param_decl) + + def _register_alias_line(self, alias_line: str) -> None: + self._entry_alias_lines.append(alias_line) + + def _record_matrix_dim(self, dim: int) -> None: + if dim not in (2, 3, 4): + raise ValueError(f"Unsupported OpenCL matrix dimension '{dim}'") + self._matrix_type_usage.add(dim) + + def _record_matrix_type(self, var_type: dtypes.dtype) -> None: + if dtypes.is_matrix(var_type): + self._record_matrix_dim(var_type.child_count) + + @staticmethod + def _matrix_helper_name(dim: int, constructor_kind: str) -> str: + return f"vkdispatch_make_mat{dim}_{constructor_kind}" + + def _is_matrix_copy_constructor_arg(self, arg_expr: str, dim: int) -> bool: + stripped = arg_expr.strip() + mat_type = self._matrix_struct_name(dim) + + if stripped.startswith(f"({mat_type})") or stripped.startswith(f"(({mat_type})"): + return True + + if f"vkdispatch_make_mat{dim}_" in stripped: + return True + + if f"vkdispatch_mat{dim}_" in stripped: + return True + + return False + + @classmethod + def _scalar_type_name(cls, scalar_type: dtypes.dtype) -> str: + type_name = cls._SCALAR_TYPE_NAMES.get(scalar_type) + if type_name is None: + raise ValueError(f"Unsupported OpenCL scalar type mapping for '{scalar_type.name}'") + return type_name + + def type_name(self, var_type: dtypes.dtype) -> str: + if dtypes.is_scalar(var_type): + return self._scalar_type_name(var_type) + + if dtypes.is_vector(var_type): + return f"{self._scalar_type_name(var_type.scalar)}{var_type.child_count}" + + if dtypes.is_complex(var_type): + return f"{self._scalar_type_name(var_type.child_type)}2" + + if dtypes.is_matrix(var_type): + self._record_matrix_type(var_type) + matrix_name = self._MATRIX_TYPE_NAMES.get(var_type) + if matrix_name is None: + raise ValueError(f"Unsupported OpenCL matrix type mapping for '{var_type.name}'") + return matrix_name + + raise ValueError(f"Unsupported OpenCL type mapping for '{var_type.name}'") + + def constructor( + self, + var_type: dtypes.dtype, + args: List[str], + arg_types: Optional[List[Optional[dtypes.dtype]]] = None, + ) -> str: + target_type = self.type_name(var_type) + + if dtypes.is_scalar(var_type): + assert len(args) > 0, f"Constructor for scalar type '{var_type.name}' needs at least one argument." + return f"(({target_type})({args[0]}))" + + if dtypes.is_matrix(var_type): + dim = var_type.child_count + assert len(args) in (1, dim, dim * dim), ( + f"Constructor for matrix type '{var_type.name}' needs 1, {dim}, or {dim * dim} arguments." + ) + if len(args) == 1: + single_arg = args[0] + helper_name = self._matrix_helper_name( + dim, + "copy" if self._is_matrix_copy_constructor_arg(single_arg, dim) else "scalar", + ) + return f"{helper_name}({single_arg})" + + if len(args) == dim: + return f"{self._matrix_helper_name(dim, 'cols')}({', '.join(args)})" + + return f"{self._matrix_helper_name(dim, 'flat')}({', '.join(args)})" + + # NVIDIA's OpenCL frontend rejects direct vector casts between different + # vector base types (e.g. uint2 -> float2). Use convert_* builtins when + # we know this is a vector/complex-to-vector/complex conversion. + if ( + len(args) == 1 + and arg_types is not None + and len(arg_types) == 1 + and arg_types[0] is not None + and (dtypes.is_vector(var_type) or dtypes.is_complex(var_type)) + and (dtypes.is_vector(arg_types[0]) or dtypes.is_complex(arg_types[0])) + ): + return f"convert_{target_type}({args[0]})" + + return f"(({target_type})({', '.join(args)}))" + + def component_access_expr(self, expr: str, component: str, base_type: dtypes.dtype) -> str: + if dtypes.is_scalar(base_type) and component == "x": + return expr + return super().component_access_expr(expr, component, base_type) + + def buffer_component_expr( + self, + scalar_buffer_expr: str, + base_type: dtypes.dtype, + element_index_expr: str, + component_index_expr: str, + ) -> Optional[str]: + if dtypes.is_complex(base_type): + component_count = base_type.child_count + elif dtypes.is_vector(base_type): + component_count = base_type.child_count + else: + return None + + return ( + f"{scalar_buffer_expr}[" + f"(({element_index_expr}) * {component_count}) + ({component_index_expr})" + f"]" + ) + + def _cast_math_arg(self, arg_type: dtypes.dtype, arg_expr: str) -> str: + if dtypes.is_scalar(arg_type) or dtypes.is_vector(arg_type) or dtypes.is_complex(arg_type): + return self.constructor(arg_type, [arg_expr], arg_types=[arg_type]) + + return arg_expr + + def math_func_name(self, func_name: str, var_type: dtypes.dtype) -> str: + func_name_dict = { + "sin": "native_sin", + "cos": "native_cos", + "tan": "native_tan", + "sqrt": "native_sqrt", + "exp": "native_exp", + "exp2": "native_exp2", + "log": "native_log", + "log2": "native_log2", + } + + if func_name in func_name_dict: + return func_name_dict[func_name] + + return func_name + + def unary_math_expr(self, func_name: str, arg_type: dtypes.dtype, arg_expr: str) -> str: + mapped = self.math_func_name(func_name, arg_type) + return f"{mapped}({self._cast_math_arg(arg_type, arg_expr)})" + + def binary_math_expr( + self, + func_name: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + ) -> str: + mapped = self.math_func_name(func_name, lhs_type) + lhs_cast_expr = self._cast_math_arg(lhs_type, lhs_expr) + rhs_cast_expr = self._cast_math_arg(rhs_type, rhs_expr) + return f"{mapped}({lhs_cast_expr}, {rhs_cast_expr})" + + def pre_header(self, *, enable_subgroup_ops: bool, enable_printf: bool) -> str: + _ = enable_subgroup_ops + _ = enable_printf + header = ( + "// OpenCL C source generated by vkdispatch\n" + "#ifdef cl_khr_global_int32_base_atomics\n" + "#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable\n" + "#endif\n" + "#ifdef cl_khr_local_int32_base_atomics\n" + "#pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable\n" + "#endif\n" + "#ifdef cl_khr_fp64\n" + "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" + "#endif\n" + "#ifdef cl_khr_fp16\n" + "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" + "#endif\n" + ) + matrix_helpers = self._emit_matrix_helpers() + if len(matrix_helpers) > 0: + header += f"\n{matrix_helpers}\n" + return header + + def _emit_matrix_helpers(self) -> str: + if len(self._matrix_type_usage) == 0: + return "" + + sections: List[str] = [] + if 3 in self._matrix_type_usage: + sections.append( + "typedef struct __attribute__((packed)) vkdispatch_packed_float3 {\n" + " float x;\n" + " float y;\n" + " float z;\n" + "} vkdispatch_packed_float3;\n" + "static inline float3 vkdispatch_unpack_float3(vkdispatch_packed_float3 v) { return (float3)(v.x, v.y, v.z); }\n" + "static inline vkdispatch_packed_float3 vkdispatch_pack_float3(float3 v) {\n" + " vkdispatch_packed_float3 out = {v.x, v.y, v.z};\n" + " return out;\n" + "}" + ) + + for dim in sorted(self._matrix_type_usage): + sections.append(self._emit_matrix_helpers_for_dim(dim)) + + return "\n\n".join(sections) + + @staticmethod + def _vector_components(dim: int) -> List[str]: + return list("xyzw"[:dim]) + + @staticmethod + def _matrix_struct_name(dim: int) -> str: + return f"vkdispatch_mat{dim}" + + @staticmethod + def _vector_type_name(dim: int) -> str: + return f"float{dim}" + + def _matrix_col_expr(self, mat_expr: str, col: int, dim: int) -> str: + if dim == 3: + return f"vkdispatch_unpack_float3({mat_expr}.c{col})" + return f"{mat_expr}.c{col}" + + def _matrix_col_assign_stmt(self, target_expr: str, col: int, value_expr: str, dim: int) -> str: + if dim == 3: + return f"{target_expr}.c{col} = vkdispatch_pack_float3({value_expr});" + return f"{target_expr}.c{col} = {value_expr};" + + def _emit_matrix_helpers_for_dim(self, dim: int) -> str: + mat_type = self._matrix_struct_name(dim) + vec_type = self._vector_type_name(dim) + comps = self._vector_components(dim) + scalar_helper_name = self._matrix_helper_name(dim, "scalar") + copy_helper_name = self._matrix_helper_name(dim, "copy") + cols_helper_name = self._matrix_helper_name(dim, "cols") + flat_helper_name = self._matrix_helper_name(dim, "flat") + + lines: List[str] = [] + + if dim == 3: + lines.append( + "typedef struct __attribute__((packed)) vkdispatch_mat3 {\n" + " vkdispatch_packed_float3 c0;\n" + " vkdispatch_packed_float3 c1;\n" + " vkdispatch_packed_float3 c2;\n" + "} vkdispatch_mat3;" + ) + else: + cols = "\n".join([f" {vec_type} c{i};" for i in range(dim)]) + lines.append(f"typedef struct {mat_type} {{\n{cols}\n}} {mat_type};") + + # Constructors. + lines.append(f"static inline {mat_type} {scalar_helper_name}(float s) {{") + lines.append(f" {mat_type} out;") + for col_idx in range(dim): + diag_values = [("s" if row_idx == col_idx else "0.0f") for row_idx in range(dim)] + vec_expr = f"({vec_type})(" + ", ".join(diag_values) + ")" + lines.append(f" {self._matrix_col_assign_stmt('out', col_idx, vec_expr, dim)}") + lines.append(" return out;") + lines.append("}") + + lines.append(f"static inline {mat_type} {copy_helper_name}({mat_type} m) {{ return m; }}") + + col_args = ", ".join([f"{vec_type} c{i}" for i in range(dim)]) + lines.append(f"static inline {mat_type} {cols_helper_name}({col_args}) {{") + lines.append(f" {mat_type} out;") + for col_idx in range(dim): + lines.append(f" {self._matrix_col_assign_stmt('out', col_idx, f'c{col_idx}', dim)}") + lines.append(" return out;") + lines.append("}") + + flat_names = [f"m{col}{row}" for col in range(dim) for row in range(dim)] + flat_args = ", ".join([f"float {name}" for name in flat_names]) + lines.append(f"static inline {mat_type} {flat_helper_name}({flat_args}) {{") + lines.append(f" return {cols_helper_name}(") + for col_idx in range(dim): + values = [f"m{col_idx}{row_idx}" for row_idx in range(dim)] + suffix = "," if col_idx < dim - 1 else "" + lines.append(f" ({vec_type})({', '.join(values)}){suffix}") + lines.append(" );") + lines.append("}") + + # Unary negation. + lines.append(f"static inline {mat_type} vkdispatch_mat{dim}_neg({mat_type} a) {{") + lines.append(f" {mat_type} out;") + for col_idx in range(dim): + col_expr = self._matrix_col_expr("a", col_idx, dim) + lines.append( + f" {self._matrix_col_assign_stmt('out', col_idx, f'-{col_expr}', dim)}" + ) + lines.append(" return out;") + lines.append("}") + + # Matrix +/- matrix. + for op_name, op_symbol in (("add", "+"), ("sub", "-")): + lines.append( + f"static inline {mat_type} vkdispatch_mat{dim}_{op_name}_mm({mat_type} a, {mat_type} b) {{" + ) + lines.append(f" {mat_type} out;") + for col_idx in range(dim): + lhs_col = self._matrix_col_expr("a", col_idx, dim) + rhs_col = self._matrix_col_expr("b", col_idx, dim) + lines.append( + f" {self._matrix_col_assign_stmt('out', col_idx, f'{lhs_col} {op_symbol} {rhs_col}', dim)}" + ) + lines.append(" return out;") + lines.append("}") + + # Matrix/scalar and scalar/matrix arithmetic. + for op_name, op_symbol in (("add", "+"), ("sub", "-"), ("mul", "*"), ("div", "/")): + lines.append( + f"static inline {mat_type} vkdispatch_mat{dim}_{op_name}_ms({mat_type} a, float b) {{" + ) + lines.append(f" {mat_type} out;") + for col_idx in range(dim): + lhs_col = self._matrix_col_expr("a", col_idx, dim) + lines.append( + f" {self._matrix_col_assign_stmt('out', col_idx, f'{lhs_col} {op_symbol} b', dim)}" + ) + lines.append(" return out;") + lines.append("}") + + lines.append( + f"static inline {mat_type} vkdispatch_mat{dim}_{op_name}_sm(float a, {mat_type} b) {{" + ) + lines.append(f" {mat_type} out;") + for col_idx in range(dim): + rhs_col = self._matrix_col_expr("b", col_idx, dim) + lines.append( + f" {self._matrix_col_assign_stmt('out', col_idx, f'a {op_symbol} {rhs_col}', dim)}" + ) + lines.append(" return out;") + lines.append("}") + + # Matrix/vector product (column-major, GLSL-style): m * v. + mat_vec_terms = [f"({self._matrix_col_expr('m', i, dim)} * v.{comps[i]})" for i in range(dim)] + lines.append(f"static inline {vec_type} vkdispatch_mat{dim}_mul_mv({mat_type} m, {vec_type} v) {{") + lines.append(f" return {' + '.join(mat_vec_terms)};") + lines.append("}") + + # Vector/matrix product (column-major, GLSL-style): v * m. + lines.append(f"static inline {vec_type} vkdispatch_mat{dim}_mul_vm({vec_type} v, {mat_type} m) {{") + for col_idx in range(dim): + lines.append(f" {vec_type} col{col_idx} = {self._matrix_col_expr('m', col_idx, dim)};") + row_exprs = [] + for col_idx in range(dim): + terms = [f"(v.{comps[row_idx]} * col{col_idx}.{comps[row_idx]})" for row_idx in range(dim)] + row_exprs.append(" + ".join(terms)) + lines.append(f" return ({vec_type})({', '.join(row_exprs)});") + lines.append("}") + + return "\n".join(lines) + + def arithmetic_unary_expr(self, op: str, var_type: dtypes.dtype, var_expr: str) -> Optional[str]: + if op == "-" and dtypes.is_matrix(var_type): + dim = var_type.child_count + self._record_matrix_dim(dim) + return f"vkdispatch_mat{dim}_neg({var_expr})" + return None + + def arithmetic_binary_expr( + self, + op: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + ) -> Optional[str]: + if not (dtypes.is_matrix(lhs_type) or dtypes.is_matrix(rhs_type)): + return None + + if op not in ("+", "-", "*", "/"): + raise NotImplementedError( + f"OpenCL matrix arithmetic override does not support operator '{op}' " + f"for ({lhs_type.name}, {rhs_type.name})." + ) + + if dtypes.is_matrix(lhs_type): + dim = lhs_type.child_count + if dtypes.is_matrix(rhs_type): + if rhs_type.child_count != dim: + raise ValueError( + f"OpenCL matrix arithmetic requires matching dimensions, got '{lhs_type.name}' and '{rhs_type.name}'." + ) + if op not in ("+", "-"): + raise NotImplementedError( + f"OpenCL matrix arithmetic does not support operator '{op}' for two matrices." + ) + self._record_matrix_dim(dim) + return f"vkdispatch_mat{dim}_{'add' if op == '+' else 'sub'}_mm({lhs_expr}, {rhs_expr})" + + if dtypes.is_scalar(rhs_type): + self._record_matrix_dim(dim) + suffix = "add" if op == "+" else "sub" if op == "-" else "mul" if op == "*" else "div" + return f"vkdispatch_mat{dim}_{suffix}_ms({lhs_expr}, {rhs_expr})" + + if dtypes.is_vector(rhs_type) and op == "*": + if rhs_type.child_count != dim or rhs_type.scalar != dtypes.float32: + raise ValueError( + f"OpenCL matrix/vector multiplication requires float32 vec{dim}, got '{rhs_type.name}'." + ) + self._record_matrix_dim(dim) + return f"vkdispatch_mat{dim}_mul_mv({lhs_expr}, {rhs_expr})" + + raise NotImplementedError( + f"Unsupported OpenCL matrix arithmetic for ({lhs_type.name}, {rhs_type.name}) with operator '{op}'." + ) + + # lhs is not matrix; rhs is matrix + dim = rhs_type.child_count + if dtypes.is_scalar(lhs_type): + self._record_matrix_dim(dim) + suffix = "add" if op == "+" else "sub" if op == "-" else "mul" if op == "*" else "div" + return f"vkdispatch_mat{dim}_{suffix}_sm({lhs_expr}, {rhs_expr})" + + if dtypes.is_vector(lhs_type) and op == "*": + if lhs_type.child_count != dim or lhs_type.scalar != dtypes.float32: + raise ValueError( + f"OpenCL vector/matrix multiplication requires float32 vec{dim}, got '{lhs_type.name}'." + ) + self._record_matrix_dim(dim) + return f"vkdispatch_mat{dim}_mul_vm({lhs_expr}, {rhs_expr})" + + raise NotImplementedError( + f"Unsupported OpenCL matrix arithmetic for ({lhs_type.name}, {rhs_type.name}) with operator '{op}'." + ) + + def make_source(self, header: str, body: str, x: int, y: int, z: int) -> str: + expected_size_header = ( + f"// Expected local size: ({x}, {y}, {z})\n" + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_X {x}\n" + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Y {y}\n" + f"#define VKDISPATCH_EXPECTED_LOCAL_SIZE_Z {z}\n" + ) + workgroup_attribute = f"__attribute__((reqd_work_group_size({x}, {y}, {z})))" + if "__kernel void vkdispatch_main" in body: + body = body.replace( + "__kernel void vkdispatch_main", + f"{workgroup_attribute}\n__kernel void vkdispatch_main", + 1, + ) + else: + body = f"{workgroup_attribute}\n{body}" + + return f"{expected_size_header}\n{header}\n{body}" + + def constant_namespace(self) -> str: + return "UBO" + + def variable_namespace(self) -> str: + return "PC" + + def exec_bounds_guard(self, exec_count_expr: str) -> str: + gid_expr = f"({self.global_invocation_id_expr()})" + exec_expr = f"({exec_count_expr})" + return ( + f"if ({self.component_access_expr(exec_expr, 'x', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'x', dtypes.uvec3)} || " + f"{self.component_access_expr(exec_expr, 'y', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'y', dtypes.uvec3)} || " + f"{self.component_access_expr(exec_expr, 'z', dtypes.uvec4)} <= {self.component_access_expr(gid_expr, 'z', dtypes.uvec3)}) {{ return; }}\n" + ) + + def shared_buffer_declaration(self, var_type: dtypes.dtype, name: str, size: int) -> str: + self._shared_buffer_lines.append(f"__local {self.type_name(var_type)} {name}[{size}];") + # OpenCL requires __local storage declarations at kernel/function scope. + return "" + + def uniform_block_declaration(self, contents: str) -> str: + self._register_kernel_param("__global const UniformObjectBuffer* vkdispatch_uniform_ptr") + self._register_alias_line("const UniformObjectBuffer UBO = *vkdispatch_uniform_ptr;") + return f"\ntypedef struct UniformObjectBuffer {{\n{contents}\n}} UniformObjectBuffer;\n" + + def storage_buffer_declaration(self, binding: int, var_type: dtypes.dtype, name: str) -> str: + struct_name = f"Buffer{binding}" + param_name = f"vkdispatch_binding_{binding}_ptr" + data_type = self.type_name(var_type) + self._register_kernel_param(f"__global {data_type}* {param_name}") + if dtypes.is_complex(var_type): + scalar_type = self.type_name(var_type.child_type) + self._register_alias_line( + f"__global {scalar_type}* {name}_scalar = (__global {scalar_type}*)({param_name});" + ) + elif dtypes.is_vector(var_type): + scalar_type = self.type_name(var_type.scalar) + self._register_alias_line( + f"__global {scalar_type}* {name}_scalar = (__global {scalar_type}*)({param_name});" + ) + self._register_alias_line(f"{struct_name} {name} = {{{param_name}}};") + return f"typedef struct {struct_name} {{ __global {data_type}* data; }} {struct_name};\n" + + def sampler_declaration(self, binding: int, dimensions: int, name: str) -> str: + _ = (binding, dimensions, name) + raise NotImplementedError("image/sampler unsupported in OpenCL backend") + + def push_constant_declaration(self, contents: str) -> str: + self._register_kernel_param("const PushConstant vkdispatch_pc_value") + self._register_alias_line("const PushConstant PC = vkdispatch_pc_value;") + return f"\ntypedef struct PushConstant {{\n{contents}\n}} PushConstant;\n" + + def entry_point(self, body_contents: str) -> str: + params = ", ".join(self._kernel_params) + alias_block = "" + shared_block = "" + for line in self._shared_buffer_lines: + shared_block += f" {line}\n" + for line in self._entry_alias_lines: + alias_block += f" {line}\n" + + return ( + f"__kernel void vkdispatch_main({params}) {{\n" + f"{shared_block}" + f"{alias_block}" + f"{body_contents}" + f"}}\n" + ) + + def inf_f32_expr(self) -> str: + return "as_float((uint)0x7F800000u)" + + def ninf_f32_expr(self) -> str: + return "as_float((uint)0xFF800000u)" + + def inf_f64_expr(self) -> str: + return "as_double((ulong)0x7FF0000000000000UL)" + + def ninf_f64_expr(self) -> str: + return "as_double((ulong)0xFFF0000000000000UL)" + + def inf_f16_expr(self) -> str: + return "as_half((ushort)0x7C00u)" + + def ninf_f16_expr(self) -> str: + return "as_half((ushort)0xFC00u)" + + def float_bits_to_int_expr(self, var_expr: str) -> str: + return f"as_int({var_expr})" + + def float_bits_to_uint_expr(self, var_expr: str) -> str: + return f"as_uint({var_expr})" + + def int_bits_to_float_expr(self, var_expr: str) -> str: + return f"as_float({var_expr})" + + def uint_bits_to_float_expr(self, var_expr: str) -> str: + return f"as_float({var_expr})" + + def global_invocation_id_expr(self) -> str: + return "((uint3)((uint)get_global_id(0), (uint)get_global_id(1), (uint)get_global_id(2)))" + + def local_invocation_id_expr(self) -> str: + return "((uint3)((uint)get_local_id(0), (uint)get_local_id(1), (uint)get_local_id(2)))" + + def local_invocation_index_expr(self) -> str: + return ( + "((uint)(get_local_id(0) + " + "get_local_size(0) * (get_local_id(1) + get_local_size(1) * get_local_id(2))))" + ) + + def workgroup_id_expr(self) -> str: + return "((uint3)((uint)get_group_id(0), (uint)get_group_id(1), (uint)get_group_id(2)))" + + def workgroup_size_expr(self) -> str: + return "((uint3)((uint)get_local_size(0), (uint)get_local_size(1), (uint)get_local_size(2)))" + + def num_workgroups_expr(self) -> str: + return "((uint3)((uint)get_num_groups(0), (uint)get_num_groups(1), (uint)get_num_groups(2)))" + + def num_subgroups_expr(self) -> str: + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_id_expr(self) -> str: + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_size_expr(self) -> str: + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_invocation_id_expr(self) -> str: + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def barrier_statement(self) -> str: + return "barrier(CLK_LOCAL_MEM_FENCE);" + + def memory_barrier_statement(self) -> str: + return "mem_fence(CLK_LOCAL_MEM_FENCE);" + + def memory_barrier_buffer_statement(self) -> str: + return "mem_fence(CLK_GLOBAL_MEM_FENCE);" + + def memory_barrier_shared_statement(self) -> str: + return "mem_fence(CLK_LOCAL_MEM_FENCE);" + + def memory_barrier_image_statement(self) -> str: + raise NotImplementedError("image/sampler unsupported in OpenCL backend") + + def group_memory_barrier_statement(self) -> str: + return "mem_fence(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE);" + + def subgroup_add_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_mul_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_min_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_max_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_and_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_or_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_xor_expr(self, arg_expr: str) -> str: + _ = arg_expr + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_elect_expr(self) -> str: + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def subgroup_barrier_statement(self) -> str: + raise NotImplementedError("subgroup operations unsupported in OpenCL backend") + + def printf_statement(self, fmt: str, args: List[str]) -> str: + if len(args) == 0: + return f'printf("{fmt}");' + return f'printf("{fmt}", {", ".join(args)});' + + def texture_size_expr(self, texture_expr: str, lod: int, dimensions: int) -> str: + _ = (texture_expr, lod, dimensions) + raise NotImplementedError("image/sampler unsupported in OpenCL backend") + + def sample_texture_expr(self, texture_expr: str, coord_expr: str, lod_expr: Optional[str] = None) -> str: + _ = (texture_expr, coord_expr, lod_expr) + raise NotImplementedError("image/sampler unsupported in OpenCL backend") + + def mark_texture_sample_dimension(self, dimensions: int) -> None: + _ = dimensions + raise NotImplementedError("image/sampler unsupported in OpenCL backend") + + def atomic_add_expr(self, mem_expr: str, value_expr: str, var_type: dtypes.dtype) -> str: + if var_type not in (dtypes.int32, dtypes.uint32): + raise NotImplementedError(f"OpenCL atomic_add only supports int32/uint32, got '{var_type.name}'") + + return f"atomic_add(&({mem_expr}), {value_expr})" diff --git a/vkdispatch/codegen/builder.py b/vkdispatch/codegen/builder.py index c1eb0478..44e50e48 100644 --- a/vkdispatch/codegen/builder.py +++ b/vkdispatch/codegen/builder.py @@ -1,85 +1,30 @@ import vkdispatch.base.dtype as dtypes -from vkdispatch.base.dtype import dtype, is_scalar, is_vector, is_matrix, is_complex, to_vector from .struct_builder import StructElement, StructBuilder -from typing import Dict -from typing import List -from typing import Tuple -from typing import Union -from typing import Optional -from typing import Callable -from typing import Any +from .shader_writer import ShaderWriter +from .backends import CodeGenBackend +from .global_builder import get_codegen_backend -import enum -import dataclasses - -import numpy as np - -ENABLE_SCALED_AND_OFFSET_INT = True +from enum import IntFlag, auto -def do_scaled_int_check(other): - return ENABLE_SCALED_AND_OFFSET_INT and (isinstance(other, int) or np.issubdtype(type(other), np.integer)) +from typing import Dict, List, Optional, Tuple -def is_int_power_of_2(n: int) -> bool: - """Check if an integer is a power of 2.""" - return n > 0 and (n & (n - 1)) == 0 - -def shader_var_name(index: "Union[Any, ShaderVariable]") -> str: - if isinstance(index, ShaderVariable): - result_str = str(index) +import dataclasses - if result_str[0] == "(" and result_str[-1] == ")": - result_str = result_str[1:-1] - - return result_str - - return str(index) +import enum -def var_types_to_floating(var_type: dtype) -> dtype: - if var_type == dtypes.int32 or var_type == dtypes.uint32: - return dtypes.float32 +from .variables.variables import BaseVariable, ShaderVariable, ScaledAndOfftsetIntVariable +from .variables.bound_variables import BufferVariable, ImageVariable - if var_type == dtypes.ivec2 or var_type == dtypes.uvec2: - return dtypes.vec2 +_PUSH_CONSTANT_UNSUPPORTED_BACKENDS = set() - if var_type == dtypes.ivec3 or var_type == dtypes.uvec3: - return dtypes.vec3 - - if var_type == dtypes.ivec4 or var_type == dtypes.uvec4: - return dtypes.vec4 - - return var_type -class BindingType(enum.Enum): - """ - A dataclass that represents the type of a binding in a shader. Either a - STORAGE_BUFFER, UNIFORM_BUFFER, or SAMPLER. - """ - STORAGE_BUFFER = 1 - UNIFORM_BUFFER = 3 - SAMPLER = 5 - -@dataclasses.dataclass -class ShaderBinding: - """ - A dataclass that represents a bound resource in a shader. Either a - buffer or an image. - - Attributes: - dtype (vd.dtype): The dtype of the resource. If - the resource is an image, this should be vd.vec4 - (since all images are sampled with 4 channels in shaders). - name (str): The name of the resource within the shader code. - dimension (int): The dimension of the resource. Set to 0 for - buffers and 1, 2, or 3 for images. - binding_type (BindingType): The type of the binding. Either - STORAGE_BUFFER, UNIFORM_BUFFER, or SAMPLER. - """ - dtype: dtype - name: str - dimension: int - binding_type: BindingType +def _push_constant_not_supported_error(backend_name: str) -> str: + return ( + f"Push Constants are not supported for the {backend_name.upper()} backend. " + "Use Const instead." + ) @dataclasses.dataclass class SharedBuffer: @@ -91,10 +36,19 @@ class SharedBuffer: size (int): The size of the shared buffer. name (str): The name of the shared buffer within the shader code. """ - dtype: dtype + dtype: dtypes.dtype size: int name: str +class BindingType(enum.Enum): + """ + A dataclass that represents the type of a binding in a shader. Either a + STORAGE_BUFFER, UNIFORM_BUFFER, or SAMPLER. + """ + STORAGE_BUFFER = 1 + UNIFORM_BUFFER = 3 + SAMPLER = 5 + @dataclasses.dataclass class ShaderDescription: """ @@ -116,11 +70,22 @@ class ShaderDescription: uniform_structure: List[StructElement] binding_type_list: List[BindingType] binding_access: List[Tuple[bool, bool]] # List of tuples indicating read and write access for each binding - exec_count_name: str + exec_count_name: Optional[str] + resource_binding_base: int + backend: Optional[CodeGenBackend] = None def make_source(self, x: int, y: int, z: int) -> str: - layout_str = f"layout(local_size_x = {x}, local_size_y = {y}, local_size_z = {z}) in;" - return f"{self.header}\n{layout_str}\n{self.body}" + if self.backend is None: + layout_str = f"layout(local_size_x = {x}, local_size_y = {y}, local_size_z = {z}) in;" + shader_source = f"{self.header}\n{layout_str}\n{self.body}" + else: + shader_source = self.backend.make_source(self.header, self.body, x, y, z) + + # ff = open(f"sources/{self.name}.comp", "w") + # ff.write(shader_source) + # ff.close() + + return shader_source def __repr__(self): description_string = "" @@ -132,779 +97,39 @@ def __repr__(self): description_string += f"Binding Types: {self.binding_type_list}\n" description_string += f"Binding Access: {self.binding_access}\n" description_string += f"Execution Count Name: {self.exec_count_name}\n" + description_string += f"Backend: {self.backend.name if self.backend is not None else 'none'}\n" description_string += f"Header:\n{self.header}\n" description_string += f"Body:\n{self.body}\n" return description_string -class ShaderVariable: - append_func: Callable[[str], None] - name_func: Callable[[str], str] - var_type: dtype - name: str - raw_name: str - can_index: bool = False - use_child_type: bool = True - _varying: bool = False - lexical_unit: bool = False - settable: bool = False - parent_variables: List["ShaderVariable"] - - def __init__(self, - append_func: Callable[[str], None], - name_func: Callable[[str], Tuple[str, str]], - var_type: dtype, - name: Optional[str] = None, - lexical_unit: bool = False, - settable: bool = False, - parent_variables: List["ShaderVariable"] = None - ) -> None: - - self.append_func = append_func - self.name_func = name_func - self.var_type = var_type - self.lexical_unit = lexical_unit - - both_names = self.name_func(name) - self.name = both_names[0] - self.raw_name = both_names[1] - self.settable = settable - - if parent_variables is None: - parent_variables = [] - - self.parent_variables = [] - - for parent_var in parent_variables: - if isinstance(parent_var, ShaderVariable): - self.parent_variables.append(parent_var) - - if is_complex(self.var_type): - self.real = self.new(self.var_type.child_type, f"{self}.x", [self], lexical_unit=True, settable=settable) - self.imag = self.new(self.var_type.child_type, f"{self}.y", [self], lexical_unit=True, settable=settable) - self.x = self.real - self.y = self.imag - - self._register_shape() - - if is_vector(self.var_type): - self.x = self.new(self.var_type.child_type, f"{self}.x", [self], lexical_unit=True, settable=settable) - - if self.var_type.child_count >= 2: - self.y = self.new(self.var_type.child_type, f"{self}.y", [self], lexical_unit=True, settable=settable) - - if self.var_type.child_count >= 3: - self.z = self.new(self.var_type.child_type, f"{self}.z", [self], lexical_unit=True, settable=settable) - - if self.var_type.child_count == 4: - self.w = self.new(self.var_type.child_type, f"{self}.w", [self], lexical_unit=True, settable=settable) - - self._register_shape() - - if is_matrix(self.var_type): - self._register_shape() - - self._initilized = True - - def __repr__(self) -> str: - if self.lexical_unit: - return self.name - - return f"({self.name})" - - def read_callback(self): - for parent in self.parent_variables: - parent.read_callback() - - def write_callback(self): - for parent in self.parent_variables: - parent.write_callback() - - def new(self, var_type: dtype, name: str, parents: List["ShaderVariable"], lexical_unit: bool = False, settable: bool = False) -> "ShaderVariable": - return ShaderVariable(self.append_func, self.name_func, var_type, name, lexical_unit=lexical_unit, settable=settable, parent_variables=parents) - - def __getitem__(self, index) -> "ShaderVariable": - if not self.can_index: - raise ValueError("Unsupported indexing!") - - return_type = self.var_type.child_type if self.use_child_type else self.var_type - - if isinstance(index, ShaderVariable) or isinstance(index, (int, np.integer)): - return self.new(return_type, f"{self.name}[{shader_var_name(index)}]", [self], settable=self.settable) - - if isinstance(index, tuple): - index_strs = tuple(shader_var_name(i) for i in index) - - if len(index_strs) == 1: - return self.new(return_type, f"{self.name}[{index_strs[0]}]", [self], settable=self.settable) - elif self.shape is None: - raise ValueError("Cannot do multidimentional index into object with no shape!") - - if len(index_strs) == 2: - true_index = f"{index_strs[0]} * {self.shape.y} + {index_strs[1]}" - return self.new(return_type, f"{self.name}[{true_index}]", [self], settable=self.settable) - elif len(index_strs) == 3: - true_index = f"{index_strs[0]} * {self.shape.y} + {index_strs[1]}" - true_index = f"({true_index}) * {self.shape.z} + {index_strs[2]}" - return self.new(return_type, f"{self.name}[{true_index}]", [self], settable=self.settable) - else: - raise ValueError(f"Unsupported number of indicies {len(index)}!") - - else: - raise ValueError(f"Unsupported index type {index} of type {type(index)}!") - - def __setitem__(self, index, value: "ShaderVariable") -> None: - assert self.settable, f"Cannot set value of '{self.name}' because it is not a settable variable!" - - if isinstance(index, slice): - if index.start is None and index.stop is None and index.step is None: - self.write_callback() - - if isinstance(value, ShaderVariable): - value.read_callback() - - self.append_func(f"{self.name} = {shader_var_name(value)};\n") - return - else: - raise ValueError("Unsupported slice!") - - if not self.can_index: - raise ValueError(f"Unsupported indexing {index}!") - - if f"{self.name}[{index}]" == str(value): - return - - self.write_callback() - - if isinstance(index, ShaderVariable): - index.read_callback() - - if isinstance(value, ShaderVariable): - value.read_callback() - - self.append_func(f"{self.name}[{shader_var_name(index)}] = {shader_var_name(value)};\n") - - def _register_shape(self, shape_var: "ShaderVariable" = None, shape_name: str = None, use_child_type: bool = True): - self.shape = shape_var - self.shape_name = shape_name - self.can_index = True - self.use_child_type = use_child_type - - def __bool__(self) -> bool: - raise ValueError(f"Vkdispatch variables cannot be cast to a python boolean") - - def new_scaled_and_offset_int(self, var_type: dtype, name: str, parents: List["ShaderVariable"] = None) -> "ScaledAndOfftsetIntVariable": - return ScaledAndOfftsetIntVariable(self.append_func, self.name_func, var_type, name, parent_variables=parents) - - def copy(self, var_name: str = None): - """Create a new variable with the same value as the current variable.""" - new_var = self.new(self.var_type, var_name, [], lexical_unit=True, settable=True) - - self.read_callback() - - self.append_func(f"{self.var_type.glsl_type} {new_var.name} = {self};\n") - return new_var - - def cast_to(self, var_type: dtype): - return self.new(var_type, f"{var_type.glsl_type}({self.name})", [self], lexical_unit=True) - - def printf_args(self) -> str: - total_count = np.prod(self.var_type.shape) - - if total_count == 1: - return self.name - - args_list = [] - - for i in range(0, total_count): - args_list.append(f"{self.name}[{i}]") - - return ",".join(args_list) - - def __setattr__(self, name: str, value: "ShaderVariable") -> "ShaderVariable": - attrib_error = False - attrib_error_msg = "" - - try: - if self._initilized: - if is_complex(self.var_type): - if name == "real": - self.write_callback() - - if isinstance(value, ShaderVariable): - value.read_callback() - - self.append_func(f"{self}.x = {shader_var_name(value)};\n") - return - - if name == "imag": - self.write_callback() - - if isinstance(value, ShaderVariable): - value.read_callback() - - self.append_func(f"{self}.y = {shader_var_name(value)};\n") - return - - if name == "x" or name == "y": - self.write_callback() - - if isinstance(value, ShaderVariable): - value.read_callback() - - self.append_func(f"{self}.{name} = {shader_var_name(value)};\n") - return - - if is_vector(self.var_type): - if name == "y" and self.var_type.shape[0] < 2: - attrib_error = True - attrib_error_msg = f"Cannot set attribute '{name}' in a {self.var_type.name}!" - - if name == "z" and self.var_type.shape[0] < 3: - attrib_error = True - attrib_error_msg = f"Cannot set attribute '{name}' in a {self.var_type.name}!" - - if name == "w" and self.var_type.shape[0] < 4: - attrib_error = True - attrib_error_msg = f"Cannot set attribute '{name}' in a {self.var_type.name}!" - - if not attrib_error and (name == "x" or name == "y" or name == "z" or name == "w"): - self.write_callback() - - if isinstance(value, ShaderVariable): - value.read_callback() - - self.append_func(f"{self}.{name} = {shader_var_name(value)};\n") - return - - if is_scalar(self.var_type): - if name == "x": - self.write_callback() - - if isinstance(value, ShaderVariable): - value.read_callback() - - self.append_func(f"{self} = {shader_var_name(value)};\n") - return - except: - super().__setattr__(name, value) - return - - if attrib_error: - raise AttributeError(attrib_error_msg) - - super().__setattr__(name, value) - - def __getattr__(self, name: str) -> "ShaderVariable": - if not set(name).issubset(set("xyzw")): - raise AttributeError(f"Cannot get attribute '{name}'") - - if len(name) > 4: - raise AttributeError(f"Cannot get attribute '{name}'") - - if len(name) == 1: - if len(self.var_type.shape) == 2: - raise AttributeError(f"Cannot get attribute '{name}' from a matrix of shape {self.var_type.shape}!") - - if name == "x" and self.var_type.shape[0] == 1: - return self.new(self.var_type, f"{self}.{name}", [self], lexical_unit=True, settable=self.settable) - - if name == "y" and self.var_type.shape[0] < 2: - raise AttributeError(f"Cannot get attribute '{name}' from a {self.var_type.name}!") - - if name == "z" and self.var_type.shape[0] < 3: - raise AttributeError(f"Cannot get attribute '{name}' from a {self.var_type.name}!") - - if name == "w" and self.var_type.shape[0] < 4: - raise AttributeError(f"Cannot get attribute '{name}' from a {self.var_type.name}!") - - return self.new(self.var_type.child_type, f"{self}.{name}", [self], lexical_unit=True, settable=self.settable) - - new_type = to_vector(self.var_type.child_type, len(name)) - return self.new(new_type, f"{self}.{name}", [self], lexical_unit=True, settable=self.settable) - - def __lt__(self, other): - return self.new(dtypes.int32, f"{self} < {other}", [self, other]) - - def __le__(self, other): - return self.new(dtypes.int32, f"{self} <= {other}", [self, other]) - - def __eq__(self, other): - return self.new(dtypes.int32, f"{self} == {other}", [self, other]) - - def __ne__(self, other): - return self.new(dtypes.int32, f"{self} != {other}", [self, other]) - - def __gt__(self, other): - return self.new(dtypes.int32, f"{self} > {other}", [self, other]) - - def __ge__(self, other): - return self.new(dtypes.int32, f"{self} >= {other}", [self, other]) - - def __add__(self, other): - if do_scaled_int_check(other): - result = self.new_scaled_and_offset_int(self.var_type, f"{self}", [self]) - return result.__add__(other) - - return self.new(self.var_type, f"{self} + {other}", [self, other]) - - def __sub__(self, other): - if do_scaled_int_check(other): - result = self.new_scaled_and_offset_int(self.var_type, f"{self}", [self]) - return result.__sub__(other) - - return self.new(self.var_type, f"{self} - {other}", [self, other]) - - def __mul__(self, other): - if do_scaled_int_check(other): - result = self.new_scaled_and_offset_int(self.var_type, f"{self}", [self]) - return result.__mul__(other) - - return_var_type = self.var_type - - if (self.var_type.dimentions == 2 - and other.var_type.dimentions == 1): - return_var_type = other.var_type - - if(self.var_type == dtypes.int32 or self.var_type == dtypes.uint32): - if (isinstance(other, int) and is_int_power_of_2(other)): - if other == 1: - return self - - power = int(np.round(np.log2(other))) - - return self.new(self.var_type, f"{self} << {power}", [self]) - elif (isinstance(other, ShaderVariable) and (other.var_type == dtypes.float32)) or (isinstance(other, float) and np.issubdtype(type(other), np.floating)): - return_var_type = dtypes.float32 - - return self.new(return_var_type, f"{self} * {other}", [self, other]) - - def __truediv__(self, other): - if isinstance(other, int) and is_int_power_of_2(other): - if other == 1: - return self - - if self.var_type != dtypes.int32 and self.var_type != dtypes.uint32: - return self.new(self.var_type, f"{self} / {other}", [self, other]) - - power = int(np.round(np.log2(other))) - - return self.new(self.var_type, f"{self} >> {power}", [self]) - - return self.new(self.var_type, f"{self} / {other}", [self, other]) - - # def __floordiv__(self, other: 'shader_variable') -> 'shader_variable': - # return self.builder.make_var(f"{self} / {other}") - - def __mod__(self, other): - return self.new(self.var_type, f"{self} % {other}", [self, other]) - - def __pow__(self, other): - other_str = str(other) - - if isinstance(other, ShaderVariable): - other_str = other.name - - return self.new(self.var_type, f"pow({self.name}, {other_str})", [self, other]) - - def __neg__(self): - return self.new(self.var_type, f"-{self}", [self]) - - def __abs__(self): - return self.new(self.var_type, f"abs({self.name})", [self]) - - def __invert__(self): - return self.new(self.var_type, f"~{self}", [self]) - - def __lshift__(self, other): - return self.new(self.var_type, f"{self} << {other}", [self, other]) - - def __rshift__(self, other): - return self.new(self.var_type, f"{self} >> {other}", [self, other]) - - def __and__(self, other): - return self.new(self.var_type, f"{self} & {other}", [self, other]) - - def __xor__(self, other): - return self.new(self.var_type, f"{self} ^ {other}", [self, other]) - - def __or__(self, other): - return self.new(self.var_type, f"({self} | {other}", [self, other]) - - def __radd__(self, other): - if do_scaled_int_check(other): - result = self.new_scaled_and_offset_int(self.var_type, f"{self}", [self]) - return result.__radd__(other) - - return self.new(self.var_type, f"{other} + {self}", [self, other]) - - def __rsub__(self, other): - if do_scaled_int_check(other): - result = self.new_scaled_and_offset_int(self.var_type, f"{self}", [self]) - return result.__rsub__(other) - - return self.new(self.var_type, f"{other} - {self}", [self, other]) - - def __rmul__(self, other): - if do_scaled_int_check(other): - result = self.new_scaled_and_offset_int(self.var_type, f"{self}", [self]) - return result.__rmul__(other) - - return_var_type = self.var_type - - if(self.var_type == dtypes.int32 or self.var_type == dtypes.uint32): - if (isinstance(other, int) and is_int_power_of_2(other)): - if other == 1: - return self - - power = int(np.round(np.log2(other))) - - return self.new(self.var_type, f"{self} << {power}", [self]) - elif (isinstance(other, ShaderVariable) and (other.var_type == dtypes.float32)) or (isinstance(other, float) and np.issubdtype(type(other), np.floating)): - return_var_type = dtypes.float32 - - return self.new(return_var_type, f"{other} * {self}", [self, other]) - - def __rtruediv__(self, other): - return self.new(self.var_type, f"{other} / {self}", [self, other]) - - # def __rfloordiv__(self, other: 'shader_variable') -> 'shader_variable': - # return self.builder.make_var(f"{other} / {self}") - - def __rmod__(self, other): - return self.new(self.var_type, f"{other} % {self}", [self, other]) - - def __rpow__(self, other): - other_str = str(other) - - if isinstance(other, ShaderVariable): - other_str = other.name - - return self.new(self.var_type, f"pow({other_str}, {self.name})", [self, other]) - - def __rand__(self, other): - return self.new(self.var_type, f"{other} & {self}", [self, other]) - - def __rxor__(self, other): - return self.new(self.var_type, f"{other} ^ {self}", [self, other]) - - def __ror__(self, other): - return self.new(self.var_type, f"{other} | {self}", [self, other]) - - def __iadd__(self, other): - assert self.settable, f"Cannot set value of '{self.name}' because it is not a settable variable!" - - self.read_callback() - self.write_callback() - - if isinstance(other, ShaderVariable): - other.read_callback() - - self.append_func(f"{self} += {other};\n") - return self - - def __isub__(self, other): - assert self.settable, f"Cannot set value of '{self.name}' because it is not a settable variable!" - - self.read_callback() - self.write_callback() - - if isinstance(other, ShaderVariable): - other.read_callback() - - self.append_func(f"{self} -= {other};\n") - return self - - def __imul__(self, other): - assert self.settable, f"Cannot set value of '{self.name}' because it is not a settable variable!" - - self.read_callback() - self.write_callback() - - if isinstance(other, ShaderVariable): - other.read_callback() - - self.append_func(f"{self} *= {other};\n") - return self - - def __itruediv__(self, other): - assert self.settable, f"Cannot set value of '{self.name}' because it is not a settable variable!" - - self.read_callback() - self.write_callback() - - if isinstance(other, ShaderVariable): - other.read_callback() - - self.append_func(f"{self} /= {other};\n") - return self - - # def __ifloordiv__(self, other: 'shader_variable') -> 'shader_variable': - # self.append_func(f"{self} /= {other};\n") - # return self - - def __imod__(self, other): - assert self.settable, f"Cannot set value of '{self.name}' because it is not a settable variable!" - - self.read_callback() - self.write_callback() - - if isinstance(other, ShaderVariable): - other.read_callback() - - self.append_func(f"{self} %= {other};\n") - return self - - def __ipow__(self, other): - assert self.settable, f"Cannot set value of '{self.name}' because it is not a settable variable!" - - self.read_callback() - self.write_callback() - - other_str = str(other) - - if isinstance(other, ShaderVariable): - other.read_callback() - other_str = other.name - - self.append_func(f"{self} = pow({self.name}, {other_str});\n") - return self - - def __ilshift__(self, other): - assert self.settable, f"Cannot set value of '{self.name}' because it is not a settable variable!" - - self.read_callback() - self.write_callback() - - if isinstance(other, ShaderVariable): - other.read_callback() - - self.append_func(f"{self} <<= {other};\n") - return self - - def __irshift__(self, other): - assert self.settable, f"Cannot set value of '{self.name}' because it is not a settable variable!" - - self.read_callback() - self.write_callback() - - if isinstance(other, ShaderVariable): - other.read_callback() - - self.append_func(f"{self} >>= {other};\n") - return self - - def __iand__(self, other): - assert self.settable, f"Cannot set value of '{self.name}' because it is not a settable variable!" - - self.read_callback() - self.write_callback() - - if isinstance(other, ShaderVariable): - other.read_callback() - - self.append_func(f"{self} &= {other};\n") - return self - - def __ixor__(self, other): - assert self.settable, f"Cannot set value of '{self.name}' because it is not a settable variable!" - - self.read_callback() - self.write_callback() - - if isinstance(other, ShaderVariable): - other.read_callback() - - self.append_func(f"{self} ^= {other};\n") - return self - - def __ior__(self, other): - assert self.settable, f"Cannot set value of '{self.name}' because it is not a settable variable!" - - self.read_callback() - self.write_callback() - - if isinstance(other, ShaderVariable): - other.read_callback() - - self.append_func(f"{self} |= {other};\n") - return self - -class ScaledAndOfftsetIntVariable(ShaderVariable): - def __init__(self, - append_func: Callable[[str], None], - name_func: Callable[[str], Tuple[str, str]], - var_type: dtype, - name: Optional[str] = None, - scale: int = 1, - offset: int = 0, - parent_variables: List["ShaderVariable"] = None - ) -> None: - self.base_name = str(name) - self.scale = scale - self.offset = offset - - super().__init__(append_func, name_func, var_type, name, parent_variables=parent_variables) - - def new_from_self(self, scale: int = 1, offset: int = 0): - child_vartype = self.var_type - - if isinstance(scale, float) or isinstance(offset, float): - child_vartype = var_types_to_floating(self.var_type) - - return ScaledAndOfftsetIntVariable( - self.append_func, - self.name_func, - child_vartype, - f"{self.name}", - scale=self.scale * scale, - offset=offset + self.offset * scale, - parent_variables=self.parent_variables - ) - - def __repr__(self) -> str: - scale_str = f" * {self.scale}" if self.scale != 1 else "" - offset_str = f" + {self.offset}" if self.offset != 0 else "" - - if scale_str == "" and offset_str == "": - return self.base_name - - return f"({self.base_name}{scale_str}{offset_str})" - - def __add__(self, other): - if isinstance(other, ShaderVariable): - return super().__add__(other) - - return self.new_from_self(offset=other) - - def __sub__(self, other): - if isinstance(other, ShaderVariable): - return super().__sub__(other) - - return self.new_from_self(offset=-other) - - def __mul__(self, other): - if isinstance(other, ShaderVariable): - return super().__mul__(other) - - return self.new_from_self(scale=other) - - def __radd__(self, other): - if isinstance(other, ShaderVariable): - return super().__radd__(other) - - return self.new_from_self(offset=other) - - def __rsub__(self, other): - if isinstance(other, ShaderVariable): - return super().__rsub__(other) - - return self.new_from_self(offset=other, scale=-1) - - def __rmul__(self, other): - if isinstance(other, ShaderVariable): - return super().__rmul__(other) - - return self.new_from_self(scale=other) - -class BoundVariable(ShaderVariable): - binding: int = -1 - - def __init__(self, - append_func: Callable[[str], None], - name_func: Callable[[str], str], - var_type: dtype, - binding: int, - name: Optional[str] = None, - ) -> None: - super().__init__(append_func, name_func, var_type, name) - - self.binding = binding - - #def __int__(self): - # return int(self.binding) - -class BufferVariable(BoundVariable): - read_lambda: Callable[[], None] - write_lambda: Callable[[], None] - - def __init__(self, - append_func: Callable[[str], None], - name_func: Callable[[str], Tuple[str, str]], - var_type: dtype, - binding: int, - name: Optional[str] = None, - shape_var: "ShaderVariable" = None, - shape_name: Optional[str] = None, - raw_name: Optional[str] = None, - read_lambda: Callable[[], None] = None, - write_lambda: Callable[[], None] = None, - ) -> None: - super().__init__(append_func, name_func, var_type, binding, name) - - self.name = name if name is not None else self.name - self.raw_name = raw_name if raw_name is not None else self.raw_name - self.settable = True - - self.read_lambda = read_lambda - self.write_lambda = write_lambda - - self._register_shape(shape_var=shape_var, shape_name=shape_name, use_child_type=False) - - def read_callback(self): - self.read_lambda() - - def write_callback(self): - self.write_lambda() - -class ImageVariable(BoundVariable): - dimensions: int = 0 - read_lambda: Callable[[], None] - write_lambda: Callable[[], None] - - def __init__(self, - append_func: Callable[[str], None], - name_func: Callable[[str], Tuple[str, str]], - var_type: dtype, - binding: int, - dimensions: int, - name: Optional[str] = None, - read_lambda: Callable[[], None] = None, - write_lambda: Callable[[], None] = None, - ) -> None: - super().__init__(append_func, name_func, var_type, binding, name) - - self.read_lambda = read_lambda - self.write_lambda = write_lambda - self.dimensions = dimensions - - def read_callback(self): - self.read_lambda() - - def write_callback(self): - self.write_lambda() - - def sample(self, coord: "ShaderVariable", lod: "ShaderVariable" = None) -> "ShaderVariable": - if self.dimensions == 0: - raise ValueError("Cannot sample a texture with dimension 0!") - - sample_coord_string = "" +@dataclasses.dataclass +class ShaderBinding: + """ + A dataclass that represents a bound resource in a shader. Either a + buffer or an image. - if self.dimensions == 1: - sample_coord_string = f"((({coord}) + 0.5) / textureSize({self}, 0))" - elif self.dimensions == 2: - sample_coord_string = f"((vec2({coord}.xy) + 0.5) / vec2(textureSize({self}, 0)))" - elif self.dimensions == 3: - sample_coord_string = f"((vec3({coord}.xyz) + 0.5) / vec3(textureSize({self}, 0)))" - else: - raise ValueError("Unsupported number of dimensions!") + Attributes: + dtype (vd.dtype): The dtype of the resource. If + the resource is an image, this should be vd.vec4 + (since all images are sampled with 4 channels in shaders). + name (str): The name of the resource within the shader code. + dimension (int): The dimension of the resource. Set to 0 for + buffers and 1, 2, or 3 for images. + binding_type (BindingType): The type of the binding. Either + STORAGE_BUFFER, UNIFORM_BUFFER, or SAMPLER. + """ + dtype: dtypes.dtype + name: str + dimension: int + binding_type: BindingType - if lod is None: - return self.new(dtypes.vec4, f"texture({self}, {sample_coord_string})", [self]) - - return self.new(dtypes.vec4, f"textureLod({self}, {sample_coord_string}, {lod})", [self]) +class ShaderFlags(IntFlag): + NONE = 0 + NO_SUBGROUP_OPS = auto() + NO_PRINTF = auto() + NO_EXEC_BOUNDS = auto() -class ShaderBuilder: - var_count: int +class ShaderBuilder(ShaderWriter): binding_count: int binding_read_access: Dict[int, bool] binding_write_access: Dict[int, bool] @@ -914,49 +139,27 @@ class ShaderBuilder: pc_struct: StructBuilder uniform_struct: StructBuilder exec_count: Optional[ShaderVariable] - contents: str - pre_header: str + flags: ShaderFlags + backend: CodeGenBackend def __init__(self, - enable_subgroup_ops: bool = True, - enable_atomic_float_ops: bool = True, - enable_printf: bool = True, - enable_exec_bounds: bool = True, - is_apple_device: bool = False) -> None: - self.enable_subgroup_ops = enable_subgroup_ops - self.enable_atomic_float_ops = enable_atomic_float_ops - self.enable_printf = enable_printf - self.enable_exec_bounds = enable_exec_bounds - self.is_apple_device = is_apple_device - - self.pre_header = "#version 450\n" - self.pre_header += "#extension GL_ARB_separate_shader_objects : enable\n" - - if self.enable_subgroup_ops: - self.pre_header += "#extension GL_KHR_shader_subgroup_arithmetic : enable\n" - - #if self.enable_atomic_float_ops: - # self.pre_header += "#extension GL_EXT_shader_atomic_float : enable\n" - - if self.enable_printf: - self.pre_header += "#extension GL_EXT_debug_printf : enable\n" - - self.global_invocation = self.make_var(dtypes.uvec3, "gl_GlobalInvocationID", [], lexical_unit=True) - self.local_invocation = self.make_var(dtypes.uvec3, "gl_LocalInvocationID", [], lexical_unit=True) - self.workgroup = self.make_var(dtypes.uvec3, "gl_WorkGroupID", [], lexical_unit=True) - self.workgroup_size = self.make_var(dtypes.uvec3, "gl_WorkGroupSize", [], lexical_unit=True) - self.num_workgroups = self.make_var(dtypes.uvec3, "gl_NumWorkGroups", [], lexical_unit=True) - - self.num_subgroups = self.make_var(dtypes.uint32, "gl_NumSubgroups", [], lexical_unit=True) - self.subgroup_id = self.make_var(dtypes.uint32, "gl_SubgroupID", [], lexical_unit=True) + flags: ShaderFlags = ShaderFlags.NONE, + is_apple_device: bool = False, + backend: Optional[CodeGenBackend] = None) -> None: + super().__init__() - self.subgroup_size = self.make_var(dtypes.uint32, "gl_SubgroupSize", [], lexical_unit=True) - self.subgroup_invocation = self.make_var(dtypes.uint32, "gl_SubgroupInvocationID", [], lexical_unit=True) + self.flags = flags + self.is_apple_device = is_apple_device + if backend is not None: + self.backend = backend + else: + # Use the selected backend type while keeping per-builder backend state isolated. + self.backend = get_codegen_backend().__class__() self.reset() def reset(self) -> None: - self.var_count = 0 + self.backend.reset_state() self.binding_count = 0 self.pc_struct = StructBuilder() self.uniform_struct = StructBuilder() @@ -965,88 +168,51 @@ def reset(self) -> None: self.binding_write_access = {} self.shared_buffers = [] self.scope_num = 1 - self.contents = "" - self.mapping_index: ShaderVariable = None - self.kernel_index: ShaderVariable = None - self.mapping_registers: List[ShaderVariable] = None - self.exec_count = self.declare_constant(dtypes.uvec4, var_name="exec_count") - - if self.enable_exec_bounds: - self.if_statement(self.exec_count.x <= self.global_invocation.x) - self.return_statement() - self.end() - - self.if_statement(self.exec_count.y <= self.global_invocation.y) - self.return_statement() - self.end() - - self.if_statement(self.exec_count.z <= self.global_invocation.z) - self.return_statement() - self.end() - - def set_mapping_index(self, index: ShaderVariable): - self.mapping_index = index - - def set_kernel_index(self, index: ShaderVariable): - self.kernel_index = index - - def set_mapping_registers(self, registers: ShaderVariable): - self.mapping_registers = list(registers) - - def append_contents(self, contents: str) -> None: - self.contents += (" " * self.scope_num) + contents - - def comment(self, comment: str) -> None: - self.append_contents("\n") - self.append_contents(f"/* {comment} */\n") - - - def get_name_func(self, prefix: Optional[str] = None, suffix: Optional[str] = None): - my_prefix = [prefix] - my_suffix = [suffix] - def get_name_val(var_name: Union[str, None] = None): - new_var = f"var{self.var_count}" if var_name is None else var_name - raw_name = new_var - - if var_name is None: - self.var_count += 1 - - if my_prefix[0] is not None: - new_var = f"{my_prefix[0]}{new_var}" - my_prefix[0] = None - - if my_suffix[0] is not None: - new_var = f"{new_var}{my_suffix[0]}" - my_suffix[0] = None - - return new_var, raw_name - return get_name_val - - def make_var(self, - var_type: dtype, - var_name: Optional[str], - parents: List[ShaderVariable], - prefix: Optional[str] = None, - suffix: Optional[str] = None, - lexical_unit: bool = False, - settable: bool = False) -> ShaderVariable: - return ShaderVariable( - self.append_contents, - self.get_name_func(prefix, suffix), - var_type, - var_name, - lexical_unit=lexical_unit, - settable=settable, - parent_variables=parents + self.exec_count = None + + if not (self.flags & ShaderFlags.NO_EXEC_BOUNDS): + self.exec_count = self.declare_constant(dtypes.uvec4, var_name="exec_count") + self.append_contents(self.backend.exec_bounds_guard(self.exec_count.resolve())) + + def new_var(self, + var_type: dtypes.dtype, + name: str, + parents: List["ShaderVariable"], + lexical_unit: bool = False, + settable: bool = False, + register: bool = False) -> "ShaderVariable": + return ShaderVariable(var_type, + name, + lexical_unit=lexical_unit, + settable=settable, + register=register, + parents=parents) + + def new_scaled_var(self, + var_type: dtypes.dtype, + name: str, + scale: int = 1, + offset: int = 0, + parents: List[BaseVariable] = None): + return ScaledAndOfftsetIntVariable(var_type, + name, + scale=scale, + offset=offset, + parents=parents) + + def declare_constant(self, var_type: dtypes.dtype, count: int = 1, var_name: Optional[str] = None): + if var_name is None: + var_name = self.new_name() + + new_var = ShaderVariable( + var_type=var_type, + name=f"{self.backend.constant_namespace()}.{var_name}", + raw_name=var_name, + lexical_unit=True, + settable=False, + parents=[] ) - - def declare_constant(self, var_type: dtype, count: int = 1, var_name: Optional[str] = None): - suffix = None - if var_type.glsl_type_extern is not None: - suffix = ".xyz" - - new_var = self.make_var(var_type, var_name, [], "UBO.", suffix) if count > 1: new_var.use_child_type = False @@ -1055,13 +221,21 @@ def declare_constant(self, var_type: dtype, count: int = 1, var_name: Optional[s self.uniform_struct.register_element(new_var.raw_name, var_type, count) return new_var - def declare_variable(self, var_type: dtype, count: int = 1, var_name: Optional[str] = None): - suffix = None - if var_type.glsl_type_extern is not None: - suffix = ".xyz" - - new_var = self.make_var(var_type, var_name, [], "PC.", suffix) - new_var._varying = True + def declare_variable(self, var_type: dtypes.dtype, count: int = 1, var_name: Optional[str] = None): + if self.backend.name in _PUSH_CONSTANT_UNSUPPORTED_BACKENDS: + raise NotImplementedError(_push_constant_not_supported_error(self.backend.name)) + + if var_name is None: + var_name = self.new_name() + + new_var = ShaderVariable( + var_type=var_type, + name=f"{self.backend.variable_namespace()}.{var_name}", + raw_name=var_name, + lexical_unit=True, + settable=False, + parents=[] + ) if count > 1: new_var.use_child_type = False @@ -1070,11 +244,15 @@ def declare_variable(self, var_type: dtype, count: int = 1, var_name: Optional[s self.pc_struct.register_element(new_var.raw_name, var_type, count) return new_var - def declare_buffer(self, var_type: dtype, var_name: Optional[str] = None): + def declare_buffer(self, var_type: dtypes.dtype, var_name: Optional[str] = None): self.binding_count += 1 buffer_name = f"buf{self.binding_count}" if var_name is None else var_name shape_name = f"{buffer_name}_shape" + scalar_expr = None + + if self.backend.name == "opencl" and (dtypes.is_vector(var_type) or dtypes.is_complex(var_type)): + scalar_expr = f"{buffer_name}_scalar" self.binding_list.append(ShaderBinding(var_type, buffer_name, 0, BindingType.STORAGE_BUFFER)) self.binding_read_access[self.binding_count] = False @@ -1087,15 +265,18 @@ def read_lambda(): def write_lambda(): self.binding_write_access[current_binding_count] = True + + def shape_var_factory(): + return self.declare_constant(dtypes.ivec4, var_name=shape_name) return BufferVariable( - self.append_contents, - self.get_name_func(), var_type, self.binding_count, f"{buffer_name}.data", - self.declare_constant(dtypes.ivec4, var_name=shape_name), - shape_name, + shape_var_factory=shape_var_factory, + shape_name=shape_name, + scalar_expr=scalar_expr, + codegen_backend=self.backend, read_lambda=read_lambda, write_lambda=write_lambda ) @@ -1115,8 +296,6 @@ def write_lambda(): self.binding_write_access[self.binding_count] = True return ImageVariable( - self.append_contents, - self.get_name_func(), dtypes.vec4, self.binding_count, dimensions, @@ -1125,18 +304,23 @@ def write_lambda(): write_lambda=write_lambda ) - def shared_buffer(self, var_type: dtype, size: int, var_name: Optional[str] = None): - buffer_name = self.get_name_func()(var_name)[0] - shape_name = f"{buffer_name}_shape" + def shared_buffer(self, var_type: dtypes.dtype, size: int, var_name: Optional[str] = None): + if var_name is None: + var_name = self.new_name() + + shape_name = f"{var_name}_shape" + + def shape_var_factory(): + return self.declare_constant(dtypes.ivec4, var_name=shape_name) new_var = BufferVariable( - self.append_contents, - self.get_name_func(), var_type, -1, - buffer_name, - self.declare_constant(dtypes.ivec4, var_name=shape_name), - shape_name, + var_name, + shape_var_factory=shape_var_factory, + shape_name=shape_name, + scalar_expr=None, + codegen_backend=self.backend, read_lambda=lambda: None, write_lambda=lambda: None ) @@ -1145,465 +329,55 @@ def shared_buffer(self, var_type: dtype, size: int, var_name: Optional[str] = No return new_var - def abs(self, arg: ShaderVariable): - return self.make_var(arg.var_type, f"abs({arg})", [arg], lexical_unit=True) - - def acos(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"acos({arg})", [arg], lexical_unit=True) - - def acosh(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"acosh({arg})", [arg], lexical_unit=True) - - def asin(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"asin({arg})", [arg], lexical_unit=True) - - def asinh(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"asinh({arg})", [arg], lexical_unit=True) - - def atan(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"atan({arg})", [arg], lexical_unit=True) - - def atan2(self, arg1: ShaderVariable, arg2: ShaderVariable): - # TODO: correctly handle pure float inputs - - floating_arg1 = var_types_to_floating(arg1.var_type) - floating_arg2 = var_types_to_floating(arg2.var_type) - - assert floating_arg1 == floating_arg2, f"Both arguments to atan2 ({arg1.var_type} and {arg2.var_type}) must be of the same dimentionality" - - return self.make_var(floating_arg1, f"atan({arg1}, {arg2})", [arg1, arg2], lexical_unit=True) - - def atanh(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"atanh({arg})", [arg], lexical_unit=True) - - def atomic_add(self, arg1: ShaderVariable, arg2: ShaderVariable): - if not isinstance(arg1, ShaderVariable): - raise TypeError("First argument to atomic_add must be a ShaderVariable") - - arg1.read_callback() - arg1.write_callback() - - if isinstance(arg2, ShaderVariable): - arg2.read_callback() - - new_var = self.make_var(arg1.var_type, None, []) - self.append_contents(f"{new_var.var_type.glsl_type} {new_var.name} = atomicAdd({arg1}, {arg2});\n") - return new_var - - def barrier(self): - if self.is_apple_device: - self.memory_barrier() - - self.append_contents("barrier();\n") - - def ceil(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"ceil({arg})", [arg], lexical_unit=True) - - def clamp(self, arg: ShaderVariable, min_val: ShaderVariable, max_val: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"clamp({arg}, {min_val}, {max_val})", [arg, min_val, max_val], lexical_unit=True) - - def cos(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"cos({arg})", [arg], lexical_unit=True) - - def cosh(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"cosh({arg})", [arg], lexical_unit=True) - - def cross(self, arg1: ShaderVariable, arg2: ShaderVariable): - return self.make_var(dtypes.v3, f"cross({arg1}, {arg2})", [arg1, arg2], lexical_unit=True) - - def degrees(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"degrees({arg})", [arg], lexical_unit=True) - - def determinant(self, arg: ShaderVariable): - return self.make_var(dtypes.float32, f"determinant({arg})", [arg], lexical_unit=True) - - def distance(self, arg1: ShaderVariable, arg2: ShaderVariable): - return self.make_var(dtypes.float32, f"distance({arg1}, {arg2})", [arg1, arg2], lexical_unit=True) - - def dot(self, arg1: ShaderVariable, arg2: ShaderVariable): - return self.make_var(dtypes.float32, f"dot({arg1}, {arg2})", [arg1, arg2], lexical_unit=True) - - def exp(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"exp({arg})", [arg], lexical_unit=True) - - def exp2(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"exp2({arg})", [arg], lexical_unit=True) - - def float_bits_to_int(self, arg: ShaderVariable): - return self.make_var(dtypes.int32, f"floatBitsToInt({arg})", [arg], lexical_unit=True) - - def float_bits_to_uint(self, arg: ShaderVariable): - return self.make_var(dtypes.uint32, f"floatBitsToUint({arg})", [arg], lexical_unit=True) - - def floor(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"floor({arg})", [arg], lexical_unit=True) - - def fma(self, arg1: ShaderVariable, arg2: ShaderVariable, arg3: ShaderVariable): - # TODO: properly handle type conversion and float inputs - - return self.make_var(arg1.var_type, f"fma({arg1}, {arg2}, {arg3})", [arg1, arg2, arg3], lexical_unit=True) - - def int_bits_to_float(self, arg: ShaderVariable): - return self.make_var(dtypes.float32, f"intBitsToFloat({arg})", [arg], lexical_unit=True) - - def inverse(self, arg: ShaderVariable): - assert arg.var_type.dimentions == 2, f"Cannot apply inverse to non-matrix type {arg.var_type}" - - return self.make_var(arg.var_type, f"inverse({arg})", [arg], lexical_unit=True) - - def inverse_sqrt(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"inversesqrt({arg})", [arg], lexical_unit=True) - - def isinf(self, arg: ShaderVariable): - return self.make_var(dtypes.int32, f"any(isinf({arg}))", [arg], lexical_unit=True) - - def isnan(self, arg: ShaderVariable): - return self.make_var(dtypes.int32, f"any(isnan({arg}))", [arg], lexical_unit=True) - - def length(self, arg: ShaderVariable): - return self.make_var(dtypes.float32, f"length({arg})", [arg], lexical_unit=True) - - def log(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"log({arg})", [arg], lexical_unit=True) - - def log2(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"log2({arg})", [arg], lexical_unit=True) - - def max(self, arg1: ShaderVariable, arg2: ShaderVariable): - # TODO: properly handle type conversion and float inputs - - return self.make_var(arg1.var_type, f"max({arg1}, {arg2})", [arg1, arg2], lexical_unit=True) - - def memory_barrier(self): - self.append_contents("memoryBarrier();\n") - - def memory_barrier_shared(self): - self.append_contents("memoryBarrierShared();\n") - - def min(self, arg1: ShaderVariable, arg2: ShaderVariable): - # TODO: properly handle type conversion and float inputs - - return self.make_var(arg1.var_type, f"min({arg1}, {arg2})", [arg1, arg2], lexical_unit=True) - - def mix(self, arg1: ShaderVariable, arg2: ShaderVariable, arg3: ShaderVariable): - # TODO: properly handle type conversion and float inputs - - return self.make_var(arg1.var_type, f"mix({arg1}, {arg2}, {arg3})", [arg1, arg2, arg3], lexical_unit=True) - - def mod(self, arg1: ShaderVariable, arg2: ShaderVariable): - # TODO: properly handle type conversion and float inputs - - return self.make_var(arg1.var_type, f"mod({arg1}, {arg2})", [arg1, arg2], lexical_unit=True) - - def normalize(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"normalize({arg})", [arg], lexical_unit=True) - - def pow(self, arg1: ShaderVariable, arg2: ShaderVariable): - return self.make_var(arg1.var_type, f"pow({arg1}, {arg2})", [arg1, arg2], lexical_unit=True) - - def radians(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"radians({arg})", [arg], lexical_unit=True) - - def round(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"round({arg})", [arg], lexical_unit=True) - - def round_even(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"roundEven({arg})", [arg], lexical_unit=True) - - def sign(self, arg: ShaderVariable): - return self.make_var(arg.var_type, f"sign({arg})", [arg], lexical_unit=True) - - def sin(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"sin({arg})", [arg], lexical_unit=True) - - def sinh(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"sinh({arg})", [arg], lexical_unit=True) - - def smoothstep(self, arg1: ShaderVariable, arg2: ShaderVariable, arg3: ShaderVariable): - # TODO: properly handle type conversion and float inputs - - return self.make_var(arg1.var_type, f"smoothstep({arg1}, {arg2}, {arg3})", [arg1, arg2, arg3], lexical_unit=True) - - def sqrt(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"sqrt({arg})", [arg], lexical_unit=True) - - def step(self, arg1: ShaderVariable, arg2: ShaderVariable): - return self.make_var(arg1.var_type, f"step({arg1}, {arg2})", [arg1, arg2], lexical_unit=True) - - def tan(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"tan({arg})", [arg], lexical_unit=True) - - def tanh(self, arg: ShaderVariable): - return self.make_var(var_types_to_floating(arg.var_type), f"tanh({arg})", [arg], lexical_unit=True) - - def transpose(self, arg: ShaderVariable): - return self.make_var(arg.var_type, f"transpose({arg})", [arg], lexical_unit=True) - - def trunc(self, arg: ShaderVariable): - return self.make_var(arg.var_type, f"trunc({arg})", [arg], lexical_unit=True) - - def uint_bits_to_float(self, arg: ShaderVariable): - return self.make_var(dtypes.float32, f"uintBitsToFloat({arg})", [arg], lexical_unit=True) - - def mult_c64(self, arg1: ShaderVariable, arg2: ShaderVariable): - new_var = self.make_var( - arg1.var_type, - f"vec2({arg1}.x * {arg2}.x - {arg1}.y * {arg2}.y, {arg1}.x * {arg2}.y + {arg1}.y * {arg2}.x)", - [arg1, arg2], - lexical_unit=True - ) - return new_var - - def mult_c64_by_const(self, arg1: ShaderVariable, number: complex): - if isinstance(number, ShaderVariable): - raise ValueError("Cannot multiply complex number by a variable, use mult_c64 instead.") - - new_var = self.make_var( - arg1.var_type, - f"vec2({arg1}.x * {number.real} - {arg1}.y * {number.imag}, {arg1}.x * {number.imag} + {arg1}.y * {number.real})", - [arg1], - lexical_unit=True - ) - return new_var - - def mult_conj_c64(self, arg1: ShaderVariable, arg2: ShaderVariable): - new_var = self.make_var( - arg1.var_type, - f"vec2({arg1}.x * {arg2}.x + {arg1}.y * {arg2}.y, {arg1}.y * {arg2}.x - {arg1}.x * {arg2}.y)", - [arg1, arg2], - lexical_unit=True - ) - return new_var - - def if_statement(self, arg: ShaderVariable, command: Optional[str] = None): - if command is None: - self.append_contents(f"if({arg}) {'{'}\n") - self.scope_num += 1 - return - - self.append_contents(f"if({arg})\n") - self.scope_num += 1 - self.append_contents(f"{command}\n") - self.scope_num -= 1 - - def if_any(self, *args: List[ShaderVariable]): - self.append_contents(f"if({' || '.join([str(elem) for elem in args])}) {'{'}\n") - self.scope_num += 1 - - def if_all(self, *args: List[ShaderVariable]): - self.append_contents(f"if({' && '.join([str(elem) for elem in args])}) {'{'}\n") - self.scope_num += 1 - - def else_statement(self): - self.scope_num -= 1 - self.append_contents("} else {\n") - self.scope_num += 1 - - def else_if_statement(self, arg: ShaderVariable): - self.scope_num -= 1 - self.append_contents(f"}} else if({arg}) {'{'}\n") - self.scope_num += 1 - - def else_if_any(self, *args: List[ShaderVariable]): - self.scope_num -= 1 - self.append_contents(f"}} else if({' || '.join([str(elem) for elem in args])}) {'{'}\n") - self.scope_num += 1 - - def else_if_all(self, *args: List[ShaderVariable]): - self.scope_num -= 1 - self.append_contents(f"}} else if({' && '.join([str(elem) for elem in args])}) {'{'}\n") - self.scope_num += 1 - - def return_statement(self, arg=None): - arg = arg if arg is not None else "" - self.append_contents(f"return {arg};\n") - - def while_statement(self, arg: ShaderVariable): - self.append_contents(f"while({arg}) {'{'}\n") - self.scope_num += 1 - - def new_scope(self, comment: str = None): - if comment is None: - self.append_contents("{\n") - else: - self.append_contents("{ " + f"/* {comment} */\n") - - self.scope_num += 1 - - def end(self): - self.scope_num -= 1 - self.append_contents("}\n") - - def logical_and(self, arg1: ShaderVariable, arg2: ShaderVariable): - return self.make_var(dtypes.int32, f"({arg1} && {arg2})", [arg1, arg2]) - - def logical_or(self, arg1: ShaderVariable, arg2: ShaderVariable): - return self.make_var(dtypes.int32, f"({arg1} || {arg2})", [arg1, arg2]) - - def subgroup_add(self, arg1: ShaderVariable): - return self.make_var(arg1.var_type, f"subgroupAdd({arg1})", [arg1], lexical_unit=True) - - def subgroup_mul(self, arg1: ShaderVariable): - return self.make_var(arg1.var_type, f"subgroupMul({arg1})", [arg1], lexical_unit=True) - - def subgroup_min(self, arg1: ShaderVariable): - return self.make_var(arg1.var_type, f"subgroupMin({arg1})", [arg1], lexical_unit=True) - - def subgroup_max(self, arg1: ShaderVariable): - return self.make_var(arg1.var_type, f"subgroupMax({arg1})", [arg1], lexical_unit=True) - - def subgroup_and(self, arg1: ShaderVariable): - return self.make_var(arg1.var_type, f"subgroupAnd({arg1})", [arg1], lexical_unit=True) - - def subgroup_or(self, arg1: ShaderVariable): - return self.make_var(arg1.var_type, f"subgroupOr({arg1})", [arg1], lexical_unit=True) - - def subgroup_xor(self, arg1: ShaderVariable): - return self.make_var(arg1.var_type, f"subgroupXor({arg1})", [arg1], lexical_unit=True) - - def subgroup_elect(self): - return self.make_var(dtypes.int32, f"subgroupElect()", [], lexical_unit=True) - - def subgroup_barrier(self): - self.append_contents("subgroupBarrier();\n") - - def new(self, var_type: dtype, *args, var_name: Optional[str] = None): - new_var = self.make_var(var_type, var_name, [], lexical_unit=True, settable=True) - - for arg in args: - if isinstance(arg, ShaderVariable): - arg.read_callback() - - decleration_suffix = "" - if len(args) > 0: - decleration_suffix = f" = {var_type.glsl_type}({', '.join([str(elem) for elem in args])})" - - self.append_contents(f"{new_var.var_type.glsl_type} {new_var.name}{decleration_suffix};\n") - - return new_var - - def new_float(self, *args, var_name: Optional[str] = None): - return self.new(dtypes.float32, *args, var_name=var_name) - - def new_int(self, *args, var_name: Optional[str] = None): - return self.new(dtypes.int32, *args, var_name=var_name) - - def new_uint(self, *args, var_name: Optional[str] = None): - return self.new(dtypes.uint32, *args, var_name=var_name) - - def new_vec2(self, *args, var_name: Optional[str] = None): - return self.new(dtypes.vec2, *args, var_name=var_name) - - def new_vec3(self, *args, var_name: Optional[str] = None): - return self.new(dtypes.vec3, *args, var_name=var_name) - - def new_vec4(self, *args, var_name: Optional[str] = None): - return self.new(dtypes.vec4, *args, var_name=var_name) - - def new_uvec2(self, *args, var_name: Optional[str] = None): - return self.new(dtypes.uvec2, *args, var_name=var_name) - - def new_uvec3(self, *args, var_name: Optional[str] = None): - return self.new(dtypes.uvec3, *args, var_name=var_name) - - def new_uvec4(self, *args, var_name: Optional[str] = None): - return self.new(dtypes.uvec4, *args, var_name=var_name) - - def new_ivec2(self, *args, var_name: Optional[str] = None): - return self.new(dtypes.ivec2, *args, var_name=var_name) - - def new_ivec3(self, *args, var_name: Optional[str] = None): - return self.new(dtypes.ivec3, *args, var_name=var_name) - - def new_ivec4(self, *args, var_name: Optional[str] = None): - return self.new(dtypes.ivec4, *args, var_name=var_name) - - def printf(self, format: str, *args: Union[ShaderVariable, str], seperator=" "): - args_string = "" - - for arg in args: - args_string += f", {arg}" - - self.append_contents(f'debugPrintfEXT("{format}" {args_string});\n') - - def print_vars(self, *args: Union[ShaderVariable, str], seperator=" "): - args_list = [] - - fmts = [] - - for arg in args: - if isinstance(arg, ShaderVariable): - args_list.append(arg.printf_args()) - fmts.append(arg.var_type.format_str) - else: - fmts.append(str(arg)) - - fmt = seperator.join(fmts) - - args_argument = "" - - if len(args_list) > 0: - args_argument = f", {','.join(args_list)}" - - self.append_contents(f'debugPrintfEXT("{fmt}"{args_argument});\n') - - def unravel_index(self, index: ShaderVariable, shape: ShaderVariable): - new_var = self.new_uvec3() - - new_var.x = index % shape.x - new_var.y = (index / shape.x) % shape.y - new_var.z = index / (shape.x * shape.y) - - return new_var - - def complex_from_euler_angle(self, angle: ShaderVariable): - return self.make_var(dtypes.vec2, f"vec2({self.cos(angle)}, {self.sin(angle)})", [angle]) - def compose_struct_decleration(self, elements: List[StructElement]) -> str: declerations = [] for elem in elements: - decleration_type = f"{elem.dtype.glsl_type}" - if elem.dtype.glsl_type_extern is not None: - decleration_type = f"{elem.dtype.glsl_type_extern}" + decleration_type = self.backend.type_name(elem.dtype) decleration_suffix = "" if elem.count > 1: decleration_suffix = f"[{elem.count}]" - declerations.append(f"\t{decleration_type} {elem.name}{decleration_suffix};") + declerations.append(f" {decleration_type} {elem.name}{decleration_suffix};") return "\n".join(declerations) def build(self, name: str) -> ShaderDescription: - header = "" + self.pre_header + header = "" for shared_buffer in self.shared_buffers: - header += f"shared {shared_buffer.dtype.glsl_type} {shared_buffer.name}[{shared_buffer.size}];\n" + header += self.backend.shared_buffer_declaration( + shared_buffer.dtype, + shared_buffer.name, + shared_buffer.size + ) + "\n" uniform_elements = self.uniform_struct.build() uniform_decleration_contents = self.compose_struct_decleration(uniform_elements) - if len(uniform_decleration_contents) > 0: - header += f"\nlayout(set = 0, binding = 0) uniform UniformObjectBuffer {{\n { uniform_decleration_contents } \n}} UBO;\n" + has_uniform_buffer = len(uniform_decleration_contents) > 0 + if has_uniform_buffer: + header += self.backend.uniform_block_declaration(uniform_decleration_contents) - binding_type_list = [BindingType.UNIFORM_BUFFER] - binding_access = [(True, False)] # UBO is read-only + binding_base = 1 if has_uniform_buffer else 0 + binding_type_list = [] + binding_access = [] + if has_uniform_buffer: + binding_type_list.append(BindingType.UNIFORM_BUFFER) + binding_access.append((True, False)) # UBO is read-only for ii, binding in enumerate(self.binding_list): + emitted_binding = ii + binding_base if binding.binding_type == BindingType.STORAGE_BUFFER: - true_type = binding.dtype.glsl_type - if binding.dtype.glsl_type_extern is not None: - true_type = binding.dtype.glsl_type_extern - - header += f"layout(set = 0, binding = {ii + 1}) buffer Buffer{ii + 1} {{ {true_type} data[]; }} {binding.name};\n" + header += self.backend.storage_buffer_declaration(emitted_binding, binding.dtype, binding.name) binding_type_list.append(binding.binding_type) binding_access.append(( self.binding_read_access[ii + 1], self.binding_write_access[ii + 1] )) else: - header += f"layout(set = 0, binding = {ii + 1}) uniform sampler{binding.dimension}D {binding.name};\n" + header += self.backend.sampler_declaration(emitted_binding, binding.dimension, binding.name) binding_type_list.append(binding.binding_type) binding_access.append(( self.binding_read_access[ii + 1], @@ -1615,16 +389,26 @@ def build(self, name: str) -> ShaderDescription: pc_decleration_contents = self.compose_struct_decleration(pc_elements) if len(pc_decleration_contents) > 0: - header += f"\nlayout(push_constant) uniform PushConstant {{\n { pc_decleration_contents } \n}} PC;\n" + assert self.backend.name not in _PUSH_CONSTANT_UNSUPPORTED_BACKENDS, ( + _push_constant_not_supported_error(self.backend.name) + ) + header += self.backend.push_constant_declaration(pc_decleration_contents) + + pre_header = self.backend.pre_header( + enable_subgroup_ops=not (self.flags & ShaderFlags.NO_SUBGROUP_OPS), + enable_printf=not (self.flags & ShaderFlags.NO_PRINTF) + ) return ShaderDescription( - header=header, - body=f"void main() {{\n{self.contents}\n}}\n", + header=f"{pre_header}{header}", + body=self.backend.entry_point(self.contents), name=name, pc_size=self.pc_struct.size, pc_structure=pc_elements, uniform_structure=uniform_elements, binding_type_list=[binding.value for binding in binding_type_list], binding_access=binding_access, - exec_count_name=self.exec_count.raw_name - ) \ No newline at end of file + exec_count_name=self.exec_count.raw_name if self.exec_count is not None else None, + resource_binding_base=binding_base, + backend=self.backend + ) diff --git a/vkdispatch/codegen/functions/__init__.py b/vkdispatch/codegen/functions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vkdispatch/codegen/functions/atomic_memory.py b/vkdispatch/codegen/functions/atomic_memory.py new file mode 100644 index 00000000..7efb8590 --- /dev/null +++ b/vkdispatch/codegen/functions/atomic_memory.py @@ -0,0 +1,76 @@ +from typing import Any, List + +import vkdispatch.base.dtype as dtypes + +from ..variables.base_variable import BaseVariable +from ..variables.bound_variables import BufferVariable +from ..variables.variables import ShaderVariable +from . import utils + + +def _is_buffer_backed_target(var: ShaderVariable) -> bool: + stack: List[BaseVariable] = [var] + visited_ids = set() + + while len(stack) > 0: + current = stack.pop() + current_id = id(current) + if current_id in visited_ids: + continue + visited_ids.add(current_id) + + if isinstance(current, BufferVariable): + return True + + stack.extend(current.parents) + + return False + + +# https://docs.vulkan.org/glsl/latest/chapters/builtinfunctions.html#atomic-memory-functions +def atomic_add(mem: ShaderVariable, y: Any) -> ShaderVariable: + assert isinstance(mem, ShaderVariable), f"atomic_add target must be a ShaderVariable, got {type(mem)}" + assert dtypes.is_scalar(mem.var_type), "atomic_add target must be a scalar lvalue" + assert mem.is_setable(), "atomic_add target must be a writable lvalue" + assert not mem.is_register(), "atomic_add does not support register/local variables as target" + assert _is_buffer_backed_target(mem), "atomic_add target must reference a buffer element (e.g., buf[idx])" + + assert mem.var_type in (dtypes.int32, dtypes.uint32), ( + f"atomic_add currently supports only int32/uint32 targets, got '{mem.var_type.name}'" + ) + + parents: List[BaseVariable] = [mem] + + if isinstance(y, ShaderVariable): + assert dtypes.is_scalar(y.var_type), "atomic_add increment variable must be scalar" + assert dtypes.is_integer_dtype(y.var_type), ( + f"atomic_add increment variable must be integer-typed, got '{y.var_type.name}'" + ) + y.read_callback() + parents.append(y) + y_expr = utils.backend_constructor(mem.var_type, y) + elif utils.is_int_number(y): + y_expr = utils.backend_constructor(mem.var_type, y) + elif utils.is_number(y): + raise TypeError(f"atomic_add increment must be an integer scalar, got {y!r}") + else: + raise TypeError(f"atomic_add increment must be an integer scalar or ShaderVariable, got {type(y)}") + + mem.read_callback() + mem.write_callback() + + result_var = utils.new_var( + mem.var_type, + None, + parents=parents, + lexical_unit=True, + settable=True, + register=True + ) + + atomic_expr = utils.codegen_backend().atomic_add_expr(mem.resolve(), y_expr, mem.var_type) + utils.append_contents( + f"{utils.backend_type_name(result_var.var_type)} {result_var.name} = {atomic_expr};\n" + ) + + return result_var diff --git a/vkdispatch/codegen/functions/base_functions/__init__.py b/vkdispatch/codegen/functions/base_functions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vkdispatch/codegen/functions/base_functions/arithmetic.py b/vkdispatch/codegen/functions/base_functions/arithmetic.py new file mode 100644 index 00000000..79e890e5 --- /dev/null +++ b/vkdispatch/codegen/functions/base_functions/arithmetic.py @@ -0,0 +1,531 @@ +import vkdispatch.base.dtype as dtypes +from vkdispatch.codegen.variables.base_variable import BaseVariable +from typing import Any, Tuple, Union + +from .. import scalar_eval as se + +def my_log2_int(x: int) -> int: + return int(se.round(se.log2(x))) + + +from . import base_utils + + +def _mark_arith_unary(var: BaseVariable, op: str) -> None: + base_utils.get_codegen_backend().mark_composite_unary_op(var.var_type, op) + + +def _mark_arith_binary(lhs_type: dtypes.dtype, rhs_type: dtypes.dtype, op: str, *, inplace: bool = False) -> None: + base_utils.get_codegen_backend().mark_composite_binary_op(lhs_type, rhs_type, op, inplace=inplace) + +def _resolve_arithmetic_binary_expr( + op: str, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, +) -> Tuple[str, bool]: + override_expr = base_utils.get_codegen_backend().arithmetic_binary_expr( + op, lhs_type, lhs_expr, rhs_type, rhs_expr + ) + if override_expr is not None: + return override_expr, True + return f"{lhs_expr} {op} {rhs_expr}", False + +def _resolve_arithmetic_unary_expr(op: str, var_type: dtypes.dtype, var_expr: str) -> Tuple[str, bool]: + override_expr = base_utils.get_codegen_backend().arithmetic_unary_expr(op, var_type, var_expr) + if override_expr is not None: + return override_expr, True + return f"{op}{var_expr}", False + +def arithmetic_op_common(var: BaseVariable, + other: Any, + reverse: bool = False, + inplace: bool = False) -> BaseVariable: + assert isinstance(var, BaseVariable), "First argument must be a ShaderVariable" + + result_type = None + + if base_utils.is_scalar_number(other): + result_type = dtypes.cross_type(var.var_type, base_utils.number_to_dtype(other)) + elif isinstance(other, BaseVariable): + result_type = dtypes.cross_type(var.var_type, other.var_type) + elif base_utils.is_complex_number(other): + raise TypeError("Python built-in complex numbers are not supported in arithmetic operations yet!") + else: + raise TypeError(f"Unsupported type for arithmetic op: ShaderVariable and {type(other)}") + + if inplace: + assert var.is_setable(), "Inplace arithmetic requires the variable to be settable." + assert not reverse, "Inplace arithmetic does not support reverse operations." + var.read_callback() + var.write_callback() + assert result_type == var.var_type, "Inplace arithmetic requires the result type to match the variable type." + + if base_utils.is_scalar_number(other): + return result_type + + if inplace: + other.read_callback() + + return dtypes.cross_type(var.var_type, other.var_type) + +def add(var: BaseVariable, other: Any, inplace: bool = False) -> BaseVariable: + return_type = arithmetic_op_common(var, other, inplace=inplace) + + if base_utils.is_scalar_number(other): + scalar_type = base_utils.number_to_dtype(other) + scalar_expr = base_utils.format_number_literal(other) + _mark_arith_binary(var.var_type, scalar_type, "+", inplace=inplace) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "+", + var.var_type, + var.resolve(), + scalar_type, + scalar_expr, + ) + if not inplace: + if use_assignment: + return base_utils.new_base_var( + return_type, + expr, + parents=[var], + ) + return base_utils.new_scaled_var( + return_type, + var.resolve(), + offset=other, + parents=[var]) + + if use_assignment: + base_utils.append_contents(f"{var.resolve()} = {expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} += {scalar_expr};\n") + return var + + assert isinstance(other, BaseVariable) + _mark_arith_binary(var.var_type, other.var_type, "+", inplace=inplace) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "+", + var.var_type, + var.resolve(), + other.var_type, + other.resolve(), + ) + + if not inplace: + return base_utils.new_base_var( + return_type, + expr, + parents=[var, other]) + + if use_assignment: + base_utils.append_contents(f"{var.resolve()} = {expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} += {other.resolve()};\n") + return var + +def sub(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: + return_type = arithmetic_op_common(var, other, reverse=reverse, inplace=inplace) + + if base_utils.is_scalar_number(other): + scalar_type = base_utils.number_to_dtype(other) + scalar_expr = base_utils.format_number_literal(other) + if reverse and not inplace: + _mark_arith_unary(var, "-") + _mark_arith_binary(var.var_type, scalar_type, "+", inplace=False) + else: + # Non-reverse scalar subtraction is emitted as `+ (-scalar)` via scaled-var optimization. + _mark_arith_binary(var.var_type, scalar_type, "+" if not inplace else "-", inplace=inplace) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "-", + scalar_type if reverse else var.var_type, + scalar_expr if reverse else var.resolve(), + var.var_type if reverse else scalar_type, + var.resolve() if reverse else scalar_expr, + ) + if not inplace: + if use_assignment: + return base_utils.new_base_var( + return_type, + expr, + parents=[var], + ) + return base_utils.new_scaled_var( + return_type, + f"(-{var.resolve()})" if reverse else var.resolve(), + offset=other if reverse else -other, + parents=[var]) + + if use_assignment: + base_utils.append_contents(f"{var.resolve()} = {expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} -= {scalar_expr};\n") + return var + + assert isinstance(other, BaseVariable) + lhs_type = var.var_type if not reverse else other.var_type + rhs_type = other.var_type if not reverse else var.var_type + _mark_arith_binary(lhs_type, rhs_type, "-", inplace=inplace) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "-", + lhs_type, + var.resolve() if not reverse else other.resolve(), + rhs_type, + other.resolve() if not reverse else var.resolve(), + ) + + if not inplace: + return base_utils.new_base_var( + return_type, + expr, + parents=[var, other]) + + if use_assignment: + base_utils.append_contents(f"{var.resolve()} = {expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} -= {other.resolve()};\n") + return var + +def mul(var: BaseVariable, other: Any, inplace: bool = False) -> BaseVariable: + if base_utils.is_scalar_number(other): + return_type = arithmetic_op_common(var, other, inplace=inplace) + scalar_type = base_utils.number_to_dtype(other) + scalar_expr = base_utils.format_number_literal(other) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "*", + var.var_type, + var.resolve(), + scalar_type, + scalar_expr, + ) + if not inplace: + if other == 1: + return var + + if ( + not use_assignment + and dtypes.is_integer_dtype(var.var_type) + and base_utils.is_int_number(other) + and base_utils.is_int_power_of_2(other) + ): + power = my_log2_int(other) + _mark_arith_binary(var.var_type, scalar_type, "<<", inplace=False) + return base_utils.new_base_var(var.var_type, f"{var.resolve()} << {power}", [var]) + + _mark_arith_binary(var.var_type, scalar_type, "*", inplace=False) + if use_assignment: + return base_utils.new_base_var( + return_type, + expr, + parents=[var], + ) + return base_utils.new_scaled_var(return_type, var.resolve(), scale=other, parents=[var]) + + _mark_arith_binary(var.var_type, scalar_type, "*", inplace=True) + if use_assignment: + base_utils.append_contents(f"{var.resolve()} = {expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} *= {scalar_expr};\n") + return var + + assert isinstance(other, BaseVariable) + + if dtypes.is_complex(var.var_type) and dtypes.is_complex(other.var_type): + raise ValueError("Complex multiplication is not supported via the `*` operator.") + + if dtypes.is_matrix(var.var_type) and dtypes.is_matrix(other.var_type): + raise ValueError("Matrix multiplication is not supported via the `*` operator. Use `@` operator instead.") + + return_type = dtypes.cross_multiply_type(var.var_type, other.var_type) + if inplace: + assert var.is_setable(), "Inplace arithmetic requires the variable to be settable." + var.read_callback() + var.write_callback() + other.read_callback() + assert return_type == var.var_type, "Inplace arithmetic requires the result type to match the variable type." + + _mark_arith_binary(var.var_type, other.var_type, "*", inplace=inplace) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "*", + var.var_type, + var.resolve(), + other.var_type, + other.resolve(), + ) + if not inplace: + return base_utils.new_base_var( + return_type, + expr, + parents=[var, other]) + + if use_assignment: + base_utils.append_contents(f"{var.resolve()} = {expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} *= {other.resolve()};\n") + return var + +def truediv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: + if dtypes.is_integer_dtype(var.var_type) and inplace: + raise ValueError("Inplace true division is not supported for integer types.") + + return_type = arithmetic_op_common(var, other, reverse=reverse, inplace=inplace) + return_type = dtypes.make_floating_dtype(return_type) + + if base_utils.is_scalar_number(other): + scalar_f_type = dtypes.float32 + other_expr = base_utils.format_number_literal(other, force_float32=True) + if not reverse: + _mark_arith_binary(return_type, scalar_f_type, "/", inplace=inplace) + else: + _mark_arith_binary(scalar_f_type, return_type, "/", inplace=inplace) + lhs_expr = base_utils.to_dtype_base(return_type, var).resolve() if not reverse else other_expr + rhs_expr = other_expr if not reverse else base_utils.to_dtype_base(return_type, var).resolve() + lhs_type = return_type if not reverse else scalar_f_type + rhs_type = scalar_f_type if not reverse else return_type + expr, use_assignment = _resolve_arithmetic_binary_expr( + "/", + lhs_type, + lhs_expr, + rhs_type, + rhs_expr, + ) + if not inplace: + return base_utils.new_base_var( + return_type, + expr, + parents=[var]) + + if use_assignment: + inplace_expr, _ = _resolve_arithmetic_binary_expr( + "/", + var.var_type, + var.resolve(), + scalar_f_type, + other_expr, + ) + base_utils.append_contents(f"{var.resolve()} = {inplace_expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} /= {other_expr};\n") + return var + + assert isinstance(other, BaseVariable) + + if dtypes.is_complex(var.var_type) and dtypes.is_complex(other.var_type): + raise ValueError("Complex division is not supported.") + + if dtypes.is_matrix(var.var_type) and dtypes.is_matrix(other.var_type): + raise ValueError("Matrix division is not supported.") + + lhs_mark_type = return_type if not reverse else dtypes.make_floating_dtype(other.var_type) + rhs_mark_type = dtypes.make_floating_dtype(other.var_type) if not reverse else return_type + _mark_arith_binary(lhs_mark_type, rhs_mark_type, "/", inplace=inplace) + + lhs_expr = ( + base_utils.to_dtype_base(lhs_mark_type, var).resolve() + if not reverse else + base_utils.to_dtype_base(lhs_mark_type, other).resolve() + ) + rhs_expr = ( + base_utils.to_dtype_base(rhs_mark_type, other).resolve() + if not reverse else + base_utils.to_dtype_base(rhs_mark_type, var).resolve() + ) + expr, use_assignment = _resolve_arithmetic_binary_expr( + "/", + lhs_mark_type, + lhs_expr, + rhs_mark_type, + rhs_expr, + ) + + if not inplace: + return base_utils.new_base_var( + return_type, + expr, + parents=[var, other]) + + if use_assignment: + inplace_expr, _ = _resolve_arithmetic_binary_expr( + "/", + var.var_type, + var.resolve(), + rhs_mark_type, + rhs_expr, + ) + base_utils.append_contents(f"{var.resolve()} = {inplace_expr};\n") + else: + base_utils.append_contents(f"{var.resolve()} /= {rhs_expr};\n") + return var + +def floordiv(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: + assert dtypes.is_integer_dtype(var.var_type), "Floor division is only supported for integer types." + return_type = arithmetic_op_common(var, other, reverse=reverse, inplace=inplace) + assert dtypes.is_integer_dtype(return_type), "Floor division is only supported for integer types." + + if base_utils.is_scalar_number(other): + assert base_utils.is_int_number(other), "Floor division only supports integer scalar values." + + if not inplace: + if other == 1: + return var + + if base_utils.is_int_power_of_2(other): + power = my_log2_int(other) + _mark_arith_binary(var.var_type, base_utils.number_to_dtype(other), ">>", inplace=False) + return base_utils.new_base_var(var.var_type, f"{var.resolve()} >> {power}", [var]) + + scalar_type = base_utils.number_to_dtype(other) + _mark_arith_binary(var.var_type if not reverse else scalar_type, scalar_type if not reverse else var.var_type, "/", inplace=False) + return base_utils.new_base_var( + return_type, + ( + f"{var.resolve()} / {other}" + if not reverse else + f"{other} / {var.resolve()}" + ), + parents=[var]) + + _mark_arith_binary(var.var_type, base_utils.number_to_dtype(other), "/", inplace=True) + base_utils.append_contents(f"{var.resolve()} /= {other};\n") + return var + + assert isinstance(other, BaseVariable) + _mark_arith_binary(var.var_type if not reverse else other.var_type, other.var_type if not reverse else var.var_type, "/", inplace=inplace) + + if not inplace: + return base_utils.new_base_var( + return_type, + ( + f"{var.resolve()} / {other.resolve()}" + if not reverse else + f"{other.resolve()} / {var.resolve()}" + ), + parents=[var, other]) + + base_utils.append_contents(f"{var.resolve()} /= {other.resolve()};\n") + return var + +def mod(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: + assert dtypes.is_integer_dtype(var.var_type), "Modulus is only supported for integer types." + return_type = arithmetic_op_common(var, other, reverse=reverse, inplace=inplace) + assert dtypes.is_integer_dtype(return_type), "Modulus is only supported for integer types." + + if base_utils.is_scalar_number(other): + scalar_type = base_utils.number_to_dtype(other) + _mark_arith_binary(var.var_type if not reverse else scalar_type, scalar_type if not reverse else var.var_type, "%", inplace=inplace) + if not inplace: + return base_utils.new_base_var( + return_type, + ( + f"{var.resolve()} % {other}" + if not reverse else + f"{other} % {var.resolve()}" + ), + parents=[var]) + + base_utils.append_contents(f"{var.resolve()} %= {other};\n") + return var + + assert isinstance(other, BaseVariable) + _mark_arith_binary(var.var_type if not reverse else other.var_type, other.var_type if not reverse else var.var_type, "%", inplace=inplace) + + if not inplace: + return base_utils.new_base_var( + return_type, + ( + f"{var.resolve()} % {other.resolve()}" + if not reverse else + f"{other.resolve()} % {var.resolve()}" + ), + parents=[var, other]) + + base_utils.append_contents(f"{var.resolve()} %= {other.resolve()};\n") + return var + + +def pow_expr(x: Any, y: Any) -> Union[BaseVariable, float]: + if base_utils.is_int_number(y) and y == 0: + return 1 + + if base_utils.is_number(y) and base_utils.is_number(x): + return se.power(x, y) + + if base_utils.is_number(x) and isinstance(y, BaseVariable): + result_type = base_utils.dtype_to_floating(y.var_type) + return base_utils.new_base_var( + result_type, + base_utils.get_codegen_backend().binary_math_expr( + "pow", + dtypes.float32, + base_utils.resolve_input(x), + result_type, + y.resolve(), + ), + parents=[y] + ) + + if base_utils.is_number(y) and isinstance(x, BaseVariable): + result_type = base_utils.dtype_to_floating(x.var_type) + + if base_utils.is_int_number(y) and x.is_register(): + if y > 0 and y <= 4: + expr = " * ".join([x.resolve()] * int(y)) + return base_utils.new_base_var(result_type, expr, parents=[x]) + elif y < 0 and y >= -4: + expr = " * ".join([x.resolve()] * int(-y)) + return base_utils.new_base_var(result_type, f"1 / ({expr})", parents=[x]) + + return base_utils.new_base_var( + result_type, + base_utils.get_codegen_backend().binary_math_expr( + "pow", + result_type, + x.resolve(), + dtypes.float32, + base_utils.resolve_input(y), + ), + parents=[x] + ) + + assert isinstance(y, BaseVariable), "First argument must be a ShaderVariable or number" + assert isinstance(x, BaseVariable), "Second argument must be a ShaderVariable or number" + + result_type = base_utils.dtype_to_floating(dtypes.cross_type(x.var_type, y.var_type)) + return base_utils.new_base_var( + result_type, + base_utils.get_codegen_backend().binary_math_expr( + "pow", + base_utils.dtype_to_floating(x.var_type), + x.resolve(), + base_utils.dtype_to_floating(y.var_type), + y.resolve(), + ), + parents=[y, x], + lexical_unit=True + ) + +def pow(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False) -> BaseVariable: + _ = arithmetic_op_common(var, other, reverse=reverse, inplace=inplace) + experession = pow_expr(other, var) if reverse else pow_expr(var, other) + + if not inplace: + return experession + + base_utils.append_contents(f"{var.resolve()} = {experession};\n") + return var + +def neg(var: BaseVariable) -> BaseVariable: + _mark_arith_unary(var, "-") + expr, _ = _resolve_arithmetic_unary_expr("-", var.var_type, var.resolve()) + return base_utils.new_base_var( + var.var_type, + expr, + parents=[var]) + +def absolute(var: BaseVariable) -> BaseVariable: + return base_utils.new_base_var( + var.var_type, + f"abs({var.resolve()})", + parents=[var], + lexical_unit=True) diff --git a/vkdispatch/codegen/functions/base_functions/arithmetic_comparisons.py b/vkdispatch/codegen/functions/base_functions/arithmetic_comparisons.py new file mode 100644 index 00000000..d4094258 --- /dev/null +++ b/vkdispatch/codegen/functions/base_functions/arithmetic_comparisons.py @@ -0,0 +1,47 @@ +import vkdispatch.base.dtype as dtypes +from vkdispatch.codegen.variables.base_variable import BaseVariable +from typing import Any + +from . import base_utils + +def less_than(var: BaseVariable, other: Any) -> BaseVariable: + return base_utils.new_base_var( + dtypes.int32, + f"{base_utils.resolve_input(var)} < {base_utils.resolve_input(other)}", + parents=[var, other] + ) + +def less_or_equal(var: BaseVariable, other: Any) -> BaseVariable: + return base_utils.new_base_var( + dtypes.int32, + f"{base_utils.resolve_input(var)} <= {base_utils.resolve_input(other)}", + parents=[var, other] + ) + +def equal_to(var: BaseVariable, other: Any) -> BaseVariable: + return base_utils.new_base_var( + dtypes.int32, + f"{base_utils.resolve_input(var)} == {base_utils.resolve_input(other)}", + parents=[var, other] + ) + +def not_equal_to(var: BaseVariable, other: Any) -> BaseVariable: + return base_utils.new_base_var( + dtypes.int32, + f"{base_utils.resolve_input(var)} != {base_utils.resolve_input(other)}", + parents=[var, other] + ) + +def greater_than(var: BaseVariable, other: Any) -> BaseVariable: + return base_utils.new_base_var( + dtypes.int32, + f"{base_utils.resolve_input(var)} > {base_utils.resolve_input(other)}", + parents=[var, other] + ) + +def greater_or_equal(var: BaseVariable, other: Any) -> BaseVariable: + return base_utils.new_base_var( + dtypes.int32, + f"{base_utils.resolve_input(var)} >= {base_utils.resolve_input(other)}", + parents=[var, other] + ) diff --git a/vkdispatch/codegen/functions/base_functions/base_utils.py b/vkdispatch/codegen/functions/base_functions/base_utils.py new file mode 100644 index 00000000..7a5d7d71 --- /dev/null +++ b/vkdispatch/codegen/functions/base_functions/base_utils.py @@ -0,0 +1,141 @@ +import vkdispatch.base.dtype as dtypes +from vkdispatch.codegen.variables.base_variable import BaseVariable + +from typing import Any, Optional + +import numbers +import math + +from ....compat import numpy_compat as npc +from vkdispatch.codegen.shader_writer import new_scaled_var, append_contents, new_name +from vkdispatch.codegen.global_builder import get_codegen_backend + +from vkdispatch.codegen.shader_writer import new_var as new_var_impl + +_I32_MIN = -(2 ** 31) +_I32_MAX = 2 ** 31 - 1 +_U32_MAX = 2 ** 32 - 1 + +def new_base_var(var_type: dtypes.dtype, + var_name: Optional[str], + parents: list, + lexical_unit: bool = False, + settable: bool = False, + register: bool = False) -> BaseVariable: + return new_var_impl(var_type, var_name, parents, lexical_unit, settable, register) + +def is_number(x) -> bool: + return isinstance(x, numbers.Number) and not isinstance(x, bool) + +def is_int_number(x) -> bool: + return isinstance(x, numbers.Integral) and not isinstance(x, bool) + +def _is_numpy_float(x) -> bool: + return npc.is_numpy_floating_instance(x) + +def is_float_number(x) -> bool: + return isinstance(x, numbers.Real) and not isinstance(x, numbers.Integral) and not isinstance(x, bool) \ + and (isinstance(x, float) or _is_numpy_float(x)) + +def is_complex_number(x) -> bool: + return isinstance(x, numbers.Complex) and not isinstance(x, numbers.Real) + +def is_scalar_number(x) -> bool: + return is_number(x) and (is_int_number(x) or is_float_number(x)) and not is_complex_number(x) + +def is_int_power_of_2(n: int) -> bool: + """Check if an integer is a power of 2.""" + return n > 0 and (n & (n - 1)) == 0 + +def number_to_dtype(number: numbers.Number): + if is_int_number(number): + if number >= 0: + if number <= _U32_MAX: + return dtypes.uint32 + return dtypes.uint64 + + if number >= _I32_MIN and number <= _I32_MAX: + return dtypes.int32 + return dtypes.int64 + elif is_float_number(number): + return dtypes.float32 + elif is_complex_number(number): + return dtypes.complex64 + else: + raise TypeError(f"Unsupported number type: {type(number)}") + +def _check_is_int_numpy(x) -> bool: + return npc.is_numpy_integer_scalar(x) + +def check_is_int(variable): + return npc.is_integer_scalar(variable) + +def dtype_to_floating(var_type: dtypes.dtype) -> dtypes.dtype: + return dtypes.make_floating_dtype(var_type) + +def _inf_scalar_type(var_type: dtypes.dtype) -> dtypes.dtype: + """Extract the scalar float type from any dtype.""" + if dtypes.is_complex(var_type): + return var_type.child_type + if dtypes.is_vector(var_type) or dtypes.is_matrix(var_type): + return var_type.scalar + return var_type + +def format_number_literal(var: numbers.Number, *, force_float32: bool = False, dtype: Optional[dtypes.dtype] = None) -> str: + if is_complex_number(var): + return str(var) + + if is_float_number(var) or (force_float32 and is_int_number(var)): + value = float(var) + + if math.isinf(value): + backend = get_codegen_backend() + scalar = _inf_scalar_type(dtype) if dtype is not None else dtypes.float32 + if scalar is dtypes.float64: + return backend.inf_f64_expr() if value > 0 else backend.ninf_f64_expr() + if scalar is dtypes.float16: + return backend.inf_f16_expr() if value > 0 else backend.ninf_f16_expr() + return backend.inf_f32_expr() if value > 0 else backend.ninf_f32_expr() + + if math.isnan(value): + return "(0.0f / 0.0f)" + + literal = repr(value) + if "e" not in literal and "E" not in literal and "." not in literal: + literal += ".0" + return literal + "f" + + return str(var) + +def resolve_input(var: Any, dtype: Optional[dtypes.dtype] = None) -> str: + #print("Resolving input:", var) + + if is_number(var): + return format_number_literal(var, dtype=dtype) + + assert isinstance(var, BaseVariable), "Argument must be a ShaderVariable or number" + return var.resolve() + +def resolve_input_type(var: Any) -> Optional[dtypes.dtype]: + if is_number(var): + return number_to_dtype(var) + + if isinstance(var, BaseVariable): + return var.var_type + + return None + +def backend_constructor(var_type: dtypes.dtype, *args) -> str: + return get_codegen_backend().constructor( + var_type, + [resolve_input(elem, dtype=var_type) for elem in args], + arg_types=[resolve_input_type(elem) for elem in args], + ) + +def to_dtype_base(var_type: dtypes.dtype, *args): + return new_base_var( + var_type, + backend_constructor(var_type, *args), + args, + lexical_unit=True + ) diff --git a/vkdispatch/codegen/functions/base_functions/bitwise.py b/vkdispatch/codegen/functions/base_functions/bitwise.py new file mode 100644 index 00000000..e272817f --- /dev/null +++ b/vkdispatch/codegen/functions/base_functions/bitwise.py @@ -0,0 +1,185 @@ +import vkdispatch.base.dtype as dtypes +from vkdispatch.codegen.variables.base_variable import BaseVariable +from typing import Any + +from . import base_utils + + +def _mark_bit_unary(var: BaseVariable, op: str) -> None: + base_utils.get_codegen_backend().mark_composite_unary_op(var.var_type, op) + + +def _mark_bit_binary(lhs_type: dtypes.dtype, rhs_type: dtypes.dtype, op: str, *, inplace: bool = False) -> None: + base_utils.get_codegen_backend().mark_composite_binary_op(lhs_type, rhs_type, op, inplace=inplace) + +def bitwise_op_common(var: BaseVariable, + other: Any, + reverse: bool = False, + inplace: bool = False) -> BaseVariable: + assert isinstance(var, BaseVariable), "First argument must be a ShaderVariable" + assert dtypes.is_integer_dtype(var.var_type), "Bitwise operations only supported on integer types." + + result_type = None + + if base_utils.is_int_number(other): + result_type = dtypes.cross_type(var.var_type, base_utils.number_to_dtype(other)) + elif isinstance(other, BaseVariable): + result_type = dtypes.cross_type(var.var_type, other.var_type) + else: + raise TypeError(f"Unsupported type for bitwise op: ShaderVariable and {type(other)}") + + if inplace: + assert var.is_setable(), "Inplace bitwise requires the variable to be settable." + assert not reverse, "Inplace bitwise does not support reverse operations." + var.read_callback() + var.write_callback() + assert result_type == var.var_type, "Inplace bitwise requires the result type to match the variable type." + + if base_utils.is_int_number(other): + return result_type + + assert dtypes.is_integer_dtype(other.var_type), "Bitwise operations only supported on integer types." + + if inplace: + other.read_callback() + + return dtypes.cross_type(var.var_type, other.var_type) + +def lshift(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False): + return_type = bitwise_op_common(var, other, reverse=reverse, inplace=inplace) + + if base_utils.is_int_number(other): + _mark_bit_binary(var.var_type if not reverse else base_utils.number_to_dtype(other), base_utils.number_to_dtype(other) if not reverse else var.var_type, "<<", inplace=inplace) + if not inplace: + return base_utils.new_base_var( + return_type, + ( + f"{var.resolve()} << {other}" + if not reverse else + f"{other} << {var.resolve()}" + ), + parents=[var]) + + base_utils.append_contents(f"{var.resolve()} <<= {other};\n") + return var + + assert isinstance(other, BaseVariable) + _mark_bit_binary(var.var_type if not reverse else other.var_type, other.var_type if not reverse else var.var_type, "<<", inplace=inplace) + + if not inplace: + return base_utils.new_base_var( + return_type, + ( + f"{var.resolve()} << {other.resolve()}" + if not reverse else + f"{other.resolve()} << {var.resolve()}" + ), + parents=[var, other]) + + base_utils.append_contents(f"{var.resolve()} <<= {other.resolve()};\n") + return var + +def rshift(var: BaseVariable, other: Any, reverse: bool = False, inplace: bool = False): + return_type = bitwise_op_common(var, other, reverse=reverse, inplace=inplace) + + if base_utils.is_int_number(other): + _mark_bit_binary(var.var_type if not reverse else base_utils.number_to_dtype(other), base_utils.number_to_dtype(other) if not reverse else var.var_type, ">>", inplace=inplace) + if not inplace: + return base_utils.new_base_var( + return_type, + ( + f"{var.resolve()} >> {other}" + if not reverse else + f"{other} >> {var.resolve()}" + ), + parents=[var]) + + base_utils.append_contents(f"{var.resolve()} >>= {other};\n") + return var + + assert isinstance(other, BaseVariable) + _mark_bit_binary(var.var_type if not reverse else other.var_type, other.var_type if not reverse else var.var_type, ">>", inplace=inplace) + + if not inplace: + return base_utils.new_base_var( + return_type, + ( + f"{var.resolve()} >> {other.resolve()}" + if not reverse else + f"{other.resolve()} >> {var.resolve()}" + ), + parents=[var, other]) + + base_utils.append_contents(f"{var.resolve()} >>= {other.resolve()};\n") + return var + +def and_bits(var: BaseVariable, other: Any, inplace: bool = False): + return_type = bitwise_op_common(var, other, inplace=inplace) + + if base_utils.is_int_number(other): + _mark_bit_binary(var.var_type, base_utils.number_to_dtype(other), "&", inplace=inplace) + if not inplace: + return base_utils.new_base_var(return_type, f"{var.resolve()} & {other}",parents=[var]) + + base_utils.append_contents(f"{var.resolve()} &= {other};\n") + return var + + assert isinstance(other, BaseVariable) + _mark_bit_binary(var.var_type, other.var_type, "&", inplace=inplace) + + if not inplace: + return base_utils.new_base_var(return_type, f"{var.resolve()} & {other.resolve()}",parents=[var, other]) + + base_utils.append_contents(f"{var.resolve()} &= {other.resolve()};\n") + return var + +def xor_bits(var: BaseVariable, other: Any, inplace: bool = False): + return_type = bitwise_op_common(var, other, inplace=inplace) + + if base_utils.is_int_number(other): + _mark_bit_binary(var.var_type, base_utils.number_to_dtype(other), "^", inplace=inplace) + if not inplace: + return base_utils.new_base_var(return_type, f"{var.resolve()} ^ {other}",parents=[var]) + + base_utils.append_contents(f"{var.resolve()} ^= {other};\n") + return var + + assert isinstance(other, BaseVariable) + _mark_bit_binary(var.var_type, other.var_type, "^", inplace=inplace) + + if not inplace: + return base_utils.new_base_var(return_type, f"{var.resolve()} ^ {other.resolve()}",parents=[var, other]) + + base_utils.append_contents(f"{var.resolve()} ^= {other.resolve()};\n") + return var + +def or_bits(var: BaseVariable, other: Any, inplace: bool = False): + return_type = bitwise_op_common(var, other, inplace=inplace) + + if base_utils.is_int_number(other): + _mark_bit_binary(var.var_type, base_utils.number_to_dtype(other), "|", inplace=inplace) + if not inplace: + return base_utils.new_base_var(return_type, f"{var.resolve()} | {other}",parents=[var]) + + base_utils.append_contents(f"{var.resolve()} |= {other};\n") + return var + + assert isinstance(other, BaseVariable) + _mark_bit_binary(var.var_type, other.var_type, "|", inplace=inplace) + + if not inplace: + return base_utils.new_base_var(return_type, f"{var.resolve()} | {other.resolve()}",parents=[var, other]) + + base_utils.append_contents(f"{var.resolve()} |= {other.resolve()};\n") + return var + +def invert(var: BaseVariable): + assert isinstance(var, BaseVariable), "First argument must be a ShaderVariable" + assert dtypes.is_integer_dtype(var.var_type), "Bitwise operations only supported on integer types." + _mark_bit_unary(var, "~") + + return base_utils.new_base_var( + var.var_type, + f"~{var.resolve()}", + parents=[var] + ) diff --git a/vkdispatch/codegen/functions/block_synchonization.py b/vkdispatch/codegen/functions/block_synchonization.py new file mode 100644 index 00000000..3deccc45 --- /dev/null +++ b/vkdispatch/codegen/functions/block_synchonization.py @@ -0,0 +1,27 @@ +from ..global_builder import get_builder, get_codegen_backend + +from . import utils + +def barrier(): + # On Apple devices, a memory barrier is required before a barrier + # to ensure memory operations are visible to all threads + # (for some reason) + if get_builder().is_apple_device and get_codegen_backend().name == "glsl": + memory_barrier() + + utils.append_contents(utils.codegen_backend().barrier_statement() + "\n") + +def memory_barrier(): + utils.append_contents(utils.codegen_backend().memory_barrier_statement() + "\n") + +def memory_barrier_buffer(): + utils.append_contents(utils.codegen_backend().memory_barrier_buffer_statement() + "\n") + +def memory_barrier_shared(): + utils.append_contents(utils.codegen_backend().memory_barrier_shared_statement() + "\n") + +def memory_barrier_image(): + utils.append_contents(utils.codegen_backend().memory_barrier_image_statement() + "\n") + +def group_memory_barrier(): + utils.append_contents(utils.codegen_backend().group_memory_barrier_statement() + "\n") diff --git a/vkdispatch/codegen/functions/builtin_constants.py b/vkdispatch/codegen/functions/builtin_constants.py new file mode 100644 index 00000000..47812331 --- /dev/null +++ b/vkdispatch/codegen/functions/builtin_constants.py @@ -0,0 +1,130 @@ +import vkdispatch.base.dtype as dtypes +from . import utils + +def inf_f32(): + return utils.new_var( + dtypes.float32, + utils.codegen_backend().inf_f32_expr(), + [], + lexical_unit=True + ) + +def ninf_f32(): + return utils.new_var( + dtypes.float32, + utils.codegen_backend().ninf_f32_expr(), + [], + lexical_unit=True + ) + +def inf_f64(): + return utils.new_var( + dtypes.float64, + utils.codegen_backend().inf_f64_expr(), + [], + lexical_unit=True + ) + +def ninf_f64(): + return utils.new_var( + dtypes.float64, + utils.codegen_backend().ninf_f64_expr(), + [], + lexical_unit=True + ) + +def inf_f16(): + return utils.new_var( + dtypes.float16, + utils.codegen_backend().inf_f16_expr(), + [], + lexical_unit=True + ) + +def ninf_f16(): + return utils.new_var( + dtypes.float16, + utils.codegen_backend().ninf_f16_expr(), + [], + lexical_unit=True + ) + +def global_invocation_id(): + return utils.new_var( + dtypes.uvec3, + utils.codegen_backend().global_invocation_id_expr(), + [], + lexical_unit=True + ) + +def local_invocation_id(): + return utils.new_var( + dtypes.uvec3, + utils.codegen_backend().local_invocation_id_expr(), + [], + lexical_unit=True + ) + +def local_invocation_index(): + return utils.new_var( + dtypes.uint32, + utils.codegen_backend().local_invocation_index_expr(), + [], + lexical_unit=True + ) + +def workgroup_id(): + return utils.new_var( + dtypes.uvec3, + utils.codegen_backend().workgroup_id_expr(), + [], + lexical_unit=True + ) + +def workgroup_size(): + return utils.new_var( + dtypes.uvec3, + utils.codegen_backend().workgroup_size_expr(), + [], + lexical_unit=True + ) + +def num_workgroups(): + return utils.new_var( + dtypes.uvec3, + utils.codegen_backend().num_workgroups_expr(), + [], + lexical_unit=True + ) + +def num_subgroups(): + return utils.new_var( + dtypes.uint32, + utils.codegen_backend().num_subgroups_expr(), + [], + lexical_unit=True + ) + +def subgroup_id(): + return utils.new_var( + dtypes.uint32, + utils.codegen_backend().subgroup_id_expr(), + [], + lexical_unit=True + ) + +def subgroup_size(): + return utils.new_var( + dtypes.uint32, + utils.codegen_backend().subgroup_size_expr(), + [], + lexical_unit=True + ) + +def subgroup_invocation_id(): + return utils.new_var( + dtypes.uint32, + utils.codegen_backend().subgroup_invocation_id_expr(), + [], + lexical_unit=True + ) diff --git a/vkdispatch/codegen/functions/common_builtins.py b/vkdispatch/codegen/functions/common_builtins.py new file mode 100644 index 00000000..e801bdda --- /dev/null +++ b/vkdispatch/codegen/functions/common_builtins.py @@ -0,0 +1,430 @@ +import vkdispatch.base.dtype as dtypes +from ..variables.variables import ShaderVariable +from typing import Any, Union, Tuple + +from . import utils +from . import scalar_eval as se + +def comment(comment: str, preceding_new_line: bool = True) -> None: + comment_text = str(comment).replace("\r\n", "\n").replace("\r", "\n") + comment_lines = comment_text.split("\n") + + if preceding_new_line: + utils.append_contents("\n") + + if len(comment_lines) == 1: + safe_comment = comment_lines[0].replace("*/", "* /") + utils.append_contents(f"/* {safe_comment} */\n") + return + + utils.append_contents("/*\n") + + for line in comment_lines: + safe_line = line.replace("*/", "* /") + + if safe_line: + utils.append_contents(f" * {safe_line}\n") + continue + + utils.append_contents(" *\n") + + utils.append_contents(" */\n") + +def abs(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return abs(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + utils.dtype_to_floating(var.var_type), + f"abs({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def sign(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.sign(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + utils.dtype_to_floating(var.var_type), + f"sign({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def floor(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.floor(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + utils.dtype_to_floating(var.var_type), + f"floor({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def ceil(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.ceil(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + utils.dtype_to_floating(var.var_type), + f"ceil({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def trunc(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.trunc(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + utils.dtype_to_floating(var.var_type), + f"trunc({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def round(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.round(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + utils.dtype_to_floating(var.var_type), + f"round({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def round_even(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.round(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + utils.mark_backend_feature("roundEven") + + return utils.new_var( + utils.dtype_to_floating(var.var_type), + f"roundEven({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def fract(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return float(var - se.floor(var)) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + utils.mark_backend_feature("fract") + + return utils.new_var( + utils.dtype_to_floating(var.var_type), + f"fract({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def mod(x: Any, y: Any) -> Union[ShaderVariable, float]: + if utils.is_number(y) and utils.is_number(x): + return se.mod(x, y) + + base_var = None + + if isinstance(y, ShaderVariable): + base_var = y + elif isinstance(x, ShaderVariable): + base_var = x + else: + raise AssertionError("Arguments must be ShaderVariables or numbers") + + utils.mark_backend_feature("mod") + + return utils.new_var( + utils.dtype_to_floating(base_var.var_type), + f"mod({utils.resolve_input(x)}, {utils.resolve_input(y)})", + parents=[y, x], + lexical_unit=True + ) + +def modf(x: Any, y: Any) -> Tuple[ShaderVariable, ShaderVariable]: + if utils.is_number(y) and utils.is_number(x): + a, b = se.modf(x, y) + return float(a), float(b) + + if utils.is_number(x) and isinstance(y, ShaderVariable): + utils.mark_backend_feature("mod") + return utils.new_var( + utils.dtype_to_floating(y.var_type), + f"mod({utils.resolve_input(x)}, {y.resolve()})", + parents=[y] + ) + + if utils.is_number(y) and isinstance(x, ShaderVariable): + utils.mark_backend_feature("mod") + return utils.new_var( + utils.dtype_to_floating(x.var_type), + f"mod({x.resolve()}, {utils.resolve_input(y)})", + parents=[x] + ) + + assert isinstance(y, ShaderVariable), "First argument must be a ShaderVariable or number" + assert isinstance(x, ShaderVariable), "Second argument must be a ShaderVariable or number" + utils.mark_backend_feature("mod") + + return utils.new_var( + utils.dtype_to_floating(y.var_type), + f"mod({x.resolve()}, {y.resolve()})", + parents=[y, x], + lexical_unit=True + ) + +def min(x: Any, y: Any) -> Union[ShaderVariable, float]: + if utils.is_number(y) and utils.is_number(x): + return se.minimum(x, y) + + base_var = None + + if isinstance(y, ShaderVariable): + base_var = y + elif isinstance(x, ShaderVariable): + base_var = x + else: + raise AssertionError("Arguments must be ShaderVariables or numbers") + + return utils.new_var( + utils.dtype_to_floating(base_var.var_type), + f"min({utils.resolve_input(x)}, {utils.resolve_input(y)})", + parents=[y, x], + lexical_unit=True + ) + +def max(x: Any, y: Any) -> Union[ShaderVariable, float]: + if utils.is_number(y) and utils.is_number(x): + return se.maximum(x, y) + + base_var = None + + if isinstance(y, ShaderVariable): + base_var = y + elif isinstance(x, ShaderVariable): + base_var = x + else: + raise AssertionError("Arguments must be ShaderVariables or numbers") + + return utils.new_var( + utils.dtype_to_floating(base_var.var_type), + f"max({utils.resolve_input(x)}, {utils.resolve_input(y)})", + parents=[y, x], + lexical_unit=True + ) + +def clip(x: Any, min_val: Any, max_val: Any) -> Union[ShaderVariable, float]: + if utils.is_number(x) and utils.is_number(min_val) and utils.is_number(max_val): + return se.clip(x, min_val, max_val) + + base_var = None + + if isinstance(min_val, ShaderVariable): + base_var = min_val + elif isinstance(max_val, ShaderVariable): + base_var = max_val + elif isinstance(x, ShaderVariable): + base_var = x + else: + raise AssertionError("Arguments must be ShaderVariables or numbers") + + return utils.new_var( + utils.dtype_to_floating(base_var.var_type), + f"clamp({utils.resolve_input(x)}, {utils.resolve_input(min_val)}, {utils.resolve_input(max_val)})", + parents=[x, min_val, max_val], + lexical_unit=True + ) + +def clamp(x: Any, min_val: Any, max_val: Any) -> Union[ShaderVariable, float]: + return clip(x, min_val, max_val) + +def mix(x: Any, y: Any, a: Any) -> Union[ShaderVariable, float]: + if utils.is_number(y) and utils.is_number(x) and utils.is_number(a): + return se.interp(a, [0, 1], [x, y]) + + base_var = None + + if isinstance(a, ShaderVariable): + base_var = a + elif isinstance(y, ShaderVariable): + base_var = y + elif isinstance(x, ShaderVariable): + base_var = x + else: + raise AssertionError("Arguments must be ShaderVariables or numbers") + + utils.mark_backend_feature("mix") + + return utils.new_var( + utils.dtype_to_floating(base_var.var_type), + f"mix({utils.resolve_input(x)}, {utils.resolve_input(y)}, {utils.resolve_input(a)})", + parents=[y, x, a], + lexical_unit=True + ) + +def step(edge: Any, x: Any) -> Union[ShaderVariable, float]: + if utils.is_number(edge) and utils.is_number(x): + return float(0.0 if x < edge else 1.0) + + base_var = None + + if isinstance(x, ShaderVariable): + base_var = x + elif isinstance(edge, ShaderVariable): + base_var = edge + else: + raise AssertionError("Arguments must be ShaderVariables or numbers") + + utils.mark_backend_feature("step") + + return utils.new_var( + utils.dtype_to_floating(base_var.var_type), + f"step({utils.resolve_input(edge)}, {utils.resolve_input(x)})", + parents=[edge, x], + lexical_unit=True + ) + +def smoothstep(edge0: Any, edge1: Any, x: Any) -> Union[ShaderVariable, float]: + if utils.is_number(edge0) and utils.is_number(edge1) and utils.is_number(x): + t = se.clip((x - edge0) / (edge1 - edge0), 0.0, 1.0) + return float(t * t * (3.0 - 2.0 * t)) + + base_var = None + + if isinstance(x, ShaderVariable): + base_var = x + elif isinstance(edge1, ShaderVariable): + base_var = edge1 + elif isinstance(edge0, ShaderVariable): + base_var = edge0 + else: + raise AssertionError("Arguments must be ShaderVariables or numbers") + + utils.mark_backend_feature("smoothstep") + + return utils.new_var( + utils.dtype_to_floating(base_var.var_type), + f"smoothstep({utils.resolve_input(edge0)}, {utils.resolve_input(edge1)}, {utils.resolve_input(x)})", + parents=[edge0, edge1, x], + lexical_unit=True + ) + +def isnan(var: Any) -> Union[ShaderVariable, bool]: + if utils.is_number(var): + return se.isnan(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + dtypes.int32, + f"isnan({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def isinf(var: Any) -> Union[ShaderVariable, bool]: + if utils.is_number(var): + return se.isinf(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + dtypes.int32, + f"isinf({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def float_bits_to_int(var: Any) -> Union[ShaderVariable, int]: + if utils.is_number(var): + return se.float_bits_to_int(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + dtypes.int32, + utils.codegen_backend().float_bits_to_int_expr(var.resolve()), + parents=[var], + lexical_unit=True + ) + +def float_bits_to_uint(var: Any) -> Union[ShaderVariable, int]: + if utils.is_number(var): + return se.float_bits_to_uint(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + dtypes.uint32, + utils.codegen_backend().float_bits_to_uint_expr(var.resolve()), + parents=[var], + lexical_unit=True + ) + +def int_bits_to_float(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.int_bits_to_float(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + dtypes.float32, + utils.codegen_backend().int_bits_to_float_expr(var.resolve()), + parents=[var], + lexical_unit=True + ) + +def uint_bits_to_float(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.uint_bits_to_float(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + dtypes.float32, + utils.codegen_backend().uint_bits_to_float_expr(var.resolve()), + parents=[var], + lexical_unit=True + ) + +def fma(a: Any, b: Any, c: Any) -> Union[ShaderVariable, float]: + if utils.is_number(a) and utils.is_number(b) and utils.is_number(c): + return float(a * b + c) + + base_var = None + + if isinstance(c, ShaderVariable): + base_var = c + elif isinstance(b, ShaderVariable): + base_var = b + elif isinstance(a, ShaderVariable): + base_var = a + else: + raise AssertionError("Arguments must be ShaderVariables or numbers") + + result_type = utils.dtype_to_floating(base_var.var_type) + fma_function = utils.codegen_backend().fma_function_name(result_type) + + return utils.new_var( + result_type, + f"{fma_function}({utils.resolve_input(a)}, {utils.resolve_input(b)}, {utils.resolve_input(c)})", + parents=[a, b, c], + lexical_unit=True + ) diff --git a/vkdispatch/codegen/functions/complex_numbers.py b/vkdispatch/codegen/functions/complex_numbers.py new file mode 100644 index 00000000..e99f3d7b --- /dev/null +++ b/vkdispatch/codegen/functions/complex_numbers.py @@ -0,0 +1,57 @@ +import vkdispatch.base.dtype as dtypes +from ..variables.variables import ShaderVariable +from typing import Any, Union + +from .common_builtins import fma + +from .type_casting import to_complex, to_dtype +from . import utils + +from .trigonometry import cos, sin + +def complex_from_euler_angle(angle: ShaderVariable): + if not isinstance(angle, ShaderVariable): + raise TypeError("complex_from_euler_angle expects a ShaderVariable angle") + + target_complex_type = dtypes.complex_from_float(dtypes.make_floating_dtype(angle.var_type)) + return to_dtype(target_complex_type, cos(angle), sin(angle)) + +def validate_complex_number(arg1: Any) -> Union[ShaderVariable, complex]: + if isinstance(arg1, ShaderVariable): + assert dtypes.is_complex(arg1.var_type), "Input variables to complex multiplication must be complex" + return arg1 + + assert utils.is_number(arg1), "Argument must be ShaderVariable or number" + + return complex(arg1) + +def _new_big_complex(var_type: dtypes.dtype, arg1: Any, arg2: Any): + var_str = utils.backend_constructor(var_type, arg1, arg2) + + return utils.new_var( + var_type, + var_str, + [utils.resolve_input(arg1), utils.resolve_input(arg2)], + lexical_unit=True + ) + +def mult_complex(arg1: ShaderVariable, arg2: ShaderVariable): + a1 = validate_complex_number(arg1) + a2 = validate_complex_number(arg2) + + fallback_type = dtypes.complex64 + for normalized_arg in (a1, a2): + if isinstance(normalized_arg, ShaderVariable): + fallback_type = normalized_arg.var_type + break + + result_type = None + for normalized_arg in (a1, a2): + arg_type = normalized_arg.var_type if isinstance(normalized_arg, ShaderVariable) else fallback_type + result_type = arg_type if result_type is None else dtypes.cross_type(result_type, arg_type) + + return _new_big_complex( + result_type, # type: ignore[arg-type] + fma(a1.real, a2.real, -a1.imag * a2.imag), + fma(a1.real, a2.imag, a1.imag * a2.real), + ) diff --git a/vkdispatch/codegen/functions/control_flow.py b/vkdispatch/codegen/functions/control_flow.py new file mode 100644 index 00000000..88fcad45 --- /dev/null +++ b/vkdispatch/codegen/functions/control_flow.py @@ -0,0 +1,91 @@ +import vkdispatch.base.dtype as dtypes +from ..variables.variables import ShaderVariable +from typing import List, Optional, Union +from . import utils + +def proc_bool(arg: Union[ShaderVariable, bool]) -> ShaderVariable: + if isinstance(arg, bool): + return "true" if arg else "false" + + if isinstance(arg, ShaderVariable): + return arg.resolve() + + raise TypeError(f"Argument of type {type(arg)} cannot be processed as a boolean.") + +def if_statement(arg: ShaderVariable, command: Optional[str] = None): + if command is None: + utils.append_contents(f"if({proc_bool(arg)}) {'{'}\n") + utils.scope_increment() + return + + utils.append_contents(f"if({proc_bool(arg)})\n") + utils.scope_increment() + utils.append_contents(f"{command}\n") + utils.scope_decrement() + +def if_any(*args: List[ShaderVariable]): + utils.append_contents(f"if({' || '.join([str(proc_bool(elem)) for elem in args])}) {'{'}\n") + utils.scope_increment() + +def if_all(*args: List[ShaderVariable]): + utils.append_contents(f"if({' && '.join([str(proc_bool(elem)) for elem in args])}) {'{'}\n") + utils.scope_increment() + +def else_statement(): + utils.scope_decrement() + utils.append_contents("} else {\n") + utils.scope_increment() + +def else_if_statement(arg: ShaderVariable): + utils.scope_decrement() + utils.append_contents(f"}} else if({proc_bool(arg)}) {'{'}\n") + utils.scope_increment() + +def else_if_any(*args: List[ShaderVariable]): + utils.scope_decrement() + utils.append_contents(f"}} else if({' || '.join([str(proc_bool(elem)) for elem in args])}) {'{'}\n") + utils.scope_increment() + +def else_if_all(*args: List[ShaderVariable]): + utils.scope_decrement() + utils.append_contents(f"}} else if({' && '.join([str(proc_bool(elem)) for elem in args])}) {'{'}\n") + utils.scope_increment() + +def return_statement(arg=None): + if arg is None: + utils.append_contents("return;\n") + return + + if isinstance(arg, str): + arg_expr = arg + elif isinstance(arg, ShaderVariable) or utils.is_number(arg): + arg_expr = utils.resolve_input(arg) + else: + arg_expr = str(arg) + + utils.append_contents(f"return {arg_expr};\n") + +def while_statement(arg: ShaderVariable): + utils.append_contents(f"while({proc_bool(arg)}) {'{'}\n") + utils.scope_increment() + +def new_scope(indent: bool = True, comment: str = None): + if comment is None: + utils.append_contents("{\n") + else: + utils.append_contents("{ " + f"/* {comment} */\n") + + if indent: + utils.scope_increment() + +def end(indent: bool = True): + if indent: + utils.scope_decrement() + + utils.append_contents("}\n") + +def logical_and(arg1: ShaderVariable, arg2: ShaderVariable): + return utils.new_var(dtypes.int32, f"({proc_bool(arg1)} && {proc_bool(arg2)})", [arg1, arg2]) + +def logical_or(arg1: ShaderVariable, arg2: ShaderVariable): + return utils.new_var(dtypes.int32, f"({proc_bool(arg1)} || {proc_bool(arg2)})", [arg1, arg2]) diff --git a/vkdispatch/codegen/functions/exponential.py b/vkdispatch/codegen/functions/exponential.py new file mode 100644 index 00000000..68b2ebc6 --- /dev/null +++ b/vkdispatch/codegen/functions/exponential.py @@ -0,0 +1,156 @@ +import vkdispatch.base.dtype as dtypes +from ..variables.variables import ShaderVariable +from typing import Any, Union + +from . import utils +from . import scalar_eval as se + +def _is_glsl_backend() -> bool: + return utils.codegen_backend().name == "glsl" + +def _is_float64_dtype(var_type: dtypes.dtype) -> bool: + if dtypes.is_scalar(var_type): + return var_type == dtypes.float64 + + if dtypes.is_vector(var_type): + return var_type.scalar == dtypes.float64 + + return False + +def _float64_to_float32_dtype(var_type: dtypes.dtype) -> dtypes.dtype: + if var_type == dtypes.float64: + return dtypes.float32 + + if dtypes.is_vector(var_type) and var_type.scalar == dtypes.float64: + return dtypes.to_vector(dtypes.float32, var_type.child_count) + + raise TypeError(f"Unsupported fp64 fallback dtype: {var_type}") + +def _needs_glsl_float64_trig_fallback(var_type: dtypes.dtype) -> bool: + return _is_glsl_backend() and _is_float64_dtype(var_type) + +def process_float_var(var: ShaderVariable) -> bool: + pass + +def _unary_math_var(func_name: str, var: ShaderVariable) -> ShaderVariable: + result_type = utils.dtype_to_floating(var.var_type) + expr_arg_type = result_type + expr_arg = var.resolve() + expr_result_type = result_type + + if _needs_glsl_float64_trig_fallback(result_type) and func_name in {"exp", "exp2", "log", "log2"}: + expr_arg_type = _float64_to_float32_dtype(result_type) + expr_result_type = expr_arg_type + expr_arg = utils.backend_constructor_from_resolved(expr_arg_type, [expr_arg]) + + expr = utils.codegen_backend().unary_math_expr(func_name, expr_result_type, expr_arg) + + if expr_result_type != result_type: + expr = utils.backend_constructor_from_resolved(result_type, [expr]) + + return utils.new_var( + result_type, + expr, + parents=[var], + lexical_unit=True + ) + +def pow(x: Any, y: Any) -> Union[ShaderVariable, float]: + if utils.is_number(y) and utils.is_number(x): + return se.power(x, y) + + if utils.is_number(x) and isinstance(y, ShaderVariable): + result_type = utils.dtype_to_floating(y.var_type) + return utils.new_var( + result_type, + utils.codegen_backend().binary_math_expr( + "pow", + dtypes.float32, + utils.resolve_input(x), + result_type, + y.resolve(), + ), + parents=[y] + ) + + if utils.is_number(y) and isinstance(x, ShaderVariable): + result_type = utils.dtype_to_floating(x.var_type) + return utils.new_var( + result_type, + utils.codegen_backend().binary_math_expr( + "pow", + result_type, + x.resolve(), + dtypes.float32, + utils.resolve_input(y), + ), + parents=[x] + ) + + assert isinstance(y, ShaderVariable), "First argument must be a ShaderVariable or number" + assert isinstance(x, ShaderVariable), "Second argument must be a ShaderVariable or number" + + result_type = utils.dtype_to_floating(dtypes.cross_type(x.var_type, y.var_type)) + return utils.new_var( + result_type, + utils.codegen_backend().binary_math_expr( + "pow", + utils.dtype_to_floating(x.var_type), + x.resolve(), + utils.dtype_to_floating(y.var_type), + y.resolve(), + ), + parents=[y, x], + lexical_unit=True + ) + +def exp(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.exp(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("exp", var) + +def exp2(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.exp2(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("exp2", var) + +def log(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.log(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("log", var) + +def log2(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.log2(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("log2", var) + +# has double +def sqrt(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.sqrt(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("sqrt", var) + +# has double +def inversesqrt(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return float(1.0 / se.sqrt(var)) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + utils.mark_backend_feature("inversesqrt") + + return utils.new_var( + utils.dtype_to_floating(var.var_type), + f"inversesqrt({var.resolve()})", + parents=[var], + lexical_unit=True + ) diff --git a/vkdispatch/codegen/functions/geometric.py b/vkdispatch/codegen/functions/geometric.py new file mode 100644 index 00000000..6992a8ad --- /dev/null +++ b/vkdispatch/codegen/functions/geometric.py @@ -0,0 +1,83 @@ +import vkdispatch.base.dtype as dtypes +from ..variables.variables import ShaderVariable +from typing import Any, Union + +from . import utils +from . import scalar_eval as se + +def length(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.abs_value(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + + return utils.new_var( + utils.dtype_to_floating(var.var_type), + f"length({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def distance(x: Any, y: Any) -> Union[ShaderVariable, float]: + if utils.is_number(y) and utils.is_number(x): + return se.abs_value(y - x) + + base_var = None + + if isinstance(y, ShaderVariable): + base_var = y + elif isinstance(x, ShaderVariable): + base_var = x + else: + raise AssertionError("Arguments must be ShaderVariables or numbers") + + return utils.new_var( + utils.dtype_to_floating(base_var.var_type), + f"distance({utils.resolve_input(x)}, {utils.resolve_input(y)})", + parents=[y, x], + lexical_unit=True + ) + +def dot(x: Any, y: Any) -> Union[ShaderVariable, float]: + if utils.is_number(y) and utils.is_number(x): + return se.dot(x, y) + + base_var = None + + if isinstance(y, ShaderVariable): + base_var = y + elif isinstance(x, ShaderVariable): + base_var = x + else: + raise AssertionError("Arguments must be ShaderVariables or numbers") + + return utils.new_var( + utils.dtype_to_floating(base_var.var_type), + f"dot({utils.resolve_input(x)}, {utils.resolve_input(y)})", + parents=[y, x], + lexical_unit=True + ) + +def cross(x: ShaderVariable, y: ShaderVariable) -> ShaderVariable: + assert isinstance(x, ShaderVariable), "Argument x must be a ShaderVariable" + assert isinstance(y, ShaderVariable), "Argument y must be a ShaderVariable" + + assert x.var_type == dtypes.vec3, "Argument x must be of type vec3 or dvec3" + assert y.var_type == dtypes.vec3, "Argument y must be of type vec3 or dvec3" + + return utils.new_var( + dtypes.vec3, + f"cross({x.resolve()}, {y.resolve()})", + parents=[y, x], + lexical_unit=True + ) + +def normalize(var: ShaderVariable) -> ShaderVariable: + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable" + + return utils.new_var( + var.var_type, + f"normalize({var.resolve()})", + parents=[var], + lexical_unit=True + ) diff --git a/vkdispatch/codegen/functions/index_raveling.py b/vkdispatch/codegen/functions/index_raveling.py new file mode 100644 index 00000000..d1f38b86 --- /dev/null +++ b/vkdispatch/codegen/functions/index_raveling.py @@ -0,0 +1,83 @@ +import vkdispatch.base.dtype as dtypes + +from ..variables.variables import ShaderVariable + +from . import type_casting + +from . import utils + +from typing import List, Union, Tuple + +def sanitize_input(value: Union[ShaderVariable, Tuple[int, ...]]) -> Tuple[List[Union[ShaderVariable, int]], bool]: + axes_lengths = [] + + if isinstance(value, ShaderVariable): + assert dtypes.is_vector(value.var_type) or dtypes.is_scalar(value.var_type), f"Value is of type '{value.var_type.name}', but it must be a vector or integer!" + assert dtypes.is_integer_dtype(value.var_type), f"Value is of type '{value.var_type.name}', but it must be of integer type!" + + if dtypes.is_scalar(value.var_type): + axes_lengths.append(value) + return axes_lengths + + elem_count = value.var_type.child_count + assert elem_count >= 2 and elem_count <= 4, f"Value is of type '{value.var_type.name}', but it must have 2, 3 or 4 components!" + + # Since buffer shapes store total elem count in the 4th component, we ignore it here. + if elem_count == 4: + elem_count = 3 + + for i in range(elem_count): + axes_lengths.append(value[i]) + else: + if utils.check_is_int(value): + return [value] + + assert isinstance(value, (list, tuple)), "Value must be a ShaderVariable or a list/tuple of integers!" + + elem_count = len(value) + assert elem_count >= 1 or elem_count <= 3, f"Value has {elem_count} elements, but it must have 1, 2, or 3 elements!" + + for i in range(elem_count): + assert utils.check_is_int(value[i]), "When value is a list/tuple, all its elements must be integers!" + + axes_lengths.append(value[i]) + + return axes_lengths + +def ravel_index(index: Union[ShaderVariable, int], shape: Union[ShaderVariable, Tuple[int, ...]]): + sanitized_shape = sanitize_input(shape) + sanitized_index = sanitize_input(index) + + assert len(sanitized_index) == 1, f"Index must be a single integer value, not '{index}'!" + assert len(sanitized_shape) == 2 or len(sanitized_shape) == 3, f"Shape must have 2 or 3 elements, not '{shape}'!" + + if len(sanitized_shape) == 2: + x = sanitized_index[0] // sanitized_shape[1] + y = sanitized_index[0] % sanitized_shape[1] + + return type_casting.to_uvec2(x, y) + elif len(sanitized_shape) == 3: + x = sanitized_index[0] // (sanitized_shape[1] * sanitized_shape[2]) + y = (sanitized_index[0] // sanitized_shape[2]) % sanitized_shape[1] + z = sanitized_index[0] % sanitized_shape[2] + + return type_casting.to_uvec3(x, y, z) + else: + raise RuntimeError("Ravel index only supports shapes with 2 or 3 elements!") + +def unravel_index(index: Union[ShaderVariable, Tuple[int, ...]], shape: Union[ShaderVariable, Tuple[int, ...]]): + sanitized_shape = sanitize_input(shape) + sanitized_index = sanitize_input(index) + + assert len(sanitized_index) <= len(sanitized_shape), f"Index ({index}) must have the same number of elements as shape ({sanitized_shape})!" + + if len(sanitized_index) == 1: + return index + + if len(sanitized_index) == 2: + return sanitized_index[0] * sanitized_shape[1] + sanitized_index[1] + + elif len(sanitized_index) == 3: + return sanitized_index[0] * (sanitized_shape[1] * sanitized_shape[2]) + sanitized_index[1] * sanitized_shape[2] + sanitized_index[2] + else: + raise RuntimeError("Ravel index only supports shapes with 2 or 3 elements!") \ No newline at end of file diff --git a/vkdispatch/codegen/functions/matrix.py b/vkdispatch/codegen/functions/matrix.py new file mode 100644 index 00000000..6629bc25 --- /dev/null +++ b/vkdispatch/codegen/functions/matrix.py @@ -0,0 +1,83 @@ +import vkdispatch.base.dtype as dtypes +from ..variables.variables import ShaderVariable + +from . import utils + +def matrix_comp_mult(x: ShaderVariable, y: ShaderVariable) -> ShaderVariable: + assert isinstance(y, ShaderVariable), "Second argument must be a ShaderVariable" + assert isinstance(x, ShaderVariable), "First argument must be a ShaderVariable" + + assert dtypes.is_matrix(x.var_type), "First argument must be a matrix" + assert dtypes.is_matrix(y.var_type), "Second argument must be a matrix" + + assert x.var_type == y.var_type, "Matrices must have the same shape" + + return utils.new_var( + x.var_type, + f"matrixCompMult({utils.resolve_input(x)}, {utils.resolve_input(y)})", + parents=[y, x], + lexical_unit=True + ) + +def outer_product(x: ShaderVariable, y: ShaderVariable) -> ShaderVariable: + assert isinstance(y, ShaderVariable), "Second argument must be a ShaderVariable" + assert isinstance(x, ShaderVariable), "First argument must be a ShaderVariable" + + assert dtypes.is_vector(x.var_type), "First argument must be a matrix" + assert dtypes.is_vector(y.var_type), "Second argument must be a matrix" + + assert x.var_type == y.var_type, "Matrices must have the same shape" + + out_type = None + + if x.var_type == dtypes.vec2: + out_type = dtypes.mat2 + elif x.var_type == dtypes.vec3: + out_type = dtypes.mat3 + elif x.var_type == dtypes.vec4: + out_type = dtypes.mat4 + else: + raise AssertionError("Unsupported vector type for outer product") + + return utils.new_var( + out_type, + f"outerProduct({utils.resolve_input(x)}, {utils.resolve_input(y)})", + parents=[y, x], + lexical_unit=True + ) + +def transpose(var: ShaderVariable) ->ShaderVariable: + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable" + + assert dtypes.is_matrix(var.var_type), "Argument must be a matrix" + + return utils.new_var( + var.var_type, + f"transpose({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def determinant(var: ShaderVariable) -> ShaderVariable: + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable" + + assert dtypes.is_matrix(var.var_type), "Argument must be a matrix" + + return utils.new_var( + dtypes.float32, + f"determinant({var.resolve()})", + parents=[var], + lexical_unit=True + ) + +def inverse(var: ShaderVariable) -> ShaderVariable: + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable" + + assert dtypes.is_matrix(var.var_type), "Argument must be a matrix" + + return utils.new_var( + var.var_type, + f"inverse({var.resolve()})", + parents=[var], + lexical_unit=True + ) \ No newline at end of file diff --git a/vkdispatch/codegen/functions/printing.py b/vkdispatch/codegen/functions/printing.py new file mode 100644 index 00000000..2f1893fa --- /dev/null +++ b/vkdispatch/codegen/functions/printing.py @@ -0,0 +1,29 @@ +from ..variables.variables import ShaderVariable +from typing import Any +from . import utils + +def resolve_arg(arg: Any): + if isinstance(arg, str): + return arg + + return utils.resolve_input(arg) + +def printf(format: str, *args: Any): + resolved_args = [resolve_arg(arg) for arg in args] + utils.append_contents(utils.codegen_backend().printf_statement(format, resolved_args) + "\n") + +def print_vars(*args: Any, seperator=" "): + args_list = [] + + fmts = [] + + for arg in args: + if isinstance(arg, ShaderVariable): + args_list.append(arg.printf_args()) + fmts.append(arg.var_type.format_str) + else: + fmts.append(str(arg)) + + fmt = seperator.join(fmts) + + utils.append_contents(utils.codegen_backend().printf_statement(fmt, args_list) + "\n") diff --git a/vkdispatch/codegen/functions/registers.py b/vkdispatch/codegen/functions/registers.py new file mode 100644 index 00000000..64387ef1 --- /dev/null +++ b/vkdispatch/codegen/functions/registers.py @@ -0,0 +1,153 @@ +import vkdispatch.base.dtype as dtypes +from ..variables.variables import ShaderVariable +from typing import Optional + +from . import utils + +from .type_casting import to_dtype, to_complex, to_complex32, to_complex64, to_complex128 + +def new_register(var_type: dtypes.dtype, *args, var_name: Optional[str] = None): + new_var = utils.new_var( + var_type, + var_name, + [], + lexical_unit=True, + settable=True, + register=True + ) + + for arg in args: + if isinstance(arg, ShaderVariable): + arg.read_callback() + + if len(args) == 0: + args = (0,) + + decleration = to_dtype(var_type, *args).resolve() + + utils.append_contents(f"{utils.backend_type_name(new_var.var_type)} {new_var.name} = {decleration};\n") + + return new_var + +def new_float16_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.float16, *args, var_name=var_name) + +def new_float_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.float32, *args, var_name=var_name) + +def new_float64_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.float64, *args, var_name=var_name) + +def new_int16_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.int16, *args, var_name=var_name) + +def new_int_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.int32, *args, var_name=var_name) + +def new_int64_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.int64, *args, var_name=var_name) + +def new_uint16_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uint16, *args, var_name=var_name) + +def new_uint_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uint32, *args, var_name=var_name) + +def new_uint64_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uint64, *args, var_name=var_name) + +def _new_complex_register(var_type: dtypes.dtype, complex_ctor, *args, var_name: Optional[str] = None): + if len(args) > 0: + true_args = (complex_ctor(*args),) + else: + true_args = (0,) + + return new_register(var_type, *true_args, var_name=var_name) + +def new_complex_register(*args, var_name: Optional[str] = None): + if len(args) == 0: + return new_register(dtypes.complex64, 0, var_name=var_name) + + complex_value = to_complex(*args) + return new_register(complex_value.var_type, complex_value, var_name=var_name) + +def new_complex32_register(*args, var_name: Optional[str] = None): + return _new_complex_register(dtypes.complex32, to_complex32, *args, var_name=var_name) + +def new_complex64_register(*args, var_name: Optional[str] = None): + return _new_complex_register(dtypes.complex64, to_complex64, *args, var_name=var_name) + +def new_complex128_register(*args, var_name: Optional[str] = None): + return _new_complex_register(dtypes.complex128, to_complex128, *args, var_name=var_name) + +def new_hvec2_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.hvec2, *args, var_name=var_name) + +def new_hvec3_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.hvec3, *args, var_name=var_name) + +def new_hvec4_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.hvec4, *args, var_name=var_name) + +def new_vec2_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.vec2, *args, var_name=var_name) + +def new_vec3_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.vec3, *args, var_name=var_name) + +def new_vec4_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.vec4, *args, var_name=var_name) + +def new_dvec2_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.dvec2, *args, var_name=var_name) + +def new_dvec3_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.dvec3, *args, var_name=var_name) + +def new_dvec4_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.dvec4, *args, var_name=var_name) + +def new_ihvec2_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.ihvec2, *args, var_name=var_name) + +def new_ihvec3_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.ihvec3, *args, var_name=var_name) + +def new_ihvec4_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.ihvec4, *args, var_name=var_name) + +def new_uhvec2_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uhvec2, *args, var_name=var_name) + +def new_uhvec3_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uhvec3, *args, var_name=var_name) + +def new_uhvec4_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uhvec4, *args, var_name=var_name) + +def new_uvec2_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uvec2, *args, var_name=var_name) + +def new_uvec3_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uvec3, *args, var_name=var_name) + +def new_uvec4_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.uvec4, *args, var_name=var_name) + +def new_ivec2_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.ivec2, *args, var_name=var_name) + +def new_ivec3_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.ivec3, *args, var_name=var_name) + +def new_ivec4_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.ivec4, *args, var_name=var_name) + +def new_mat2_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.mat2, *args, var_name=var_name) + +def new_mat3_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.mat3, *args, var_name=var_name) + +def new_mat4_register(*args, var_name: Optional[str] = None): + return new_register(dtypes.mat4, *args, var_name=var_name) diff --git a/vkdispatch/codegen/functions/scalar_eval.py b/vkdispatch/codegen/functions/scalar_eval.py new file mode 100644 index 00000000..5d406ba2 --- /dev/null +++ b/vkdispatch/codegen/functions/scalar_eval.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import builtins +import math +import struct + +from typing import Any, Sequence, Tuple + + +def sign(value: float) -> float: + if value > 0: + return 1.0 + if value < 0: + return -1.0 + return 0.0 + + +def floor(value: float) -> float: + return float(math.floor(value)) + + +def ceil(value: float) -> float: + return float(math.ceil(value)) + + +def trunc(value: float) -> float: + return float(math.trunc(value)) + + +def round(value: float) -> float: + return float(builtins.round(value)) + + +def abs_value(value: Any) -> float: + return float(abs(value)) + + +def mod(x: float, y: float) -> float: + return float(x % y) + + +def modf(x: float, _unused: Any = None) -> Tuple[float, float]: + frac, whole = math.modf(x) + return float(frac), float(whole) + + +def minimum(x: float, y: float) -> float: + return float(x if x <= y else y) + + +def maximum(x: float, y: float) -> float: + return float(x if x >= y else y) + + +def clip(x: float, min_value: float, max_value: float) -> float: + return float(min(max(x, min_value), max_value)) + + +def interp(x: float, xp: Sequence[float], fp: Sequence[float]) -> float: + if len(xp) != len(fp): + raise ValueError("xp and fp must have the same length") + if len(xp) == 0: + raise ValueError("xp and fp must be non-empty") + if len(xp) == 1: + return float(fp[0]) + + if x <= xp[0]: + return float(fp[0]) + if x >= xp[-1]: + return float(fp[-1]) + + for index in range(1, len(xp)): + if x <= xp[index]: + x0 = xp[index - 1] + x1 = xp[index] + y0 = fp[index - 1] + y1 = fp[index] + + if x1 == x0: + return float(y0) + + t = (x - x0) / (x1 - x0) + return float(y0 + t * (y1 - y0)) + + return float(fp[-1]) + + +def isnan(value: float) -> bool: + return math.isnan(value) + + +def isinf(value: float) -> bool: + return math.isinf(value) + + +def float_bits_to_int(value: float) -> int: + return int(struct.unpack("=i", struct.pack("=f", float(value)))[0]) + + +def float_bits_to_uint(value: float) -> int: + return int(struct.unpack("=I", struct.pack("=f", float(value)))[0]) + + +def int_bits_to_float(value: int) -> float: + return float(struct.unpack("=f", struct.pack("=i", int(value)))[0]) + + +def uint_bits_to_float(value: int) -> float: + return float(struct.unpack("=f", struct.pack("=I", int(value)))[0]) + + +def power(x: float, y: float) -> float: + return float(math.pow(x, y)) + + +def exp(value: float) -> float: + return float(math.exp(value)) + + +def exp2(value: float) -> float: + if hasattr(math, "exp2"): + return float(math.exp2(value)) + return float(math.pow(2.0, value)) + + +def log(value: float) -> float: + return float(math.log(value)) + + +def log2(value: float) -> float: + return float(math.log2(value)) + + +def sqrt(value: float) -> float: + return float(math.sqrt(value)) + + +def sin(value: float) -> float: + return float(math.sin(value)) + + +def cos(value: float) -> float: + return float(math.cos(value)) + + +def tan(value: float) -> float: + return float(math.tan(value)) + + +def arcsin(value: float) -> float: + return float(math.asin(value)) + + +def arccos(value: float) -> float: + return float(math.acos(value)) + + +def arctan(value: float) -> float: + return float(math.atan(value)) + + +def arctan2(y: float, x: float) -> float: + return float(math.atan2(y, x)) + + +def sinh(value: float) -> float: + return float(math.sinh(value)) + + +def cosh(value: float) -> float: + return float(math.cosh(value)) + + +def tanh(value: float) -> float: + return float(math.tanh(value)) + + +def arcsinh(value: float) -> float: + return float(math.asinh(value)) + + +def arccosh(value: float) -> float: + return float(math.acosh(value)) + + +def arctanh(value: float) -> float: + return float(math.atanh(value)) + + +def dot(x: Any, y: Any) -> float: + if isinstance(x, (int, float, complex)) and isinstance(y, (int, float, complex)): + return float(x * y) + + return float(sum(a * b for a, b in zip(x, y))) diff --git a/vkdispatch/codegen/functions/subgroups.py b/vkdispatch/codegen/functions/subgroups.py new file mode 100644 index 00000000..23f90952 --- /dev/null +++ b/vkdispatch/codegen/functions/subgroups.py @@ -0,0 +1,66 @@ +import vkdispatch.base.dtype as dtypes +from ..variables.variables import ShaderVariable + +from . import utils + +def subgroup_add(arg1: ShaderVariable): + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_add_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) + +def subgroup_mul(arg1: ShaderVariable): + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_mul_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) + +def subgroup_min(arg1: ShaderVariable): + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_min_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) + +def subgroup_max(arg1: ShaderVariable): + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_max_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) + +def subgroup_and(arg1: ShaderVariable): + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_and_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) + +def subgroup_or(arg1: ShaderVariable): + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_or_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) + +def subgroup_xor(arg1: ShaderVariable): + return utils.new_var( + arg1.var_type, + utils.codegen_backend().subgroup_xor_expr(arg1.resolve(), arg1.var_type), + [arg1], + lexical_unit=True, + ) + +def subgroup_elect(): + return utils.new_var(dtypes.int32, utils.codegen_backend().subgroup_elect_expr(), [], lexical_unit=True) + +def subgroup_barrier(): + utils.append_contents(utils.codegen_backend().subgroup_barrier_statement() + "\n") diff --git a/vkdispatch/codegen/functions/trigonometry.py b/vkdispatch/codegen/functions/trigonometry.py new file mode 100644 index 00000000..19251db1 --- /dev/null +++ b/vkdispatch/codegen/functions/trigonometry.py @@ -0,0 +1,250 @@ +import vkdispatch.base.dtype as dtypes +from ..variables.variables import ShaderVariable +from typing import Any, List, Union + +from . import utils +from . import scalar_eval as se + +def dtype_to_floating(var_type: dtypes.dtype) -> dtypes.dtype: + return dtypes.make_floating_dtype(var_type) + +def _is_glsl_backend() -> bool: + return utils.codegen_backend().name == "glsl" + +def _is_float64_dtype(var_type: dtypes.dtype) -> bool: + if dtypes.is_scalar(var_type): + return var_type == dtypes.float64 + + if dtypes.is_vector(var_type): + return var_type.scalar == dtypes.float64 + + return False + +def _float64_to_float32_dtype(var_type: dtypes.dtype) -> dtypes.dtype: + if var_type == dtypes.float64: + return dtypes.float32 + + if dtypes.is_vector(var_type) and var_type.scalar == dtypes.float64: + return dtypes.to_vector(dtypes.float32, var_type.child_count) + + raise TypeError(f"Unsupported fp64 fallback dtype: {var_type}") + +def _needs_glsl_float64_trig_fallback(var_type: dtypes.dtype) -> bool: + return _is_glsl_backend() and _is_float64_dtype(var_type) + +def _cast_expr(var_type: dtypes.dtype, expr: str) -> str: + return utils.backend_constructor_from_resolved(var_type, [expr]) + +def _unary_math_var(func_name: str, var: ShaderVariable) -> ShaderVariable: + result_type = dtype_to_floating(var.var_type) + expr_arg_type = result_type + expr_arg = var.resolve() + expr_result_type = result_type + + if _needs_glsl_float64_trig_fallback(result_type): + expr_arg_type = _float64_to_float32_dtype(result_type) + expr_result_type = expr_arg_type + expr_arg = _cast_expr(expr_arg_type, expr_arg) + + expr = utils.codegen_backend().unary_math_expr(func_name, expr_result_type, expr_arg) + + if expr_result_type != result_type: + expr = _cast_expr(result_type, expr) + + return utils.new_var( + result_type, + expr, + parents=[var], + lexical_unit=True + ) + +def _binary_math_var( + func_name: str, + result_type: dtypes.dtype, + lhs_type: dtypes.dtype, + lhs_expr: str, + rhs_type: dtypes.dtype, + rhs_expr: str, + parents: List[ShaderVariable], + *, + lexical_unit: bool = False, +) -> ShaderVariable: + expr_result_type = result_type + expr_lhs_type = lhs_type + expr_rhs_type = rhs_type + expr_lhs = lhs_expr + expr_rhs = rhs_expr + + if _needs_glsl_float64_trig_fallback(result_type): + expr_result_type = _float64_to_float32_dtype(result_type) + + if _is_float64_dtype(lhs_type): + expr_lhs_type = _float64_to_float32_dtype(lhs_type) + expr_lhs = _cast_expr(expr_lhs_type, lhs_expr) + + if _is_float64_dtype(rhs_type): + expr_rhs_type = _float64_to_float32_dtype(rhs_type) + expr_rhs = _cast_expr(expr_rhs_type, rhs_expr) + + expr = utils.codegen_backend().binary_math_expr( + func_name, + expr_lhs_type, + expr_lhs, + expr_rhs_type, + expr_rhs, + ) + + if expr_result_type != result_type: + expr = _cast_expr(result_type, expr) + + return utils.new_var( + result_type, + expr, + parents=parents, + lexical_unit=lexical_unit, + ) + +def radians(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return var * (3.141592653589793 / 180.0) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + utils.mark_backend_feature("radians") + return _unary_math_var("radians", var) + +def degrees(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return var * (180.0 / 3.141592653589793) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + utils.mark_backend_feature("degrees") + return _unary_math_var("degrees", var) + +def sin(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.sin(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("sin", var) + +def cos(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.cos(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("cos", var) + +def tan(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.tan(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("tan", var) + +def asin(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.arcsin(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("asin", var) + +def acos(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.arccos(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("acos", var) + +def atan(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.arctan(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("atan", var) + +def atan2(y: Any, x: Any) -> Union[ShaderVariable, float]: + if utils.is_number(y) and utils.is_number(x): + return se.arctan2(y, x) + + if utils.is_number(x) and isinstance(y, ShaderVariable): + result_type = dtype_to_floating(y.var_type) + scalar_result_type = result_type.scalar if dtypes.is_vector(result_type) else result_type + return _binary_math_var( + "atan2", + result_type, + result_type, + y.resolve(), + scalar_result_type, + utils.resolve_input(x), + [y], + ) + + if utils.is_number(y) and isinstance(x, ShaderVariable): + result_type = dtype_to_floating(x.var_type) + scalar_result_type = result_type.scalar if dtypes.is_vector(result_type) else result_type + return _binary_math_var( + "atan2", + result_type, + scalar_result_type, + utils.resolve_input(y), + result_type, + x.resolve(), + [x], + ) + + assert isinstance(y, ShaderVariable), "First argument must be a ShaderVariable or number" + assert isinstance(x, ShaderVariable), "Second argument must be a ShaderVariable or number" + + result_type = dtype_to_floating(dtypes.cross_type(y.var_type, x.var_type)) + return _binary_math_var( + "atan2", + result_type, + result_type, + y.resolve(), + dtype_to_floating(x.var_type), + x.resolve(), + [y, x], + lexical_unit=True, + ) + +def sinh(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.sinh(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("sinh", var) + +def cosh(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.cosh(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("cosh", var) + +def tanh(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.tanh(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("tanh", var) + +def asinh(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.arcsinh(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("asinh", var) + +def acosh(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.arccosh(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("acosh", var) + +def atanh(var: Any) -> Union[ShaderVariable, float]: + if utils.is_number(var): + return se.arctanh(var) + + assert isinstance(var, ShaderVariable), "Argument must be a ShaderVariable or number" + return _unary_math_var("atanh", var) diff --git a/vkdispatch/codegen/functions/type_casting.py b/vkdispatch/codegen/functions/type_casting.py new file mode 100644 index 00000000..276a479a --- /dev/null +++ b/vkdispatch/codegen/functions/type_casting.py @@ -0,0 +1,174 @@ +import vkdispatch.base.dtype as dtypes +from typing import Optional + +from . import utils +from ..variables.variables import ShaderVariable + +def to_dtype(var_type: dtypes.dtype, *args): + return utils.new_var( + var_type, + utils.backend_constructor(var_type, *args), + args, + lexical_unit=True + ) + +def str_to_dtype(var_type: dtypes.dtype, + value: str, + parents: Optional[list] = None, + lexical_unit: bool = False, + settable: bool = False, + register: bool = False): + return utils.new_var( + var_type, + value, + parents=parents if parents is not None else [], + lexical_unit=lexical_unit, + settable=settable, + register=register + ) + +def to_float16(*args): + return to_dtype(dtypes.float16, *args) + +def to_float(*args): + return to_dtype(dtypes.float32, *args) + +def to_float64(*args): + return to_dtype(dtypes.float64, *args) + +def to_int16(*args): + return to_dtype(dtypes.int16, *args) + +def to_int(*args): + return to_dtype(dtypes.int32, *args) + +def to_int64(*args): + return to_dtype(dtypes.int64, *args) + +def to_uint16(*args): + return to_dtype(dtypes.uint16, *args) + +def to_uint(*args): + return to_dtype(dtypes.uint32, *args) + +def to_uint64(*args): + return to_dtype(dtypes.uint64, *args) + +def _complex_from_real_arg(arg) -> dtypes.dtype: + if isinstance(arg, ShaderVariable): + if dtypes.is_complex(arg.var_type): + return arg.var_type + if dtypes.is_scalar(arg.var_type): + return dtypes.complex_from_float(dtypes.make_floating_dtype(arg.var_type)) + raise TypeError(f"Unsupported variable type for complex conversion: {arg.var_type}") + + if utils.is_number(arg): + base_type = utils.number_to_dtype(arg) + if dtypes.is_complex(base_type): + return base_type + return dtypes.complex_from_float(dtypes.make_floating_dtype(base_type)) + + raise TypeError(f"Unsupported argument type for complex conversion: {type(arg)}") + +def _infer_complex_dtype(*args) -> dtypes.dtype: + complex_type = _complex_from_real_arg(args[0]) + + for arg in args[1:]: + complex_type = dtypes.cross_type(complex_type, _complex_from_real_arg(arg)) + + return complex_type + +def _to_complex_dtype(var_type: dtypes.dtype, *args): + assert len(args) == 1 or len(args) == 2, "Must give one of two arguments for complex init" + + if len(args) == 1 and isinstance(args[0], ShaderVariable) and dtypes.is_complex(args[0].var_type): + return to_dtype(var_type, args[0]) + + if len(args) == 1: + return to_dtype(var_type, args[0], 0) + + return to_dtype(var_type, *args) + +def to_complex32(*args): + return _to_complex_dtype(dtypes.complex32, *args) + +def to_complex(*args): + return _to_complex_dtype(_infer_complex_dtype(*args), *args) + +def to_complex64(*args): + return _to_complex_dtype(dtypes.complex64, *args) + +def to_complex128(*args): + return _to_complex_dtype(dtypes.complex128, *args) + +def to_hvec2(*args): + return to_dtype(dtypes.hvec2, *args) + +def to_hvec3(*args): + return to_dtype(dtypes.hvec3, *args) + +def to_hvec4(*args): + return to_dtype(dtypes.hvec4, *args) + +def to_vec2(*args): + return to_dtype(dtypes.vec2, *args) + +def to_vec3(*args): + return to_dtype(dtypes.vec3, *args) + +def to_vec4(*args): + return to_dtype(dtypes.vec4, *args) + +def to_dvec2(*args): + return to_dtype(dtypes.dvec2, *args) + +def to_dvec3(*args): + return to_dtype(dtypes.dvec3, *args) + +def to_dvec4(*args): + return to_dtype(dtypes.dvec4, *args) + +def to_ihvec2(*args): + return to_dtype(dtypes.ihvec2, *args) + +def to_ihvec3(*args): + return to_dtype(dtypes.ihvec3, *args) + +def to_ihvec4(*args): + return to_dtype(dtypes.ihvec4, *args) + +def to_ivec2(*args): + return to_dtype(dtypes.ivec2, *args) + +def to_ivec3(*args): + return to_dtype(dtypes.ivec3, *args) + +def to_ivec4(*args): + return to_dtype(dtypes.ivec4, *args) + +def to_uhvec2(*args): + return to_dtype(dtypes.uhvec2, *args) + +def to_uhvec3(*args): + return to_dtype(dtypes.uhvec3, *args) + +def to_uhvec4(*args): + return to_dtype(dtypes.uhvec4, *args) + +def to_uvec2(*args): + return to_dtype(dtypes.uvec2, *args) + +def to_uvec3(*args): + return to_dtype(dtypes.uvec3, *args) + +def to_uvec4(*args): + return to_dtype(dtypes.uvec4, *args) + +def to_mat2(*args): + return to_dtype(dtypes.mat2, *args) + +def to_mat3(*args): + return to_dtype(dtypes.mat3, *args) + +def to_mat4(*args): + return to_dtype(dtypes.mat4, *args) diff --git a/vkdispatch/codegen/functions/utils.py b/vkdispatch/codegen/functions/utils.py new file mode 100644 index 00000000..ddb866fb --- /dev/null +++ b/vkdispatch/codegen/functions/utils.py @@ -0,0 +1,56 @@ +import vkdispatch.base.dtype as dtypes +from ..variables.variables import ShaderVariable +from typing import List, Optional + +from .base_functions.base_utils import * +from ..global_builder import get_codegen_backend + +from ..shader_writer import scope_increment, scope_decrement + +def new_var(var_type: dtypes.dtype, + var_name: Optional[str], + parents: list, + lexical_unit: bool = False, + settable: bool = False, + register: bool = False) -> ShaderVariable: + return new_base_var(var_type, var_name, parents, lexical_unit, settable, register) + +def codegen_backend(): + return get_codegen_backend() + +def mark_backend_feature(feature_name: str) -> None: + codegen_backend().mark_feature_usage(feature_name) + +def backend_type_name(var_type: dtypes.dtype) -> str: + return codegen_backend().type_name(var_type) + +def _resolve_arg_types(args: tuple) -> List[Optional[dtypes.dtype]]: + resolved_types: List[Optional[dtypes.dtype]] = [] + + for elem in args: + if isinstance(elem, ShaderVariable): + resolved_types.append(elem.var_type) + continue + + if is_number(elem): + resolved_types.append(number_to_dtype(elem)) + continue + + resolved_types.append(None) + + return resolved_types + +def backend_constructor(var_type: dtypes.dtype, *args) -> str: + resolved_types = _resolve_arg_types(args) + return codegen_backend().constructor( + var_type, + [resolve_input(elem) for elem in args], + arg_types=resolved_types, + ) + +def backend_constructor_from_resolved( + var_type: dtypes.dtype, + args: List[str], + arg_types: Optional[List[Optional[dtypes.dtype]]] = None, +) -> str: + return codegen_backend().constructor(var_type, args, arg_types=arg_types) diff --git a/vkdispatch/codegen/global_builder.py b/vkdispatch/codegen/global_builder.py index 256efab5..8a14b1b9 100644 --- a/vkdispatch/codegen/global_builder.py +++ b/vkdispatch/codegen/global_builder.py @@ -1,381 +1,90 @@ -import vkdispatch as vd +import threading +import vkdispatch.base.dtype as dtypes +from .shader_writer import set_shader_writer +from .backends import CodeGenBackend, GLSLBackend, CUDABackend, OpenCLBackend +from vkdispatch.base.init import is_cuda, is_opencl +from typing import Optional, TYPE_CHECKING, Union -from .builder import ShaderBuilder, ShaderVariable +if TYPE_CHECKING: + from .builder import ShaderBuilder -import contextlib +_builder_context = threading.local() +_shader_print_line_numbers = threading.local() +_codegen_backend = threading.local() -from typing import List, Union, Optional +def _make_runtime_default_codegen_backend() -> CodeGenBackend: + if is_cuda(): + return CUDABackend() -inf_f32 = "uintBitsToFloat(0x7F800000)" -ninf_f32 = "uintBitsToFloat(0xFF800000)" + if is_opencl(): + return OpenCLBackend() -class GlobalBuilder: - obj = ShaderBuilder() + return GLSLBackend() -def set_global_builder(builder: ShaderBuilder): - old_value = GlobalBuilder.obj - GlobalBuilder.obj = builder # Update the global reference. - return old_value +def get_shader_print_line_numbers() -> bool: + return getattr(_shader_print_line_numbers, 'value', False) -@contextlib.contextmanager -def builder_context( - enable_subgroup_ops: bool = True, - enable_atomic_float_ops: bool = True, - enable_printf: bool = True, - enable_exec_bounds: bool = True): +def set_shader_print_line_numbers(value: bool): + _shader_print_line_numbers.value = value - builder = ShaderBuilder( - enable_atomic_float_ops=enable_atomic_float_ops, - enable_subgroup_ops=enable_subgroup_ops, - enable_printf=enable_printf, - enable_exec_bounds=enable_exec_bounds, - is_apple_device=vd.get_context().is_apple() - ) - old_builder = set_global_builder(builder) +def _get_builder() -> Optional['ShaderBuilder']: + return getattr(_builder_context, 'active_builder', None) - try: - yield builder - finally: - set_global_builder(old_builder) +def _get_codegen_backend() -> Optional[CodeGenBackend]: + return getattr(_codegen_backend, 'active_backend', None) -def comment(text: str): - GlobalBuilder.obj.comment(text) +def set_codegen_backend(backend: Optional[Union[CodeGenBackend, str]]): + if backend is None: + _codegen_backend.active_backend = None + return -def global_invocation(): - return GlobalBuilder.obj.global_invocation + if isinstance(backend, str): + backend_name = backend.lower() -def local_invocation(): - return GlobalBuilder.obj.local_invocation + if backend_name == "glsl": + _codegen_backend.active_backend = GLSLBackend() + return -def workgroup(): - return GlobalBuilder.obj.workgroup + if backend_name == "cuda": + _codegen_backend.active_backend = CUDABackend() + return -def workgroup_size(): - return GlobalBuilder.obj.workgroup_size + if backend_name == "opencl": + _codegen_backend.active_backend = OpenCLBackend() + return -def num_workgroups(): - return GlobalBuilder.obj.num_workgroups + raise ValueError(f"Unknown codegen backend '{backend}'") -def num_subgroups(): - return GlobalBuilder.obj.num_subgroups + _codegen_backend.active_backend = backend -def subgroup_id(): - return GlobalBuilder.obj.subgroup_id +def get_codegen_backend() -> CodeGenBackend: + builder = _get_builder() -def subgroup_size(): - return GlobalBuilder.obj.subgroup_size + if builder is not None: + return builder.backend -def subgroup_invocation(): - return GlobalBuilder.obj.subgroup_invocation + backend = _get_codegen_backend() -def set_mapping_index(index: ShaderVariable): - GlobalBuilder.obj.set_mapping_index(index) + if backend is None: + backend = _make_runtime_default_codegen_backend() + _codegen_backend.active_backend = backend -def set_kernel_index(index: ShaderVariable): - GlobalBuilder.obj.set_kernel_index(index) + return backend -def set_mapping_registers(registers: ShaderVariable): - GlobalBuilder.obj.set_mapping_registers(registers) +def set_builder(builder: 'ShaderBuilder'): + if builder is None: + _builder_context.active_builder = None + set_shader_writer(None) + return -def mapping_index(): - return GlobalBuilder.obj.mapping_index + assert _get_builder() is None, "A global ShaderBuilder is already set for the current thread!" + set_shader_writer(builder) + _builder_context.active_builder = builder -def kernel_index(): - return GlobalBuilder.obj.kernel_index +def get_builder() -> 'ShaderBuilder': + builder = _get_builder() + assert builder is not None, "No global ShaderBuilder is set for the current thread!" + return builder -def mapping_registers(): - return GlobalBuilder.obj.mapping_registers - -def shared_buffer(var_type: vd.dtype, size: int, var_name: Optional[str] = None): - return GlobalBuilder.obj.shared_buffer(var_type, size, var_name) - -def abs(arg: ShaderVariable): - return GlobalBuilder.obj.abs(arg) - -def acos(arg: ShaderVariable): - return GlobalBuilder.obj.acos(arg) - -def acosh(arg: ShaderVariable): - return GlobalBuilder.obj.acosh(arg) - -def asin(arg: ShaderVariable): - return GlobalBuilder.obj.asin(arg) - -def asinh(arg: ShaderVariable): - return GlobalBuilder.obj.asinh(arg) - -def atan(arg: ShaderVariable): - return GlobalBuilder.obj.atan(arg) - -def atan2(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.atan2(arg1, arg2) - -def atanh(arg: ShaderVariable): - return GlobalBuilder.obj.atanh(arg) - -def atomic_add(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.atomic_add(arg1, arg2) - -def barrier(): - GlobalBuilder.obj.barrier() - -def ceil(arg: ShaderVariable): - return GlobalBuilder.obj.ceil(arg) - -def clamp(arg: ShaderVariable, min_val: ShaderVariable, max_val: ShaderVariable): - return GlobalBuilder.obj.clamp(arg, min_val, max_val) - -def cos(arg: ShaderVariable): - return GlobalBuilder.obj.cos(arg) - -def cosh(arg: ShaderVariable): - return GlobalBuilder.obj.cosh(arg) - -def cross(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.cross(arg1, arg2) - -def degrees(arg: ShaderVariable): - return GlobalBuilder.obj.degrees(arg) - -def determinant(arg: ShaderVariable): - return GlobalBuilder.obj.determinant(arg) - -def distance(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.distance(arg1, arg2) - -def dot(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.dot(arg1, arg2) - -def exp(arg: ShaderVariable): - return GlobalBuilder.obj.exp(arg) - -def exp2(arg: ShaderVariable): - return GlobalBuilder.obj.exp2(arg) - -def float_bits_to_int(arg: ShaderVariable): - return GlobalBuilder.obj.float_bits_to_int(arg) - -def float_bits_to_uint(arg: ShaderVariable): - return GlobalBuilder.obj.float_bits_to_uint(arg) - -def floor(arg: ShaderVariable): - return GlobalBuilder.obj.floor(arg) - -def fma(arg1: ShaderVariable, arg2: ShaderVariable, arg3: ShaderVariable): - return GlobalBuilder.obj.fma(arg1, arg2, arg3) - -def int_bits_to_float(arg: ShaderVariable): - return GlobalBuilder.obj.int_bits_to_float(arg) - -def inverse(arg: ShaderVariable): - return GlobalBuilder.obj.inverse(arg) - -def inverse_sqrt(arg: ShaderVariable): - return GlobalBuilder.obj.inverse_sqrt(arg) - -def isinf(arg: ShaderVariable): - return GlobalBuilder.obj.isinf(arg) - -def isnan(arg: ShaderVariable): - return GlobalBuilder.obj.isnan(arg) - -def length(arg: ShaderVariable): - return GlobalBuilder.obj.length(arg) - -def log(arg: ShaderVariable): - return GlobalBuilder.obj.log(arg) - -def log2(arg: ShaderVariable): - return GlobalBuilder.obj.log2(arg) - -def max(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.max(arg1, arg2) - -def memory_barrier(): - GlobalBuilder.obj.memory_barrier() - -def memory_barrier_shared(): - GlobalBuilder.obj.memory_barrier_shared() - -def min(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.min(arg1, arg2) - -def mix(arg1: ShaderVariable, arg2: ShaderVariable, arg3: ShaderVariable): - return GlobalBuilder.obj.mix(arg1, arg2, arg3) - -def mod(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.mod(arg1, arg2) - -def normalize(arg: ShaderVariable): - return GlobalBuilder.obj.normalize(arg) - -def pow(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.pow(arg1, arg2) - -def radians(arg: ShaderVariable): - return GlobalBuilder.obj.radians(arg) - -def round(arg: ShaderVariable): - return GlobalBuilder.obj.round(arg) - -def round_even(arg: ShaderVariable): - return GlobalBuilder.obj.round_even(arg) - -def sign(arg: ShaderVariable): - return GlobalBuilder.obj.sign(arg) - -def sin(arg: ShaderVariable): - return GlobalBuilder.obj.sin(arg) - -def sinh(arg: ShaderVariable): - return GlobalBuilder.obj.sinh(arg) - -def smoothstep(arg1: ShaderVariable, arg2: ShaderVariable, arg3: ShaderVariable): - return GlobalBuilder.obj.smoothstep(arg1, arg2, arg3) - -def sqrt(arg: ShaderVariable): - return GlobalBuilder.obj.sqrt(arg) - -def step(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.step(arg1, arg2) - -def tan(arg: ShaderVariable): - return GlobalBuilder.obj.tan(arg) - -def tanh(arg: ShaderVariable): - return GlobalBuilder.obj.tanh(arg) - -def transpose(arg: ShaderVariable): - return GlobalBuilder.obj.transpose(arg) - -def trunc(arg: ShaderVariable): - return GlobalBuilder.obj.trunc(arg) - -def uint_bits_to_float(arg: ShaderVariable): - return GlobalBuilder.obj.uint_bits_to_float(arg) - -def mult_c64(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.mult_c64(arg1, arg2) - -def mult_c64_by_const(arg1: ShaderVariable, number: complex): - return GlobalBuilder.obj.mult_c64_by_const(arg1, number) - -def mult_conj_c64(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.mult_conj_c64(arg1, arg2) - -def if_statement(arg: ShaderVariable, command: Optional[str] = None): - GlobalBuilder.obj.if_statement(arg, command=command) - -def if_any(*args: List[ShaderVariable]): - GlobalBuilder.obj.if_any(*args) - -def if_all(*args: List[ShaderVariable]): - GlobalBuilder.obj.if_all(*args) - -def else_statement(): - GlobalBuilder.obj.else_statement() - -def else_if_statement(arg: ShaderVariable): - GlobalBuilder.obj.else_if_statement(arg) - -def else_if_any(*args: List[ShaderVariable]): - GlobalBuilder.obj.else_if_any(*args) - -def else_if_all(*args: List[ShaderVariable]): - GlobalBuilder.obj.else_if_all(*args) - -def return_statement(arg=None): - GlobalBuilder.obj.return_statement(arg) - -def while_statement(arg: ShaderVariable): - GlobalBuilder.obj.while_statement(arg) - -def new_scope(): - GlobalBuilder.obj.new_scope() - -def end(): - GlobalBuilder.obj.end() - -def logical_and(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.logical_and(arg1, arg2) - -def logical_or(arg1: ShaderVariable, arg2: ShaderVariable): - return GlobalBuilder.obj.logical_or(arg1, arg2) - -def subgroup_add(arg1: ShaderVariable): - return GlobalBuilder.obj.subgroup_add(arg1) - -def subgroup_mul(arg1: ShaderVariable): - return GlobalBuilder.obj.subgroup_mul(arg1) - -def subgroup_min(arg1: ShaderVariable): - return GlobalBuilder.obj.subgroup_min(arg1) - -def subgroup_max(arg1: ShaderVariable): - return GlobalBuilder.obj.subgroup_max(arg1) - -def subgroup_and(arg1: ShaderVariable): - return GlobalBuilder.obj.subgroup_and(arg1) - -def subgroup_or(arg1: ShaderVariable): - return GlobalBuilder.obj.subgroup_or(arg1) - -def subgroup_xor(arg1: ShaderVariable): - return GlobalBuilder.obj.subgroup_xor(arg1) - -def subgroup_elect(): - return GlobalBuilder.obj.subgroup_elect() - -def subgroup_barrier(): - GlobalBuilder.obj.subgroup_barrier() - -def new(var_type: vd.dtype, *args, var_name: Optional[str] = None): - return GlobalBuilder.obj.new(var_type, *args, var_name=var_name) - -def new_float(*args, var_name: Optional[str] = None): - return new(vd.float32, *args, var_name=var_name) - -def new_int(*args, var_name: Optional[str] = None): - return new(vd.int32, *args, var_name=var_name) - -def new_uint(*args, var_name: Optional[str] = None): - return new(vd.uint32, *args, var_name=var_name) - -def new_vec2(*args, var_name: Optional[str] = None): - return new(vd.vec2, *args, var_name=var_name) - -def new_vec3(*args, var_name: Optional[str] = None): - return new(vd.vec3, *args, var_name=var_name) - -def new_vec4(*args, var_name: Optional[str] = None): - return new(vd.vec4, *args, var_name=var_name) - -def new_uvec2(*args, var_name: Optional[str] = None): - return new(vd.uvec2, *args, var_name=var_name) - -def new_uvec3(*args, var_name: Optional[str] = None): - return new(vd.uvec3, *args, var_name=var_name) - -def new_uvec4(*args, var_name: Optional[str] = None): - return new(vd.uvec4, *args, var_name=var_name) - -def new_ivec2(*args, var_name: Optional[str] = None): - return new(vd.ivec2, *args, var_name=var_name) - -def new_ivec3(*args, var_name: Optional[str] = None): - return new(vd.ivec3, *args, var_name=var_name) - -def new_ivec4(*args, var_name: Optional[str] = None): - return new(vd.ivec4, *args, var_name=var_name) - -def printf(format: str, *args: Union[ShaderVariable, str], seperator=" "): - GlobalBuilder.obj.printf(format, *args, seperator=seperator) - -def print_vars(*args: Union[ShaderVariable, str], seperator=" "): - GlobalBuilder.obj.print_vars(*args, seperator=seperator) - -def unravel_index(index: ShaderVariable, shape: ShaderVariable): - return GlobalBuilder.obj.unravel_index(index, shape) - -def complex_from_euler_angle(angle: ShaderVariable): - return GlobalBuilder.obj.complex_from_euler_angle(angle) \ No newline at end of file +def shared_buffer(var_type: dtypes.dtype, size: int, var_name: Optional[str] = None): + return get_builder().shared_buffer(var_type, size, var_name) diff --git a/vkdispatch/codegen/shader_writer.py b/vkdispatch/codegen/shader_writer.py new file mode 100644 index 00000000..b374588c --- /dev/null +++ b/vkdispatch/codegen/shader_writer.py @@ -0,0 +1,93 @@ +import threading +import vkdispatch.base.dtype as dtypes +from .variables.base_variable import BaseVariable +from typing import Optional + +_thread_context = threading.local() + +def _get_shader_writer() -> Optional['ShaderWriter']: + return getattr(_thread_context, 'writer', None) + +def shader_writer() -> 'ShaderWriter': + writer = _get_shader_writer() + assert writer is not None, "No global ShaderWriter is set for the current thread!" + return writer + +def set_shader_writer(writer: 'ShaderWriter'): + if writer is None: + _thread_context.writer = None + return + + assert _get_shader_writer() is None, "A global ShaderWriter is already set for the current thread!" + _thread_context.writer = writer + +class ShaderWriter: + var_count: int + contents: str + scope_num: int + + def __init__(self): + self.var_count = 0 + self.scope_num = 1 + self.contents = "" + + def append_contents(self, contents: str) -> None: + self.contents += (" " * self.scope_num) + contents + + def new_name(self) -> str: + new_var = f"var{self.var_count}" + self.var_count += 1 + return new_var + + def scope_increment(self): + self.scope_num += 1 + + def scope_decrement(self): + self.scope_num -= 1 + + def new_var(self, + var_type: dtypes.dtype, + var_name: str, + parents: list, + lexical_unit: bool = False, + settable: bool = False, + register: bool = False) -> BaseVariable: + raise NotImplementedError + + def new_scaled_var(self, + var_type: dtypes.dtype, + name: str, + scale: int = 1, + offset: int = 0, + parents: list = None): + raise NotImplementedError + +def append_contents(contents: str): + shader_writer().append_contents(contents) + +def new_name() -> str: + return shader_writer().new_name() + +def scope_increment(): + shader_writer().scope_increment() + +def scope_decrement(): + shader_writer().scope_decrement() + +def scope_indentation() -> str: + return " " * shader_writer().scope_num + +def new_var(var_type: dtypes.dtype, + var_name: Optional[str], + parents: list, + lexical_unit: bool = False, + settable: bool = False, + register: bool = False) -> BaseVariable: + return shader_writer().new_var(var_type, var_name, parents, lexical_unit, settable, register) + +def new_scaled_var(var_type: dtypes.dtype, + name: str, + scale: int = 1, + offset: int = 0, + parents: list = None): + return shader_writer().new_scaled_var(var_type, name, scale, offset, parents) diff --git a/vkdispatch/codegen/variables/__init__.py b/vkdispatch/codegen/variables/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vkdispatch/codegen/variables/base_variable.py b/vkdispatch/codegen/variables/base_variable.py new file mode 100644 index 00000000..cb730815 --- /dev/null +++ b/vkdispatch/codegen/variables/base_variable.py @@ -0,0 +1,82 @@ +import vkdispatch.base.dtype as dtypes +from typing import List, Optional + +class BaseVariable: + var_type: dtypes.dtype + name: str + raw_name: str + can_index: bool + use_child_type: bool + lexical_unit: bool + settable: bool + parents: List["BaseVariable"] + + def __init__(self, + var_type: dtypes.dtype, + name: str, + raw_name: Optional[str] = None, + lexical_unit: bool = False, + settable: bool = False, + register: bool = False, + parents: List["BaseVariable"] = None + ) -> None: + self.var_type = var_type + self.lexical_unit = lexical_unit + self.can_index = False + self.use_child_type = True + + assert name is not None, "Variable name cannot be None!" + + self.name = name + self.raw_name = raw_name if raw_name is not None else self.name + + if register: + assert settable, "An unsettable register makes no sense" + + self.settable = settable + self.register = register + + if parents is None: + parents = [] + + self.parents = [] + + for parent_var in parents: + if isinstance(parent_var, BaseVariable): + self.parents.append(parent_var) + + def is_setable(self): + return self.settable + + def is_register(self): + return self.register + + def resolve(self) -> str: + if self.lexical_unit: + return self.name + + return f"({self.name})" + + def read_callback(self): + for parent in self.parents: + parent.read_callback() + + def write_callback(self): + for parent in self.parents: + parent.write_callback() + + def printf_args(self) -> str: + total_count = 1 # np.prod(self.var_type.shape) + + for dim in self.var_type.shape: + total_count *= dim + + if total_count == 1: + return self.name + + args_list = [] + + for i in range(0, total_count): + args_list.append(f"{self.name}[{i}]") + + return ",".join(args_list) \ No newline at end of file diff --git a/vkdispatch/codegen/variables/bound_variables.py b/vkdispatch/codegen/variables/bound_variables.py new file mode 100644 index 00000000..228ff299 --- /dev/null +++ b/vkdispatch/codegen/variables/bound_variables.py @@ -0,0 +1,184 @@ +from .variables import ShaderVariable +import vkdispatch.base.dtype as dtypes + +from ..functions import type_casting +from ..functions.base_functions import base_utils +from ..global_builder import get_codegen_backend + +from typing import Callable, Optional + +class BoundVariable(ShaderVariable): + binding: int = -1 + + def __init__(self, + var_type: dtypes.dtype, + binding: int, + name: str, + ) -> None: + super().__init__(var_type, name, lexical_unit=True) + + self.binding = binding + +class BufferVariable(BoundVariable): + read_lambda: Callable[[], None] + write_lambda: Callable[[], None] + scalar_expr: Optional[str] + codegen_backend: Optional[object] + + def __init__(self, + var_type: dtypes.dtype, + binding: int, + name: str, + shape_var: "ShaderVariable" = None, + shape_var_factory: Optional[Callable[[], "ShaderVariable"]] = None, + shape_name: Optional[str] = None, + raw_name: Optional[str] = None, + scalar_expr: Optional[str] = None, + codegen_backend: Optional[object] = None, + read_lambda: Callable[[], None] = None, + write_lambda: Callable[[], None] = None, + ) -> None: + super().__init__(var_type, binding, name) + + self.name = name if name is not None else self.name + self.raw_name = raw_name if raw_name is not None else self.raw_name + self.settable = True + + self.read_lambda = read_lambda + self.write_lambda = write_lambda + + self._shape_var = shape_var + self._shape_var_factory = shape_var_factory + self.shape_name = shape_name + self.scalar_expr = scalar_expr + self.codegen_backend = codegen_backend + self.can_index = True + self.use_child_type = False + + @property + def shape(self) -> "ShaderVariable": + if self._shape_var is None: + assert self._shape_var_factory is not None, "Buffer shape variable factory is not available!" + self._shape_var = self._shape_var_factory() + + return self._shape_var + + def read_callback(self): + self.read_lambda() + + def write_callback(self): + self.write_lambda() + + def __getitem__(self, index) -> "ShaderVariable": + assert self.can_index, f"Variable '{self.resolve()}' of type '{self.var_type.name}' cannot be indexed into!" + + return_type = self.var_type.child_type if self.use_child_type else self.var_type + + if isinstance(index, tuple): + assert len(index) == 1, "Only single index is supported, cannot use multi-dimentional indexing!" + index = index[0] + + if base_utils.is_int_number(index): + return ShaderVariable( + return_type, + f"{self.resolve()}[{index}]", + parents=[self], + settable=self.settable, + lexical_unit=True, + buffer_root=self, + buffer_index_expr=str(index), + ) + + assert isinstance(index, ShaderVariable), f"Index must be a ShaderVariable or int type, not {type(index)}!" + assert dtypes.is_scalar(index.var_type), "Indexing variable must be a scalar!" + assert dtypes.is_integer_dtype(index.var_type), "Indexing variable must be an integer type!" + + return ShaderVariable( + return_type, + f"{self.resolve()}[{index.resolve()}]", + parents=[self, index], + settable=self.settable, + lexical_unit=True, + buffer_root=self, + buffer_index_expr=index.resolve(), + ) + +class ImageVariable(BoundVariable): + dimensions: int = 0 + read_lambda: Callable[[], None] + write_lambda: Callable[[], None] + + def __init__(self, + var_type: dtypes.dtype, + binding: int, + dimensions: int, + name: str, + read_lambda: Callable[[], None] = None, + write_lambda: Callable[[], None] = None, + ) -> None: + super().__init__(var_type, binding, name) + + self.read_lambda = read_lambda + self.write_lambda = write_lambda + self.dimensions = dimensions + + def read_callback(self): + self.read_lambda() + + def write_callback(self): + self.write_lambda() + + def sample(self, coord: "ShaderVariable", lod: "ShaderVariable" = None) -> "ShaderVariable": + if self.dimensions == 0: + raise ValueError("Cannot sample a texture with dimension 0!") + + backend = get_codegen_backend() + backend.mark_texture_sample_dimension(self.dimensions) + + sample_coord_string = "" + + if self.dimensions == 1: + sample_coord_string = f"((({coord.resolve()}) + 0.5) / {backend.texture_size_expr(self.resolve(), 0, self.dimensions)})" + elif self.dimensions == 2: + coord_expr = backend.constructor( + dtypes.vec2, + [ + backend.component_access_expr(coord.resolve(), "x", coord.var_type), + backend.component_access_expr(coord.resolve(), "y", coord.var_type), + ] + ) + tex_size_expr = backend.constructor( + dtypes.vec2, + [backend.texture_size_expr(self.resolve(), 0, self.dimensions)] + ) + sample_coord_string = f"(({coord_expr} + 0.5) / {tex_size_expr})" + elif self.dimensions == 3: + coord_expr = backend.constructor( + dtypes.vec3, + [ + backend.component_access_expr(coord.resolve(), "x", coord.var_type), + backend.component_access_expr(coord.resolve(), "y", coord.var_type), + backend.component_access_expr(coord.resolve(), "z", coord.var_type), + ] + ) + tex_size_expr = backend.constructor( + dtypes.vec3, + [backend.texture_size_expr(self.resolve(), 0, self.dimensions)] + ) + sample_coord_string = f"(({coord_expr} + 0.5) / {tex_size_expr})" + else: + raise ValueError("Unsupported number of dimensions!") + + if lod is None: + return type_casting.str_to_dtype( + dtypes.vec4, + backend.sample_texture_expr(self.resolve(), sample_coord_string), + [self], + lexical_unit=True) + + return type_casting.str_to_dtype( + dtypes.vec4, + backend.sample_texture_expr(self.resolve(), sample_coord_string, lod.resolve()), + [self, lod], + lexical_unit=True) + diff --git a/vkdispatch/codegen/variables/variables.py b/vkdispatch/codegen/variables/variables.py new file mode 100644 index 00000000..e8e776ee --- /dev/null +++ b/vkdispatch/codegen/variables/variables.py @@ -0,0 +1,448 @@ +import vkdispatch.base.dtype as dtypes + +from .base_variable import BaseVariable + +from ..functions.base_functions import arithmetic +from ..functions.base_functions import bitwise +from ..functions.base_functions import arithmetic_comparisons +from ..functions.base_functions import base_utils +from ..global_builder import get_codegen_backend + +from typing import List, Union, Optional + +ENABLE_SCALED_AND_OFFSET_INT = True + +def var_types_to_floating(var_type: dtypes.dtype) -> dtypes.dtype: + return dtypes.make_floating_dtype(var_type) + +class ShaderVariable(BaseVariable): + _initilized: bool + is_complex: bool + is_conjugate: Optional[bool] + buffer_root: Optional["ShaderVariable"] + buffer_index_expr: Optional[str] + + def __init__(self, + var_type: dtypes.dtype, + name: Optional[str] = None, + raw_name: Optional[str] = None, + lexical_unit: bool = False, + settable: bool = False, + register: bool = False, + parents: List["ShaderVariable"] = None, + is_conjugate: bool = False, + buffer_root: Optional["ShaderVariable"] = None, + buffer_index_expr: Optional[str] = None, + ) -> None: + super().__setattr__("_initilized", False) + + super().__init__( + var_type, + name if name is not None else base_utils.new_name(), + raw_name, + lexical_unit, + settable, + register, + parents + ) + + self.is_complex = False + self.is_conjugate = None + self.buffer_root = buffer_root + self.buffer_index_expr = buffer_index_expr + + if dtypes.is_complex(self.var_type): + self.can_index = True + self.is_complex = True + self.is_conjugate = is_conjugate + + self.real = self.swizzle("x") + self.imag = self.swizzle("y") + + if is_conjugate: + self.imag = -self.imag + + elif dtypes.is_vector(self.var_type): + self.can_index = True + + self.x = self.swizzle("x") + if self.var_type.child_count >= 2: self.y = self.swizzle("y") + if self.var_type.child_count >= 3: self.z = self.swizzle("z") + if self.var_type.child_count == 4: self.w = self.swizzle("w") + elif dtypes.is_matrix(self.var_type): + self.can_index = True + + self._initilized = True + + def _buffer_component_expr(self, component_index_expr: str) -> Optional[str]: + if self.buffer_root is None or self.buffer_index_expr is None: + return None + + if not (dtypes.is_vector(self.var_type) or dtypes.is_complex(self.var_type)): + return None + + scalar_expr = getattr(self.buffer_root, "scalar_expr", None) + if scalar_expr is None: + return None + + backend = getattr(self.buffer_root, "codegen_backend", None) + if backend is None: + backend = get_codegen_backend() + + return backend.buffer_component_expr( + scalar_expr, + self.var_type, + self.buffer_index_expr, + component_index_expr, + ) + + def __getitem__(self, index) -> "ShaderVariable": + assert self.can_index, f"Variable '{self.resolve()}' of type '{self.var_type.name}' cannot be indexed into!" + + return_type = self.var_type.child_type if self.use_child_type else self.var_type + + if isinstance(index, tuple): + assert len(index) == 1, "Only single index is supported, cannot use multi-dimentional indexing!" + index = index[0] + + if base_utils.is_int_number(index): + component_expr = self._buffer_component_expr(str(index)) + if component_expr is not None: + return ShaderVariable( + return_type, + component_expr, + parents=[self], + settable=self.settable, + lexical_unit=True + ) + + return ShaderVariable(return_type, f"{self.resolve()}[{index}]", parents=[self], settable=self.settable, lexical_unit=True) + + assert isinstance(index, ShaderVariable), f"Index must be a ShaderVariable or int type, not {type(index)}!" + assert dtypes.is_scalar(index.var_type), "Indexing variable must be a scalar!" + assert dtypes.is_integer_dtype(index.var_type), "Indexing variable must be an integer type!" + + component_expr = self._buffer_component_expr(index.resolve()) + if component_expr is not None: + return ShaderVariable( + return_type, + component_expr, + parents=[self, index], + settable=self.settable, + lexical_unit=True + ) + + return ShaderVariable(return_type, f"{self.resolve()}[{index.resolve()}]", parents=[self, index], settable=self.settable, lexical_unit=True) + + def swizzle(self, components: str) -> "ShaderVariable": + assert dtypes.is_vector(self.var_type) or dtypes.is_complex(self.var_type) or dtypes.is_scalar(self.var_type), f"Variable '{self.resolve()}' of type '{self.var_type.name}' does not support swizzling!" + assert self.use_child_type, f"Variable '{self.resolve()}' does not support swizzling!" + + assert len(components) >= 1 and len(components) <= 4, f"Swizzle must have between 1 and 4 components, got {len(components)}!" + + for c in components: + assert c in ['x', 'y', 'z', 'w'], f"Invalid swizzle component '{c}'!" + + sample_type = self.var_type if dtypes.is_scalar(self.var_type) else self.var_type.child_type + return_type = sample_type if len(components) == 1 else dtypes.to_vector(sample_type, len(components)) + backend = get_codegen_backend() + base_expr = self.resolve() + + if dtypes.is_scalar(self.var_type): + assert all(c == 'x' for c in components), f"Cannot swizzle scalar variable '{self.resolve()}' with components other than 'x'!" + + scalar_x_expr = backend.component_access_expr(base_expr, "x", self.var_type) + swizzle_expr = scalar_x_expr + if len(components) > 1: + swizzle_expr = backend.constructor( + return_type, + [scalar_x_expr for _ in components] + ) + + return ShaderVariable( + var_type=return_type, + name=swizzle_expr, + parents=[self], + lexical_unit=True, + settable=self.settable and len(components) == 1, + register=self.register and len(components) == 1 + ) + + if self.var_type.shape[0] < 4: + assert 'w' not in components, f"Cannot swizzle variable '{self.resolve()}' of type '{self.var_type.name}' with component 'w'!" + + if self.var_type.shape[0] < 3: + assert 'z' not in components, f"Cannot swizzle variable '{self.resolve()}' of type '{self.var_type.name}' with component 'z'!" + + if self.var_type.shape[0] < 2: + assert 'y' not in components, f"Cannot swizzle variable '{self.resolve()}' of type '{self.var_type.name}' with component 'y'!" + + if len(components) == 1: + component_index = "xyzw".index(components) + component_expr = self._buffer_component_expr(str(component_index)) + if component_expr is not None: + return ShaderVariable( + var_type=return_type, + name=component_expr, + parents=[self], + lexical_unit=True, + settable=self.settable, + register=self.register + ) + + swizzle_expr = backend.component_access_expr(base_expr, components, self.var_type) + if len(components) > 1: + swizzle_expr = backend.constructor( + return_type, + [backend.component_access_expr(base_expr, elem, self.var_type) for elem in components] + ) + + return ShaderVariable( + var_type=return_type, + name=swizzle_expr, + parents=[self], + lexical_unit=True, + settable=self.settable and len(components) == 1, + register=self.register and len(components) == 1 + ) + + def conjugate(self) -> "ShaderVariable": + assert self.is_complex, f"Variable '{self.resolve()}' of type '{self.var_type.name}' is not a complex variable and cannot be conjugated!" + + return ShaderVariable( + var_type=self.var_type, + name=self.name, + raw_name=self.raw_name, + lexical_unit=self.lexical_unit, + settable=False, + register=False, + parents=[self], + is_conjugate=not self.is_conjugate + ) + + def set_value(self, value: "ShaderVariable") -> None: + assert self.settable, f"Cannot set value of '{self.resolve()}' because it is not a settable variable!" + + self.write_callback() + self.read_callback() + + if base_utils.is_number(value): + if dtypes.is_complex(self.var_type): + complex_value = complex(value) + complex_constructor = get_codegen_backend().constructor( + self.var_type, + [ + base_utils.format_number_literal(complex_value.real), + base_utils.format_number_literal(complex_value.imag), + ] + ) + base_utils.append_contents(f"{self.resolve()} = {complex_constructor};\n") + return + + base_utils.append_contents(f"{self.resolve()} = {base_utils.format_number_literal(value)};\n") + return + + assert self.var_type == value.var_type, f"Cannot set variable of type '{self.var_type.name}' to value of type '{value.var_type.name}'!" + value.read_callback() + + base_utils.append_contents(f"{self.resolve()} = {value.resolve()};\n") + + def __setitem__(self, index, value: "ShaderVariable") -> None: + assert self.settable, f"Cannot set value of '{self.resolve()}' because it is not a settable variable!" + + if isinstance(index, slice): + assert index.start is None and index.stop is None and index.step is None, "Only full slice (:) is supported!" + self.set_value(value) + return + + # ignore if setting variable to itself (happens in some inplace operations) + if f"{self.resolve()}[{index}]" == str(value): + return + + self[index].set_value(value) + + def __setattr__(self, name: str, value: "ShaderVariable") -> "ShaderVariable": + if not self._initilized: + super().__setattr__(name, value) + return + + if dtypes.is_complex(self.var_type) and (name == "real" or name == "imag"): + if name == "real": + self.real.set_value(value) + else: + self.imag.set_value(value) + + return + + if dtypes.is_complex(self.var_type) and (name == "x" or name == "y"): + raise ValueError(f"Cannot set attribute '{name}' of complex variable '{self.resolve()}', use 'real' and 'imag' instead!") + + if dtypes.is_vector(self.var_type) and (name == "x" or name == "y" or name == "z" or name == "w"): + if name == "x": + self.x.set_value(value) + elif name == "y": + self.y.set_value(value) + elif name == "z": + assert self.var_type.shape[0] >= 3, f"Variable '{self.resolve()}' of type '{self.var_type.name}' does not have 'z' component!" + self.z.set_value(value) + elif name == "w": + assert self.var_type.shape[0] == 4, f"Variable '{self.resolve()}' of type '{self.var_type.name}' does not have 'w' component!" + self.w.set_value(value) + return + + super().__setattr__(name, value) + + def __bool__(self) -> bool: + raise ValueError(f"Vkdispatch variables cannot be cast to a python boolean") + + def to_register(self, var_name: str = None) -> "ShaderVariable": + new_var = base_utils.new_base_var( + self.var_type, + var_name, + [], + lexical_unit=True, + settable=True, + register=True + ) + + self.read_callback() + base_utils.append_contents(f"{get_codegen_backend().type_name(new_var.var_type)} {new_var.name} = {self.resolve()};\n") + return new_var + + def to_dtype(self, var_type: dtypes.dtype) -> "ShaderVariable": + return base_utils.new_base_var( + var_type, + get_codegen_backend().constructor(var_type, [self.resolve()], arg_types=[self.var_type]), + [self], + lexical_unit=True + ) + + def __lt__(self, other) -> "ShaderVariable": return arithmetic_comparisons.less_than(self, other) + def __le__(self, other) -> "ShaderVariable": return arithmetic_comparisons.less_or_equal(self, other) + def __eq__(self, other) -> "ShaderVariable": return arithmetic_comparisons.equal_to(self, other) + def __ne__(self, other) -> "ShaderVariable": return arithmetic_comparisons.not_equal_to(self, other) + def __gt__(self, other) -> "ShaderVariable": return arithmetic_comparisons.greater_than(self, other) + def __ge__(self, other) -> "ShaderVariable": return arithmetic_comparisons.greater_or_equal(self, other) + + def __add__(self, other) -> "ShaderVariable": return arithmetic.add(self, other) + def __sub__(self, other) -> "ShaderVariable": return arithmetic.sub(self, other) + def __mul__(self, other) -> "ShaderVariable": return arithmetic.mul(self, other) + def __truediv__(self, other) -> "ShaderVariable": return arithmetic.truediv(self, other) + def __floordiv__(self, other) -> 'ShaderVariable': return arithmetic.floordiv(self, other) + def __mod__(self, other) -> "ShaderVariable": return arithmetic.mod(self, other) + def __pow__(self, other) -> "ShaderVariable": return arithmetic.pow(self, other) + def __neg__(self) -> "ShaderVariable": return arithmetic.neg(self) + def __abs__(self) -> "ShaderVariable": return arithmetic.absolute(self) + def __invert__(self) -> "ShaderVariable": return bitwise.invert(self) + def __lshift__(self, other) -> "ShaderVariable": return bitwise.lshift(self, other) + def __rshift__(self, other) -> "ShaderVariable": return bitwise.rshift(self, other) + def __and__(self, other) -> "ShaderVariable": return bitwise.and_bits(self, other) + def __xor__(self, other) -> "ShaderVariable": return bitwise.xor_bits(self, other) + def __or__(self, other) -> "ShaderVariable": return bitwise.or_bits(self, other) + + def __radd__(self, other) -> "ShaderVariable": return arithmetic.add(self, other) + def __rsub__(self, other) -> "ShaderVariable": return arithmetic.sub(self, other, reverse=True) + def __rmul__(self, other) -> "ShaderVariable": return arithmetic.mul(self, other) + def __rtruediv__(self, other) -> "ShaderVariable": return arithmetic.truediv(self, other, reverse=True) + def __rfloordiv__(self, other) -> "ShaderVariable": return arithmetic.floordiv(self, other, reverse=True) + def __rmod__(self, other) -> "ShaderVariable": return arithmetic.mod(self, other, reverse=True) + def __rpow__(self, other) -> "ShaderVariable": return arithmetic.pow(self, other, reverse=True) + def __rlshift__(self, other) -> "ShaderVariable": return bitwise.lshift(self, other, reverse=True) + def __rrshift__(self, other) -> "ShaderVariable": return bitwise.rshift(self, other, reverse=True) + def __rand__(self, other) -> "ShaderVariable": return bitwise.and_bits(self, other) + def __rxor__(self, other) -> "ShaderVariable": return bitwise.xor_bits(self, other) + def __ror__(self, other) -> "ShaderVariable": return bitwise.or_bits(self, other) + + def __iadd__(self, other) -> "ShaderVariable": return arithmetic.add(self, other, inplace=True) + def __isub__(self, other) -> "ShaderVariable": return arithmetic.sub(self, other, inplace=True) + def __imul__(self, other) -> "ShaderVariable": return arithmetic.mul(self, other, inplace=True) + def __itruediv__(self, other) -> "ShaderVariable": return arithmetic.truediv(self, other, inplace=True) + def __ifloordiv__(self, other) -> "ShaderVariable": return arithmetic.floordiv(self, other, inplace=True) + def __imod__(self, other) -> "ShaderVariable": return arithmetic.mod(self, other, inplace=True) + def __ipow__(self, other) -> "ShaderVariable": return arithmetic.pow(self, other, inplace=True) + def __ilshift__(self, other) -> "ShaderVariable": return bitwise.lshift(self, other, inplace=True) + def __irshift__(self, other) -> "ShaderVariable": return bitwise.rshift(self, other, inplace=True) + def __iand__(self, other) -> "ShaderVariable": return bitwise.and_bits(self, other, inplace=True) + def __ixor__(self, other) -> "ShaderVariable": return bitwise.xor_bits(self, other, inplace=True) + def __ior__(self, other) -> "ShaderVariable": return bitwise.or_bits(self, other, inplace=True) + +class ScaledAndOfftsetIntVariable(ShaderVariable): + def __init__(self, + var_type: dtypes.dtype, + name: str, + scale: int = 1, + offset: int = 0, + parents: List["ShaderVariable"] = None + ) -> None: + # ShaderVariable.__init__ eagerly creates vector swizzles (`x`, `y`, ...), + # which call resolve() during construction. Pre-seed these fields so + # ScaledAndOfftsetIntVariable.resolve() is safe before super().__init__ completes. + object.__setattr__(self, "base_name", str(name)) + object.__setattr__(self, "scale", scale) + object.__setattr__(self, "offset", offset) + super().__init__(var_type, name, parents=parents) + + def new_from_self(self, scale: int = 1, offset: int = 0): + child_vartype = self.var_type + + if base_utils.is_float_number(scale) or base_utils.is_float_number(offset): + child_vartype = var_types_to_floating(self.var_type) + + return ScaledAndOfftsetIntVariable( + child_vartype, + f"{self.name}", + scale=self.scale * scale, + offset=offset + self.offset * scale, + parents=self.parents + ) + + def resolve(self) -> str: + scale_str = ( + f" * {base_utils.format_number_literal(self.scale)}" + if self.scale != 1 else "" + ) + offset_str = ( + f" + {base_utils.format_number_literal(self.offset)}" + if self.offset != 0 else "" + ) + + if scale_str == "" and offset_str == "": + return self.base_name + + return f"({self.base_name}{scale_str}{offset_str})" + + def __add__(self, other) -> "Union[ShaderVariable, ScaledAndOfftsetIntVariable]": + if base_utils.is_scalar_number(other): + return self.new_from_self(offset=other) + + return super().__add__(other) + + def __sub__(self, other): + if isinstance(other, ShaderVariable): + return super().__sub__(other) + + return self.new_from_self(offset=-other) + + def __mul__(self, other): + if isinstance(other, ShaderVariable): + return super().__mul__(other) + + return self.new_from_self(scale=other) + + def __radd__(self, other): + if isinstance(other, ShaderVariable): + return super().__radd__(other) + + return self.new_from_self(offset=other) + + def __rsub__(self, other): + if isinstance(other, ShaderVariable): + return super().__rsub__(other) + + return self.new_from_self(offset=other, scale=-1) + + def __rmul__(self, other): + if isinstance(other, ShaderVariable): + return super().__rmul__(other) + + return self.new_from_self(scale=other) diff --git a/vkdispatch/compat/__init__.py b/vkdispatch/compat/__init__.py new file mode 100644 index 00000000..bb0d094a --- /dev/null +++ b/vkdispatch/compat/__init__.py @@ -0,0 +1,2 @@ +"""Compatibility helpers for optional runtime dependencies.""" + diff --git a/vkdispatch/compat/numpy_compat.py b/vkdispatch/compat/numpy_compat.py new file mode 100644 index 00000000..7d42ab43 --- /dev/null +++ b/vkdispatch/compat/numpy_compat.py @@ -0,0 +1,364 @@ +from __future__ import annotations + +import builtins +import cmath +import math +import struct + +from dataclasses import dataclass +from typing import Any, Iterable, List, Sequence, Tuple + +try: + import numpy as _np +except Exception: # pragma: no cover - intentionally broad for optional dependency import + _np = None + +HAS_NUMPY = _np is not None +pi = math.pi + + +def require_numpy(feature_name: str) -> None: + if HAS_NUMPY: + return + + raise RuntimeError( + f"{feature_name} requires numpy, but numpy is not available. " + "Install numpy or use the bytes-based API." + ) + + +def numpy_module(): + return _np + + +def prod(values: Iterable[int]) -> int: + values_tuple = tuple(values) + + if HAS_NUMPY: + return int(_np.prod(values_tuple)) + + result = 1 + for value in values_tuple: + result *= int(value) + return result + + +def ceil(value: float) -> float: + if HAS_NUMPY: + return float(_np.ceil(value)) + return float(math.ceil(value)) + +def round(value: float) -> float: + if HAS_NUMPY: + return float(_np.round(value)) + return float(builtins.round(value)) + +def angle(value: complex) -> float: + if HAS_NUMPY: + return float(_np.angle(value)) + return float(cmath.phase(value)) + + +def exp_complex(value: complex) -> complex: + if HAS_NUMPY: + return complex(_np.exp(value)) + return cmath.exp(value) + + +def is_numpy_integer_scalar(value: Any) -> bool: + return bool(HAS_NUMPY and _np.issubdtype(type(value), _np.integer)) + + +def is_integer_scalar(value: Any) -> bool: + return isinstance(value, int) or is_numpy_integer_scalar(value) + + +def is_numpy_floating_instance(value: Any) -> bool: + return bool(HAS_NUMPY and isinstance(value, _np.floating)) + + +@dataclass(frozen=True) +class HostDType: + name: str + itemsize: int + struct_format: str + kind: str + + +INT16 = HostDType("int16", 2, "h", "int") +UINT16 = HostDType("uint16", 2, "H", "uint") +INT32 = HostDType("int32", 4, "i", "int") +UINT32 = HostDType("uint32", 4, "I", "uint") +INT64 = HostDType("int64", 8, "q", "int") +UINT64 = HostDType("uint64", 8, "Q", "uint") +FLOAT16 = HostDType("float16", 2, "e", "float") +FLOAT32 = HostDType("float32", 4, "f", "float") +FLOAT64 = HostDType("float64", 8, "d", "float") +COMPLEX32 = HostDType("complex32", 4, "ee", "complex") +COMPLEX64 = HostDType("complex64", 8, "ff", "complex") +COMPLEX128 = HostDType("complex128", 16, "dd", "complex") + +_HOST_DTYPES = { + "int16": INT16, + "uint16": UINT16, + "int32": INT32, + "uint32": UINT32, + "int64": INT64, + "uint64": UINT64, + "float16": FLOAT16, + "float32": FLOAT32, + "float64": FLOAT64, + "complex32": COMPLEX32, + "complex64": COMPLEX64, + "complex128": COMPLEX128, +} + + +def host_dtype(name: str) -> HostDType: + if name not in _HOST_DTYPES: + raise ValueError(f"Unsupported dtype ({name})!") + return _HOST_DTYPES[name] + + +def is_host_dtype(value: Any) -> bool: + return isinstance(value, HostDType) + + +def host_dtype_name(dtype: Any) -> str: + if isinstance(dtype, HostDType): + return dtype.name + + if isinstance(dtype, str): + return dtype + + if HAS_NUMPY: + return str(_np.dtype(dtype).name) + + raise ValueError(f"Unsupported dtype ({dtype})!") + + +def _numpy_dtype_or_none(dtype_name: str): + if not HAS_NUMPY: + return None + + try: + return _np.dtype(dtype_name) + except TypeError: + return None + + +def dtype_itemsize(dtype: Any) -> int: + if isinstance(dtype, HostDType): + return dtype.itemsize + + if HAS_NUMPY: + return int(_np.dtype(dtype).itemsize) + + return host_dtype(host_dtype_name(dtype)).itemsize + + +def dtype_kind(dtype: Any) -> str: + if isinstance(dtype, HostDType): + return dtype.kind + + if HAS_NUMPY: + dtype_obj = _np.dtype(dtype) + if _np.issubdtype(dtype_obj, _np.complexfloating): + return "complex" + if _np.issubdtype(dtype_obj, _np.unsignedinteger): + return "uint" + if _np.issubdtype(dtype_obj, _np.integer): + return "int" + if _np.issubdtype(dtype_obj, _np.floating): + return "float" + + return host_dtype(host_dtype_name(dtype)).kind + + +def dtype_struct_format(dtype: Any) -> str: + if isinstance(dtype, HostDType): + return dtype.struct_format + return host_dtype(host_dtype_name(dtype)).struct_format + + +class CompatArray: + def __init__(self, buffer: bytes, dtype: HostDType, shape: Tuple[int, ...]): + self._buffer = bytes(buffer) + self.dtype = dtype + self.shape = tuple(shape) + self.size = prod(self.shape) + + def reshape(self, shape: Tuple[int, ...]) -> "CompatArray": + shape = tuple(shape) + if prod(shape) != self.size: + raise ValueError("Cannot reshape array with mismatched element count") + return CompatArray(self._buffer, self.dtype, shape) + + def tobytes(self) -> bytes: + return bytes(self._buffer) + + @property + def nbytes(self) -> int: + return len(self._buffer) + + def __repr__(self) -> str: + return f"CompatArray(shape={self.shape}, dtype={self.dtype.name}, nbytes={len(self._buffer)})" + + +def is_array_like(value: Any) -> bool: + if HAS_NUMPY and isinstance(value, _np.ndarray): + return True + return isinstance(value, CompatArray) + + +def array_shape(value: Any) -> Tuple[int, ...]: + if HAS_NUMPY and isinstance(value, _np.ndarray): + return tuple(value.shape) + if isinstance(value, CompatArray): + return tuple(value.shape) + raise TypeError(f"Unsupported array-like value ({type(value)})") + + +def array_dtype(value: Any) -> Any: + if HAS_NUMPY and isinstance(value, _np.ndarray): + return value.dtype + if isinstance(value, CompatArray): + return value.dtype + raise TypeError(f"Unsupported array-like value ({type(value)})") + + +def array_nbytes(value: Any) -> int: + if HAS_NUMPY and isinstance(value, _np.ndarray): + return int(value.size * value.dtype.itemsize) + if isinstance(value, CompatArray): + return value.nbytes + raise TypeError(f"Unsupported array-like value ({type(value)})") + + +def as_contiguous_bytes(value: Any) -> bytes: + if HAS_NUMPY and isinstance(value, _np.ndarray): + return _np.ascontiguousarray(value).tobytes() + if isinstance(value, CompatArray): + return value.tobytes() + raise TypeError(f"Unsupported array-like value ({type(value)})") + + +def from_buffer(buffer: bytes, dtype: Any, shape: Tuple[int, ...]): + dtype_name = host_dtype_name(dtype) + + if HAS_NUMPY: + np_dtype = _numpy_dtype_or_none(dtype_name) + if np_dtype is not None: + return _np.frombuffer(buffer, dtype=np_dtype).reshape(shape) + + if dtype_name == "complex32": + half_pairs = _np.frombuffer(buffer, dtype=_np.float16).reshape(*shape, 2) + return half_pairs[..., 0].astype(_np.float32) + (1j * half_pairs[..., 1].astype(_np.float32)) + + return CompatArray(buffer, host_dtype(dtype_name), tuple(shape)) + + +def ensure_bytes(value: Any) -> bytes: + if isinstance(value, bytes): + return value + if isinstance(value, bytearray): + return bytes(value) + if isinstance(value, memoryview): + return value.tobytes() + raise TypeError(f"Unsupported bytes-like object ({type(value)})") + + +def is_bytes_like(value: Any) -> bool: + return isinstance(value, (bytes, bytearray, memoryview)) + + +def flatten(value: Any) -> List[Any]: + if isinstance(value, CompatArray): + return unpack_values(value.tobytes(), value.dtype) + + if HAS_NUMPY and isinstance(value, _np.ndarray): + return value.reshape(-1).tolist() + + if isinstance(value, (list, tuple)): + out: List[Any] = [] + for element in value: + out.extend(flatten(element)) + return out + + return [value] + + +def _coerce_scalar(value: Any, dtype: Any): + kind = dtype_kind(dtype) + + if kind == "complex": + if isinstance(value, complex): + return value + if isinstance(value, (list, tuple)): + if len(value) != 2: + raise ValueError("Complex values must be complex scalars or pairs") + return complex(float(value[0]), float(value[1])) + return complex(value) + + if kind == "float": + return float(value) + + if kind in ("int", "uint"): + return int(value) + + raise ValueError(f"Unsupported dtype kind ({kind})") + + +def pack_values(values: Sequence[Any], dtype: Any) -> bytes: + values_list = list(values) + dtype_name = host_dtype_name(dtype) + + if HAS_NUMPY: + np_dtype = _numpy_dtype_or_none(dtype_name) + if np_dtype is not None: + array = _np.asarray(values_list, dtype=np_dtype) + return array.tobytes() + + host = host_dtype(dtype_name) + + if host.kind == "complex": + output = bytearray() + pack_fmt = "=" + host.struct_format + for value in values_list: + coerced = _coerce_scalar(value, host) + output.extend(struct.pack(pack_fmt, float(coerced.real), float(coerced.imag))) + return bytes(output) + + pack_fmt = "=" + host.struct_format + output = bytearray() + for value in values_list: + output.extend(struct.pack(pack_fmt, _coerce_scalar(value, host))) + return bytes(output) + + +def unpack_values(data: bytes, dtype: Any) -> List[Any]: + dtype_name = host_dtype_name(dtype) + + if HAS_NUMPY: + np_dtype = _numpy_dtype_or_none(dtype_name) + if np_dtype is not None: + return _np.frombuffer(data, dtype=np_dtype).tolist() + + host = host_dtype(dtype_name) + + if host.kind == "complex": + values: List[Any] = [] + unpack_fmt = "=" + host.struct_format + for real, imag in struct.iter_unpack(unpack_fmt, data): + values.append(complex(real, imag)) + return values + + unpack_fmt = "=" + host.struct_format + stride = struct.calcsize(unpack_fmt) + values = [] + + for offset in range(0, len(data), stride): + values.append(struct.unpack(unpack_fmt, data[offset: offset + stride])[0]) + + return values + diff --git a/vkdispatch/execution_pipeline/__init__.py b/vkdispatch/execution_pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vkdispatch/execution_pipeline/buffer_builder.py b/vkdispatch/execution_pipeline/buffer_builder.py index 20b39787..01418bae 100644 --- a/vkdispatch/execution_pipeline/buffer_builder.py +++ b/vkdispatch/execution_pipeline/buffer_builder.py @@ -1,44 +1,41 @@ import dataclasses +import enum +from typing import Any from typing import Dict from typing import List +from typing import Optional from typing import Tuple from typing import Union -from typing import Optional - -import enum - -import numpy as np import vkdispatch as vd import vkdispatch.codegen as vc +from ..compat import numpy_compat as npc +from vkdispatch.base.dtype import to_numpy_dtype + + @dataclasses.dataclass class BufferedStructEntry: memory_slice: slice - dtype: Optional[np.dtype] + dtype: Optional[Any] shape: Tuple[int, ...] + class BufferUsage(enum.Enum): PUSH_CONSTANT = 0 UNIFORM_BUFFER = 1 + class BufferBuilder: """ A class for building buffers in memory that can be submitted to a compute pipeline. - - Attributes: - struct_alignment (int): The alignment of the struct in the buffer. - instance_bytes (int): The size of the struct in bytes. - instance_count (int): The number of instances of the struct. - backing_buffer (np.ndarray): The backing buffer for the struct. - element_map (Dict[Tuple[str, str], BufferedStructEntry]): A map of the elements in the """ struct_alignment: int = -1 instance_bytes: int = 0 instance_count: int = 0 - backing_buffer: np.ndarray = None + backing_buffer: Any = None element_map: Dict[Tuple[str, str], BufferedStructEntry] @@ -52,54 +49,52 @@ def __init__(self, struct_alignment: Optional[int] = None, usage: Optional[Buffe struct_alignment = vd.get_context().uniform_buffer_alignment else: raise ValueError("Invalid buffer usage!") - + self.struct_alignment = struct_alignment self.reset() - + def reset(self) -> None: self.instance_bytes = 0 self.instance_count = 0 self.backing_buffer = None self.element_map = {} - + def register_struct(self, name: str, elements: List[vc.StructElement]) -> Tuple[int, int]: offset = self.instance_bytes for elem in elements: - np_dtype = np.dtype(vd.to_numpy_dtype(elem.dtype if elem.dtype.scalar is None else elem.dtype.scalar)) + elem_dtype = elem.dtype if elem.dtype.scalar is None else elem.dtype.scalar + host_dtype = to_numpy_dtype(elem_dtype) - np_shape = elem.dtype.numpy_shape + host_shape = elem.dtype.numpy_shape if elem.count > 1: - if np_shape == (1, ): - np_shape = (elem.count,) + if host_shape == (1,): + host_shape = (elem.count,) else: - np_shape = (elem.count, *np_shape) - - element_size = np_dtype.itemsize * np.prod(np_shape) + host_shape = (elem.count, *host_shape) + + element_size = npc.dtype_itemsize(host_dtype) * npc.prod(host_shape) self.element_map[(name, elem.name)] = BufferedStructEntry( slice(self.instance_bytes, self.instance_bytes + element_size), - np_dtype, - np_shape + host_dtype, + host_shape, ) self.instance_bytes += element_size - + if self.struct_alignment != 0: - padded_size = int(np.ceil(self.instance_bytes / self.struct_alignment)) * self.struct_alignment + padded_size = ((self.instance_bytes + self.struct_alignment - 1) // self.struct_alignment) * self.struct_alignment if padded_size != self.instance_bytes: self.instance_bytes = padded_size - + return offset, self.instance_bytes - offset - def __setitem__( - self, key: Tuple[str, str], value: Union[np.ndarray, list, tuple, int, float] - ) -> None: - if key not in self.element_map: - raise ValueError(f"Invalid buffer element name '{key}'!") + def _setitem_numpy(self, key: Tuple[str, str], value: Any) -> None: + np = npc.numpy_module() buffer_element = self.element_map[key] @@ -129,7 +124,7 @@ def __setitem__( raise ValueError( f"The shape of {key} is {buffer_element.shape} but a scalar was given!" ) - + if len(buffer_element.shape) > 1: (self.backing_buffer[:, buffer_element.memory_slice]).view(buffer_element.dtype).reshape(-1, *buffer_element.shape)[:] = arr else: @@ -149,21 +144,143 @@ def __setitem__( else: (self.backing_buffer[0, buffer_element.memory_slice]).view(buffer_element.dtype)[:] = arr -# def __repr__(self) -> str: -# result = "Push Constant Buffer:\n" -# -# for elem in self.elements: -# result += f"\t{elem.name} ({elem.dtype.name}): {self.numpy_arrays[elem.index]}\n" -# -# return result[:-1] + def _write_payload(self, instance_index: int, element_slice: slice, payload: bytes) -> None: + expected_size = element_slice.stop - element_slice.start + + if len(payload) != expected_size: + raise ValueError(f"Packed value size mismatch! Expected {expected_size}, got {len(payload)}") + + if npc.HAS_NUMPY: + np = npc.numpy_module() + row = self.backing_buffer[instance_index] + row[element_slice] = np.frombuffer(payload, dtype=np.uint8) + return + + start = instance_index * self.instance_bytes + element_slice.start + end = start + expected_size + + self.backing_buffer[start:end] = payload + + def _pack_single_instance_value(self, value: Any, key: Tuple[str, str], buffer_element: BufferedStructEntry) -> bytes: + expected_element_count = npc.prod(buffer_element.shape) + flat_values = npc.flatten(value) + + if expected_element_count == 1 and len(flat_values) == 0: + raise ValueError(f"The shape of {key} is {buffer_element.shape} but no value was given!") + + if len(flat_values) != expected_element_count: + raise ValueError( + f"The shape of {key} is {buffer_element.shape} but {len(flat_values)} elements were given!" + ) + + return npc.pack_values(flat_values, buffer_element.dtype) + + def _setitem_python(self, key: Tuple[str, str], value: Any) -> None: + buffer_element = self.element_map[key] + + if self.instance_count == 1: + payload = self._pack_single_instance_value(value, key, buffer_element) + self._write_payload(0, buffer_element.memory_slice, payload) + return + + # Broadcast scalar values across all instances for scalar fields. + if not isinstance(value, (list, tuple)) and not npc.is_array_like(value) and buffer_element.shape == (1,): + payload = self._pack_single_instance_value([value], key, buffer_element) + for instance_index in range(self.instance_count): + self._write_payload(instance_index, buffer_element.memory_slice, payload) + return + + expected_element_count = npc.prod(buffer_element.shape) + + if npc.is_array_like(value): + flat_values = npc.flatten(value) + expected_total = expected_element_count * self.instance_count + + if len(flat_values) != expected_total: + raise ValueError( + f"The shape of {key} is {(self.instance_count, *buffer_element.shape)} but {len(flat_values)} elements were given!" + ) + + for instance_index in range(self.instance_count): + instance_values = flat_values[ + instance_index * expected_element_count: (instance_index + 1) * expected_element_count + ] + payload = npc.pack_values(instance_values, buffer_element.dtype) + self._write_payload(instance_index, buffer_element.memory_slice, payload) + return + + if not isinstance(value, (list, tuple)): + raise ValueError( + f"The shape of {key} is {(self.instance_count, *buffer_element.shape)} but a scalar was given!" + ) + + if len(value) != self.instance_count: + raise ValueError(f"Invalid shape for {key}! Expected {self.instance_count} but got {len(value)}!") + + for instance_index in range(self.instance_count): + payload = self._pack_single_instance_value(value[instance_index], key, buffer_element) + self._write_payload(instance_index, buffer_element.memory_slice, payload) + + def __setitem__( + self, key: Tuple[str, str], value: Union[Any, list, tuple, int, float] + ) -> None: + if key not in self.element_map: + raise ValueError(f"Invalid buffer element name '{key}'!") + + if self.backing_buffer is None: + raise RuntimeError("BufferBuilder.prepare(...) must be called before assigning values") + + buffer_element = self.element_map[key] + + if npc.HAS_NUMPY and not npc.is_host_dtype(buffer_element.dtype): + self._setitem_numpy(key, value) + return + + self._setitem_python(key, value) + + def __repr__(self) -> str: + result = "Push Constant Buffer:\n" + + for key, elem in self.element_map.items(): + buffer_element = self.element_map[key] + + if npc.HAS_NUMPY and not npc.is_host_dtype(buffer_element.dtype): + value = (self.backing_buffer[:, buffer_element.memory_slice]).view(buffer_element.dtype) + else: + decoded_instances = [] + + for instance_index in range(self.instance_count): + start = instance_index * self.instance_bytes + buffer_element.memory_slice.start + end = instance_index * self.instance_bytes + buffer_element.memory_slice.stop + raw = bytes(self.backing_buffer[start:end]) + decoded = npc.unpack_values(raw, buffer_element.dtype) + decoded_instances.append(decoded if len(decoded) > 1 else decoded[0]) + + value = decoded_instances + + result += f"\t{key[0]}, {key[1]} ({elem.dtype}): {value}\n" + + return result[:-1] def prepare(self, instance_count: int) -> None: if self.instance_count != instance_count: self.instance_count = instance_count - self.backing_buffer = np.zeros((self.instance_count, self.instance_bytes), dtype=np.uint8) - + + if npc.HAS_NUMPY: + np = npc.numpy_module() + self.backing_buffer = np.zeros((self.instance_count, self.instance_bytes), dtype=np.uint8) + else: + self.backing_buffer = bytearray(self.instance_count * self.instance_bytes) + def toints(self): - return self.backing_buffer.view(np.uint32) - + if npc.HAS_NUMPY: + np = npc.numpy_module() + return self.backing_buffer.view(np.uint32) + + return npc.from_buffer(bytes(self.backing_buffer), dtype=npc.host_dtype("uint32"), shape=(len(self.backing_buffer) // 4,)) + def tobytes(self): - return self.backing_buffer.tobytes() + if npc.HAS_NUMPY: + return self.backing_buffer.tobytes() + + return bytes(self.backing_buffer) diff --git a/vkdispatch/execution_pipeline/command_graph.py b/vkdispatch/execution_pipeline/command_graph.py index 747572fd..efdfc40f 100644 --- a/vkdispatch/execution_pipeline/command_graph.py +++ b/vkdispatch/execution_pipeline/command_graph.py @@ -1,23 +1,26 @@ from typing import Any -from typing import Callable from typing import List from typing import Dict -from typing import Union -from typing import Tuple -from typing import Optional +from typing import Tuple, Optional import uuid - -import numpy as np +import threading import vkdispatch as vd import vkdispatch.codegen as vc +from vkdispatch.base.command_list import CommandList +from vkdispatch.base.compute_plan import ComputePlan +from vkdispatch.base.descriptor_set import DescriptorSet + from .buffer_builder import BufferUsage from .buffer_builder import BufferBuilder import dataclasses +def _runtime_supports_push_constants() -> bool: + return True + @dataclasses.dataclass class BufferBindInfo: """A dataclass to hold information about a buffer binding.""" @@ -35,8 +38,20 @@ class ImageBindInfo: read_access: bool write_access: bool -class CommandGraph(vd.CommandList): - """TODO: Docstring""" +class CommandGraph(CommandList): + """ + A high-level abstraction over ``CommandList`` that manages resource binding and push constants automatically. + + Unlike a raw ``CommandList``, a ``CommandGraph`` tracks variable state and handles the + complexities of ``BufferBuilder`` for push constants and uniform buffers. It serves + as the default recording target for shader functions. + + :param reset_on_submit: If True, the graph clears its recorded commands immediately after submission. + :type reset_on_submit: bool + :param submit_on_record: If True, commands are submitted to the GPU immediately upon recording + (simulating immediate mode execution). + :type submit_on_record: bool + """ _reset_on_submit: bool submit_on_record: bool @@ -51,9 +66,10 @@ class CommandGraph(vd.CommandList): uniform_bindings: Any uniform_constants_size: int - uniform_constants_buffer: vd.Buffer + uniform_constants_buffer: Optional[vd.Buffer] - uniform_descriptors: List[Tuple[vd.DescriptorSet, int, int]] + uniform_descriptors: List[Tuple[DescriptorSet, int, int]] + recorded_descriptor_sets: List[DescriptorSet] name_to_pc_key_dict: Dict[str, List[Tuple[str, str]]] queued_pc_values: Dict[Tuple[str, str], Any] @@ -72,12 +88,75 @@ def __init__(self, reset_on_submit: bool = False, submit_on_record: bool = False self.queued_pc_values = {} self.uniform_descriptors = [] + self.recorded_descriptor_sets = [] self._reset_on_submit = reset_on_submit self.submit_on_record = submit_on_record + # Lazily allocate host-uploaded UBO backing only when needed by non-CUDA backends. self.uniform_constants_size = 0 - self.uniform_constants_buffer = vd.Buffer(shape=(4096,), var_type=vd.uint32) # Create a base static constants buffer at size 4k bytes + self.uniform_constants_buffer = None + + def _ensure_uniform_constants_capacity(self, uniform_word_size: int) -> None: + if self.uniform_constants_buffer is not None and uniform_word_size <= self.uniform_constants_size: + return + + # Grow exponentially to reduce reallocation churn for larger UBO layouts. + if self.uniform_constants_size == 0: + self.uniform_constants_size = max(4096, uniform_word_size) + else: + self.uniform_constants_size = max(uniform_word_size, self.uniform_constants_size * 2) + self.uniform_constants_buffer = vd.Buffer(shape=(self.uniform_constants_size,), var_type=vd.uint32) + + def _prepare_submission_state(self, instance_count: int) -> None: + if len(self.pc_builder.element_map) > 0 and ( + self.pc_builder.instance_count != instance_count or not self.buffers_valid + ): + + assert _runtime_supports_push_constants(), ( + "Push constants not supported for backends without push-constant support " + "(OpenCL). Use UBO-backed variables instead." + ) + + self.pc_builder.prepare(instance_count) + + for key, value in self.pc_values.items(): + self.pc_builder[key] = value + + if len(self.uniform_builder.element_map) > 0 and not self.buffers_valid: + self.uniform_builder.prepare(1) + + for key, value in self.uniform_values.items(): + self.uniform_builder[key] = value + + uniform_word_size = (self.uniform_builder.instance_bytes + 3) // 4 + uniform_payload = self.uniform_builder.tobytes() + + if vd.is_cuda(): + for descriptor_set, offset, size in self.uniform_descriptors: + descriptor_set.set_inline_uniform_payload(uniform_payload[offset:offset + size]) + else: + self._ensure_uniform_constants_capacity(uniform_word_size) + assert self.uniform_constants_buffer is not None + + for descriptor_set, offset, size in self.uniform_descriptors: + descriptor_set.bind_buffer(self.uniform_constants_buffer, 0, offset, size, True, write_access=False) + + self.uniform_constants_buffer.write(uniform_payload) + + if not self.buffers_valid: + self.buffers_valid = True + + def prepare_for_cuda_graph_capture(self, instance_count: int = None) -> None: + """Initialize internal data uploads before torch CUDA graph capture. + + This method performs one-time uniform/push-constant staging without submitting + the command list, so only kernel launches are captured by ``torch.cuda.graph``. + """ + if instance_count is None: + instance_count = 1 + + self._prepare_submission_state(instance_count) def reset(self) -> None: """Reset the command graph by clearing the push constant buffer and descriptor @@ -88,15 +167,29 @@ def reset(self) -> None: self.pc_builder.reset() self.uniform_builder.reset() - self.pc_values = {} - self.uniform_values = {} - self.name_to_pc_key_dict = {} - self.queued_pc_values = {} + for descriptor_set in self.recorded_descriptor_sets: + descriptor_set.destroy() + + self.pc_values.clear() + self.uniform_values.clear() + self.name_to_pc_key_dict.clear() + self.queued_pc_values.clear() + self.uniform_descriptors.clear() + self.recorded_descriptor_sets.clear() - self.uniform_descriptors = [] self.buffers_valid = False + + def _destroy(self) -> None: + self.reset() + super()._destroy() def bind_var(self, name: str): + if not _runtime_supports_push_constants(): + raise RuntimeError( + "CommandGraph.bind_var() is disabled for backends without push-constant " + "support (OpenCL). Pass Variable values directly at shader invocation." + ) + def register_var(key: Tuple[str, str]): if not name in self.name_to_pc_key_dict.keys(): self.name_to_pc_key_dict[name] = [] @@ -106,6 +199,12 @@ def register_var(key: Tuple[str, str]): return register_var def set_var(self, name: str, value: Any): + if not _runtime_supports_push_constants(): + raise RuntimeError( + "CommandGraph.set_var() is disabled for backends without push-constant " + "support (OpenCL). Pass Variable values directly at shader invocation." + ) + if name not in self.name_to_pc_key_dict.keys(): raise ValueError("Variable not bound!") @@ -113,7 +212,7 @@ def set_var(self, name: str, value: Any): self.queued_pc_values[key] = value def record_shader(self, - plan: vd.ComputePlan, + plan: ComputePlan, shader_description: vc.ShaderDescription, exec_limits: Tuple[int, int, int], blocks: Tuple[int, int, int], @@ -123,19 +222,55 @@ def record_shader(self, pc_values: Dict[str, Any] = {}, shader_uuid: str = None ) -> None: - descriptor_set = vd.DescriptorSet(plan) + """ + Internal method to record a high-level shader execution. + + This method handles the creation of ``DescriptorSet`` objects, binding of buffers + and images, and populating push constant/uniform data before calling the base + ``record_compute_plan``. + + :param plan: The compute plan to execute. + :param shader_description: Metadata about the shader source and layout. + :param exec_limits: The execution limits (grid size) in x, y, z. + :param blocks: The number of workgroups to dispatch. + :param bound_buffers: List of buffers to bind. + :param bound_samplers: List of images/samplers to bind. + :param uniform_values: Dictionary of values for uniform buffer objects. + :param pc_values: Dictionary of values for push constants. + :param shader_uuid: Unique identifier for this shader instance (for caching). + """ + + descriptor_set = DescriptorSet(plan) + self.recorded_descriptor_sets.append(descriptor_set) if shader_uuid is None: shader_uuid = shader_description.name + "_" + str(uuid.uuid4()) + if (not _runtime_supports_push_constants()) and len(pc_values) > 0: + raise RuntimeError( + "Push-constant Variable payloads are disabled for backends without " + "push-constant support (OpenCL). " + "Variable values must be UBO-backed and provided at shader invocation." + ) + if len(shader_description.pc_structure) != 0: + if not _runtime_supports_push_constants(): + raise RuntimeError( + "Kernels should not emit push-constant layouts for backends without " + "push-constant support (OpenCL). Use UBO-backed variables." + ) self.pc_builder.register_struct(shader_uuid, shader_description.pc_structure) - - uniform_offset, uniform_range = self.uniform_builder.register_struct(shader_uuid, shader_description.uniform_structure) - self.uniform_descriptors.append((descriptor_set, uniform_offset, uniform_range)) + uniform_field_names = {elem.name for elem in shader_description.uniform_structure} + resolved_uniform_values: Dict[Tuple[str, str], Any] = {} - self.uniform_values[(shader_uuid, shader_description.exec_count_name)] = [exec_limits[0], exec_limits[1], exec_limits[2], 0] + if shader_description.exec_count_name is not None: + resolved_uniform_values[(shader_uuid, shader_description.exec_count_name)] = [ + exec_limits[0], + exec_limits[1], + exec_limits[2], + 0, + ] for buffer_bind_info in bound_buffers: descriptor_set.bind_buffer( @@ -145,7 +280,8 @@ def record_shader(self, write_access=buffer_bind_info.write_access, ) - self.uniform_values[(shader_uuid, buffer_bind_info.shape_name)] = buffer_bind_info.buffer.shader_shape + if buffer_bind_info.shape_name in uniform_field_names: + resolved_uniform_values[(shader_uuid, buffer_bind_info.shape_name)] = buffer_bind_info.buffer.shader_shape for sampler_bind_info in bound_samplers: descriptor_set.bind_sampler( @@ -156,7 +292,14 @@ def record_shader(self, ) for key, value in uniform_values.items(): - self.uniform_values[(shader_uuid, key)] = value + resolved_uniform_values[(shader_uuid, key)] = value + + if len(shader_description.uniform_structure) > 0: + uniform_offset, uniform_range = self.uniform_builder.register_struct(shader_uuid, shader_description.uniform_structure) + self.uniform_descriptors.append((descriptor_set, uniform_offset, uniform_range)) + + for key, value in resolved_uniform_values.items(): + self.uniform_values[key] = value for key, value in pc_values.items(): self.pc_values[(shader_uuid, key)] = value @@ -164,11 +307,15 @@ def record_shader(self, super().record_compute_plan(plan, descriptor_set, blocks) self.buffers_valid = False - + if self.submit_on_record: self.submit() - - def submit(self, instance_count: int = None, queue_index: int = -2) -> None: + + def submit( + self, + instance_count: int = None, + queue_index: int = -2 + ) -> None: """Submit the command list to the specified device with additional data to append to the front of the command list. @@ -180,30 +327,8 @@ def submit(self, instance_count: int = None, queue_index: int = -2) -> None: if instance_count is None: instance_count = 1 - - if len(self.pc_builder.element_map) > 0 and ( - self.pc_builder.instance_count != instance_count or not self.buffers_valid - ): - - self.pc_builder.prepare(instance_count) - - for key, value in self.pc_values.items(): - self.pc_builder[key] = value - - if len(self.uniform_builder.element_map) > 0 and not self.buffers_valid: - - self.uniform_builder.prepare(1) - for key, value in self.uniform_values.items(): - self.uniform_builder[key] = value - - for descriptor_set, offset, size in self.uniform_descriptors: - descriptor_set.bind_buffer(self.uniform_constants_buffer, 0, offset, size, True, write_access=False) - - self.uniform_constants_buffer.write(self.uniform_builder.tobytes()) - - if not self.buffers_valid: - self.buffers_valid = True + self._prepare_submission_state(instance_count) for key, val in self.queued_pc_values.items(): self.pc_builder[key] = val @@ -213,7 +338,12 @@ def submit(self, instance_count: int = None, queue_index: int = -2) -> None: if len(self.pc_builder.element_map) > 0: my_data = self.pc_builder.tobytes() - super().submit(data=my_data, queue_index=queue_index, instance_count=instance_count) + super().submit( + data=my_data, + queue_index=queue_index, + instance_count=instance_count, + cuda_stream=None, + ) if self._reset_on_submit: self.reset() @@ -221,27 +351,29 @@ def submit(self, instance_count: int = None, queue_index: int = -2) -> None: def submit_any(self, instance_count: int = None) -> None: self.submit(instance_count=instance_count, queue_index=-1) -__default_graph = None -__custom_graph = None +_global_graph = threading.local() -def default_graph() -> CommandGraph: - global __default_graph +def _get_global_graph() -> Optional[CommandGraph]: + return getattr(_global_graph, 'custom_graph', None) - if __default_graph is None: - __default_graph = CommandGraph(reset_on_submit=True, submit_on_record=True) +def default_graph() -> CommandGraph: + if not hasattr(_global_graph, 'default_graph'): + _global_graph.default_graph = CommandGraph(reset_on_submit=True, submit_on_record=True) - return __default_graph + return _global_graph.default_graph def global_graph() -> CommandGraph: - global __custom_graph + custom_graph = _get_global_graph() - if __custom_graph is not None: - return __custom_graph + if custom_graph is not None: + return custom_graph return default_graph() def set_global_graph(graph: CommandGraph = None) -> CommandGraph: - global __custom_graph - old_value = __custom_graph - __custom_graph = graph - return old_value \ No newline at end of file + if graph is None: + _global_graph.custom_graph = None + return + + assert _get_global_graph() is None, "A global CommandGraph is already set for the current thread!" + _global_graph.custom_graph = graph diff --git a/vkdispatch/execution_pipeline/cuda_graph_capture.py b/vkdispatch/execution_pipeline/cuda_graph_capture.py new file mode 100644 index 00000000..a96f6a9e --- /dev/null +++ b/vkdispatch/execution_pipeline/cuda_graph_capture.py @@ -0,0 +1,51 @@ +import vkdispatch as vd + +from contextlib import contextmanager + +import threading + +import typing + +class CUDAGraphCapture: + cuda_stream = typing.Any + uniform_buffers = typing.List[typing.Any] + + def add_uniform_buffer(self, buffer): + self.uniform_buffers.append(buffer) + +_cap = threading.local() + +def _set_capture(capture): + _cap.capture = capture + +def get_cuda_capture() -> CUDAGraphCapture: + return getattr(_cap, "capture", None) + +@contextmanager +def cuda_graph_capture(cuda_stream=None): + assert vd.is_cuda(), "CUDA graph capture is only supported when using the CUDA backend." + + cap = CUDAGraphCapture() + cap.cuda_stream = cuda_stream + cap.uniform_buffers = [] + + _set_capture(cap) + + try: + yield cap + finally: + _set_capture(None) + +@contextmanager +def suspend_cuda_capture(): + """Temporarily disable vkdispatch CUDA capture state for non-captured ops.""" + cap = get_cuda_capture() + if cap is None: + yield + return + + _set_capture(None) + try: + yield + finally: + _set_capture(cap) diff --git a/vkdispatch/fft/__init__.py b/vkdispatch/fft/__init__.py index 550cc7fd..5dab17ff 100644 --- a/vkdispatch/fft/__init__.py +++ b/vkdispatch/fft/__init__.py @@ -1,18 +1,36 @@ -from .config import FFTConfig, FFTParams +from .config import FFTConfig +from .grid_manager import FFTGridManager +from .sdata_manager import FFTSDataManager +from .registers import FFTRegisters -from .resources import FFTResources, allocate_fft_resources +from .resources import FFTResources + +from .memory_iterators import memory_reads_iterator, memory_writes_iterator, MemoryOp + +from .global_memory_iterators import global_writes_iterator, GlobalWriteOp +from .global_memory_iterators import global_reads_iterator, GlobalReadOp +from .global_memory_iterators import global_trasposed_write_iterator, GlobalTransposedWriteOp from .io_proxy import IOProxy -from .io_manager import IOManager +from .io_manager import IOManager, read_op, write_op from .context import fft_context -from .shader import make_fft_shader, get_cache_info, cache_clear, print_cache_info -from .shader import make_convolution_shader +from .shader_factories import make_fft_shader, get_cache_info, cache_clear, print_cache_info, mapped_kernel_index +from .shader_factories import make_convolution_shader, make_transpose_shader, get_transposed_size from .functions import fft, fft2, fft3, ifft, ifft2, ifft3 from .functions import rfft, rfft2, rfft3, irfft, irfft2, irfft3 -from .functions import convolve, convolve2D, convolve2DR +from .src_functions import fft_src, fft2_src, fft3_src, ifft_src, ifft2_src, ifft3_src +from .src_functions import rfft_src, rfft2_src, rfft3_src, irfft_src, irfft2_src, irfft3_src + +from .src_functions import fft_print_src, fft2_print_src, fft3_print_src, ifft_print_src, ifft2_print_src, ifft3_print_src +from .src_functions import rfft_print_src, rfft2_print_src, rfft3_print_src, irfft_print_src, irfft2_print_src, irfft3_print_src + +from .functions import convolve, convolve2D, convolve2DR, transpose + +from .src_functions import convolve_src, convolve2D_src, convolve2DR_src, transpose_src +from .src_functions import convolve_print_src, convolve2D_print_src, convolve2DR_print_src from .prime_utils import pad_dim \ No newline at end of file diff --git a/vkdispatch/fft/config.py b/vkdispatch/fft/config.py index 520ed9c6..02628e84 100644 --- a/vkdispatch/fft/config.py +++ b/vkdispatch/fft/config.py @@ -1,142 +1,196 @@ import vkdispatch as vd -import numpy as np +import vkdispatch.codegen as vc import dataclasses -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Dict +from ..compat import numpy_compat as npc +import vkdispatch.base.dtype as dtypes from .prime_utils import prime_factors, group_primes, default_register_limit, default_max_prime -@dataclasses.dataclass -class FFTRegisterStageConfig: - """ - Configuration for an FFT register stage. +from .stages import FFTRegisterStageConfig - Attributes: +def plan_fft_stages(N: int, max_register_count: int, compute_item_size: int) -> Tuple[FFTRegisterStageConfig]: + all_factors = prime_factors(N) - primes (Tuple[int]): The prime numbers used for factorization. - fft_length (int): The length of each FFT stage. - instance_count (int): The number of instances required to achieve the desired level of parallelism. - registers_used (int): The total number of registers used by the FFT stage. - remainder (int): The remainder of `N` divided by `registers_used`. - remainder_offset (int): A flag indicating whether the remainder is non-zero. - extra_ffts (int): The additional number of FFT stages required to process the remainder. - thread_count (int): The total number of threads used in the computation. - sdata_size (int): The size of the shared memory buffer used to store intermediate results. - sdata_width (int): The width of each element in the shared memory buffer. - sdata_width_padded (int): The padded width of each element in the shared memory buffer. + for factor in all_factors: + assert factor <= default_max_prime(), f"A prime factor of {N} is {factor}, which exceeds the maximum prime supported {default_max_prime()}" - """ + prime_groups = group_primes(all_factors, max_register_count) - primes: Tuple[int] - fft_length: int - instance_count: int - registers_used: int - remainder: int - remainder_offset: int - extra_ffts: int - thread_count: int - sdata_size: int - sdata_width: int - sdata_width_padded: int + stages = [] + input_stride = 1 - def __init__(self, primes: List[int], max_register_count: int, N: int): - """ - Initializes the FFTRegisterStageConfig object. + for group in prime_groups: + stage = FFTRegisterStageConfig( + group, + max_register_count, + N, + compute_item_size, + input_stride + ) + stages.append(stage) + input_stride = stage.output_stride - Parameters: + return tuple(stages) - primes (List[int]): The prime numbers to use for factorization. - max_register_count (int): The maximum number of registers allowed per thread. - N (int): The length of the input data. +@dataclasses.dataclass +class FFTPlanCandidate: + max_register_count: int + stages: Tuple[FFTRegisterStageConfig] + register_count: int + batch_threads: int + transfer_count: Optional[int] = None + + def __init__(self, N: int, max_register_count: int,compute_item_size: int): + stages = plan_fft_stages(N, max_register_count, compute_item_size) + register_count = max(stage.registers_used for stage in stages) + batch_threads = max(stage.thread_count for stage in stages) + + if register_count > max_register_count: + self.max_register_count = None + self.stages = None + self.register_count = None + self.batch_threads = None + self.transfer_count = None + return + + transfer_count = 0 + output_stride = 1 + + for stage_index in range(len(stages) - 1): + output_stage = stages[stage_index] + input_stage = stages[stage_index + 1] + + output_keys = output_stage.get_output_format(register_count).keys() + input_keys = input_stage.get_input_format(register_count).keys() + + if output_keys != input_keys: + transfer_count += 1 + + output_stride *= output_stage.fft_length + + self.max_register_count = max_register_count + self.stages = stages + self.register_count = register_count + self.batch_threads = batch_threads + self.transfer_count = transfer_count + +def register_limit_candidates(N: int, initial_limit: int) -> List[int]: + divisors = {1} + + for factor in prime_factors(N): + divisors.update(divisor * factor for divisor in tuple(divisors)) + + candidates = [initial_limit] + candidates.extend( + divisor + for divisor in sorted(divisors) + if initial_limit < divisor <= N + ) + return candidates + +def required_batch_threads_limit(batch_inner_count: int) -> int: + context = vd.get_context() + thread_dimension_limit = ( + context.max_workgroup_size[1] + if batch_inner_count > 1 + else context.max_workgroup_size[0] + ) + return max(1, min(int(thread_dimension_limit), int(context.max_workgroup_invocations))) + +def select_fft_plan_candidate( + N: int, + batch_inner_count: int, + compute_item_size: int, + max_register_count: Optional[int], +) -> FFTPlanCandidate: + batch_threads_limit = required_batch_threads_limit(batch_inner_count) + dimension_name = "y" if batch_inner_count > 1 else "x" + + if max_register_count is not None: + requested_limit = min(max_register_count, N) + candidate = FFTPlanCandidate( + N=N, + max_register_count=requested_limit, + compute_item_size=compute_item_size, + ) - """ - self.primes = tuple(primes) - self.fft_length = int(np.round(np.prod(primes))) - instance_primes = prime_factors(N // self.fft_length) - - self.instance_count = 1 + assert candidate.stages is not None, f"Failed to create an FFT plan candidate for N={N} with max_register_count={requested_limit}" - while len(instance_primes) > 0: - if self.instance_count * self.fft_length * instance_primes[0] > max_register_count: - break - self.instance_count *= instance_primes[0] - instance_primes = instance_primes[1:] + if candidate.batch_threads <= batch_threads_limit: + return candidate - self.registers_used = self.fft_length * self.instance_count + best_candidate = candidate + explicit_text = "requested" + searched_limit = requested_limit + else: + max_registers = default_register_limit() - self.remainder = N % self.registers_used - assert self.remainder % self.fft_length == 0, "Remainder must be divisible by the FFT length" - self.remainder_offset = 1 if self.remainder != 0 else 0 - self.extra_ffts = self.remainder // self.fft_length + if N==16 or N==8 or N==4 or N==2 and vd.get_devices()[0].is_nvidia(): + max_registers = max(2, N//2) - self.thread_count = N // self.registers_used + self.remainder_offset + baseline_limit = min(8, N) + requested_limit = baseline_limit + candidate_limits = register_limit_candidates(max_registers, baseline_limit) + searched_limit = candidate_limits[-1] - self.sdata_width = self.registers_used + baseline_candidate = FFTPlanCandidate( + N=N, + max_register_count=baseline_limit, + compute_item_size=compute_item_size, + ) + best_candidate = baseline_candidate if baseline_candidate.stages is not None else None - threads_primes = prime_factors(self.thread_count) + if best_candidate is not None and baseline_candidate.batch_threads <= batch_threads_limit: + for candidate_limit in candidate_limits[1:]: + candidate = FFTPlanCandidate( + N=N, + max_register_count=candidate_limit, + compute_item_size=compute_item_size, + ) - while self.sdata_width < 16 and len(threads_primes) > 0: - self.sdata_width *= threads_primes[0] - threads_primes = threads_primes[1:] + if candidate.stages is None: + continue - self.sdata_width_padded = self.sdata_width + if best_candidate is None or candidate.batch_threads < best_candidate.batch_threads: + best_candidate = candidate - if self.sdata_width_padded % 2 == 0: - self.sdata_width_padded += 1 + if candidate.batch_threads > batch_threads_limit: + continue - self.sdata_size = self.sdata_width_padded * int(np.prod(threads_primes)) + if candidate.transfer_count < baseline_candidate.transfer_count: + return candidate - if self.sdata_size > vd.get_context().max_shared_memory // vd.complex64.item_size: - self.sdata_width_padded = self.sdata_width - self.sdata_size = self.sdata_width_padded * int(np.prod(threads_primes)) + return baseline_candidate - def __str__(self): - """ - Returns a string representation of the FFTRegisterStageConfig object. - - """ - return f""" -FFT Stage Config: - primes: {self.primes} - fft_length: {self.fft_length} - instance_count: {self.instance_count} - registers_used: {self.registers_used} - remainder: {self.remainder} - remainder_offset: {self.remainder_offset} - extra_ffts: {self.extra_ffts} - thread_count: {self.thread_count} - sdata_size: {self.sdata_size} - sdata_width: {self.sdata_width} - sdata_width_padded: {self.sdata_width_padded}""" - - def __repr__(self): - """ - Returns a string representation of the FFTRegisterStageConfig object. + for candidate_limit in candidate_limits[1:]: + candidate = FFTPlanCandidate( + N=N, + max_register_count=candidate_limit, + compute_item_size=compute_item_size, + ) + if candidate.stages is None: + continue - """ - return str(self) + if best_candidate is None or candidate.batch_threads < best_candidate.batch_threads: + best_candidate = candidate -@dataclasses.dataclass -class FFTParams: - config: "FFTConfig" = None - inverse: bool = False - normalize: bool = True - r2c: bool = False - batch_outer_stride: int = None - batch_inner_stride: int = None - fft_stride: int = None - angle_factor: float = None - input_sdata: bool = False - input_buffers: List[vd.Buffer] = None - output_buffers: List[vd.Buffer] = None - passthrough: bool = False - - sdata_row_size: Optional[int] = None - sdata_row_size_padded: Optional[int] = None + if candidate.batch_threads <= batch_threads_limit: + return candidate + + explicit_text = "default" + raise ValueError( + f"Unable to build an FFT plan for size {N}: minimum achievable batch thread count " + f"{best_candidate.batch_threads} exceeds the device's local {dimension_name}-dimension " + f"limit {batch_threads_limit} (starting from {explicit_text} max_register_count=" + f"{requested_limit}, searched up to {searched_limit})." + ) @dataclasses.dataclass class FFTConfig: N: int + compute_type: dtypes.dtype register_count: int max_prime_radix: int stages: Tuple[FFTRegisterStageConfig] @@ -144,39 +198,40 @@ class FFTConfig: fft_stride: int batch_outer_stride: int batch_outer_count: int - batch_inner_stride: int batch_inner_count: int batch_threads: int sdata_allocation: int - sdata_row_size: Optional[int] - sdata_row_size_padded: Optional[int] + sdata_row_size: int + sdata_row_size_padded: int - def __init__(self, buffer_shape: Tuple, axis: int = None, max_register_count: int = None): + def __init__( + self, + buffer_shape: Tuple, + axis: int = None, + max_register_count: int = None, + compute_type: dtypes.dtype = vd.complex64, + ): if axis is None: axis = len(buffer_shape) - 1 - total_buffer_length = np.round(np.prod(buffer_shape)).astype(np.int32) + if not dtypes.is_complex(compute_type): + raise ValueError(f"compute_type must be a complex dtype, got {compute_type}") + + self.compute_type = compute_type + + total_buffer_length = int(round(npc.prod(buffer_shape))) N = buffer_shape[axis] - self.fft_stride = np.round(np.prod(buffer_shape[axis + 1:])).astype(np.int32) + self.fft_stride = int(round(npc.prod(buffer_shape[axis + 1:]))) self.batch_outer_stride = self.fft_stride * N self.batch_outer_count = total_buffer_length // self.batch_outer_stride - self.batch_inner_stride = 1 self.batch_inner_count = self.fft_stride self.N = N - if max_register_count is None: - max_register_count = default_register_limit() - - if N == 16 and vd.get_devices()[0].is_nvidia(): - max_register_count = 15 # Special case for 16-point FFTs because this is faster - - max_register_count = min(max_register_count, N) - all_factors = prime_factors(N) for factor in all_factors: @@ -184,15 +239,18 @@ def __init__(self, buffer_shape: Tuple, axis: int = None, max_register_count: in self.max_prime_radix = max(all_factors) - prime_groups = group_primes(all_factors, max_register_count) - - self.stages = tuple([FFTRegisterStageConfig(group, max_register_count, N) for group in prime_groups]) - register_utilizations = [stage.registers_used for stage in self.stages] - self.register_count = max(register_utilizations) - - assert self.register_count <= max_register_count, f"Register count {self.register_count} exceeds max register count {max_register_count}" + plan_candidate = select_fft_plan_candidate( + N=N, + batch_inner_count=self.batch_inner_count, + compute_item_size=self.compute_type.item_size, + max_register_count=max_register_count, + ) + self.stages = plan_candidate.stages + self.register_count = plan_candidate.register_count - self.sdata_allocation = 1 + self.sdata_allocation = 1 + self.sdata_row_size = 1 + self.sdata_row_size_padded = 1 for stage in self.stages: if stage.sdata_size < self.sdata_allocation: @@ -202,9 +260,9 @@ def __init__(self, buffer_shape: Tuple, axis: int = None, max_register_count: in self.sdata_row_size = stage.sdata_width self.sdata_row_size_padded = stage.sdata_width_padded - self.thread_counts = [stage.thread_count for stage in self.stages] + self.thread_counts = tuple(stage.thread_count for stage in self.stages) - self.batch_threads = max(self.thread_counts) + self.batch_threads = plan_candidate.batch_threads def __str__(self): return f"FFT Config:\nN: {self.N}\nregister_count: {self.register_count}\nstages:\n{self.stages}\nlocal_size: {self.thread_counts}" @@ -212,28 +270,5 @@ def __str__(self): def __repr__(self): return str(self) - def params(self, - inverse: bool = False, - normalize: bool = True, - r2c: bool = False, - input_sdata: bool = False, - input_buffers: List[vd.Buffer] = None, - output_buffers: List[vd.Buffer] = None, - passthrough: bool = False) -> FFTParams: - return FFTParams( - config=self, - inverse=inverse, - normalize=normalize, - r2c=r2c, - batch_outer_stride=self.batch_outer_stride, - batch_inner_stride=self.batch_inner_stride, - fft_stride=self.fft_stride, - angle_factor=2 * np.pi * (1 if inverse else -1), - input_sdata=input_sdata, - input_buffers=input_buffers, - output_buffers=output_buffers, - passthrough=passthrough, - sdata_row_size=self.sdata_row_size, - sdata_row_size_padded=self.sdata_row_size_padded - ) - + def angle_factor(self, inverse: bool) -> float: + return 2 * npc.pi * (1 if inverse else -1) diff --git a/vkdispatch/fft/context.py b/vkdispatch/fft/context.py index db2fe16d..8a6bc7cc 100644 --- a/vkdispatch/fft/context.py +++ b/vkdispatch/fft/context.py @@ -1,33 +1,196 @@ import vkdispatch as vd import vkdispatch.codegen as vc +import vkdispatch.base.dtype as dtypes + import contextlib -from typing import Union, Tuple +from typing import Optional, Tuple, Union, List, Dict -from .manager import FFTManager +from .io_manager import IOManager +from .config import FFTConfig +from .grid_manager import FFTGridManager +from .sdata_manager import FFTSDataManager +from .resources import FFTResources +from .registers import FFTRegisters +from .cooley_tukey import radix_composite +from .global_memory_iterators import global_reads_iterator, global_writes_iterator -@contextlib.contextmanager -def fft_context(buffer_shape: Tuple, +class FFTContext: + shader_context: vd.ShaderContext + config: FFTConfig + grid: FFTGridManager + registers: FFTRegisters + sdata: FFTSDataManager + resources: FFTResources + fft_callable: vd.ShaderFunction + name: str + + declared_shader_args: bool + declarer: str + + def __init__(self, + shader_context: vd.ShaderContext, + buffer_shape: Tuple, axis: int = None, max_register_count: int = None, - output_map: Union[vd.MappingFunction, type, None] = None, - input_map: Union[vd.MappingFunction, type, None] = None, - kernel_map: Union[vd.MappingFunction, type, None] = None): + compute_type: dtypes.dtype = vd.complex64, + name: str = None): + self.shader_context = shader_context + self.declared_shader_args = False + self.declarer = None + + self.config = FFTConfig(buffer_shape, axis, max_register_count, compute_type=compute_type) + self.grid = FFTGridManager(self.config, True, True) + self.resources = FFTResources(self.config, self.grid) + + self.registers = self.allocate_registers("fft") + + self.sdata = FFTSDataManager(self.config, self.grid, self.registers) + + self.fft_callable = None + self.name = name if name is not None else f"fft_shader_{buffer_shape}_{axis}" + + def allocate_registers(self, name: str, count: int = None) -> FFTRegisters: + assert name is not None, "Must provide a name for allocated registers" + + if count is None: + count = self.config.register_count + + return FFTRegisters(self.resources, count, name) + + def declare_shader_args(self, types: List) -> List[vc.ShaderVariable]: + assert not self.declared_shader_args, f"Shader arguments already declared with {self.declarer}" + self.declared_shader_args = True + self.declarer = "declare_shader_args" + return self.shader_context.declare_input_arguments(types) + + def make_io_manager(self, + output_map: Optional[vd.MappingFunction], + output_type: dtypes.dtype = vd.complex64, + input_type: Optional[dtypes.dtype] = None, + input_map: Optional[vd.MappingFunction] = None, + kernel_map: Optional[vd.MappingFunction] = None) -> IOManager: + assert not self.declared_shader_args, f"Shader arguments already declared with {self.declarer}" + self.declared_shader_args = True + self.declarer = "make_io_manager" + return IOManager( + default_registers=self.registers, + shader_context=self.shader_context, + output_map=output_map, + output_type=output_type, + input_type=input_type, + input_map=input_map, + kernel_map=kernel_map + ) + + def reads_iter(self, + r2c: bool = False, + inverse: Optional[bool] = None, + format_transposed: bool = False, + inner_only: bool = False, + signal_range: Optional[Tuple[Optional[int], Optional[int]]] = None): + return global_reads_iterator( + self.registers, + r2c=r2c, + inverse=inverse, + format_transposed=format_transposed, + inner_only=inner_only, + signal_range=signal_range + ) + + def writes_iter(self, + r2c: bool = False, + inverse: Optional[bool] = None): + return global_writes_iterator( + self.registers, + r2c=r2c, + inverse=inverse + ) + + def register_shuffle(self, + registers: Optional[FFTRegisters] = None, + output_stage: int = -1, + input_stage: int = 0) -> bool: + if registers is None: + registers = self.registers + + if registers.try_shuffle( + output_stage=output_stage, + input_stage=input_stage + ): + return True + + vc.comment("Register shuffle not possible, falling back to shared memory shuffle.", preceding_new_line=False) + self.sdata.write_to_sdata( + registers=registers, + stage_index=output_stage + ) + + self.sdata.read_from_sdata( + registers=registers, + stage_index=input_stage + ) + + def compile_shader(self): + self.fft_callable = self.shader_context.get_function( + local_size=self.grid.local_size, + exec_count=self.grid.exec_size, + name=self.name + ) + + def get_callable(self) -> vd.ShaderFunction: + assert self.fft_callable is not None, "Shader not compiled yet... something is wrong" + return self.fft_callable + + def execute(self, inverse: bool): + stage_count = len(self.config.stages) + + for i in range(stage_count): + stage = self.config.stages[i] + + vc.comment(f"""FFT stage {i + 1}/{stage_count}. +Prime group {stage.primes}: execute {stage.instance_count} radix-{stage.fft_length} sub-FFTs per invocation. +Register-group coverage this stage: {self.config.N // stage.registers_used}.""") + + if i != 0: + self.register_shuffle(output_stage=i-1, input_stage=i) + + self.resources.stage_begin(i) + for ii, invocation in enumerate(self.config.stages[i].invocations): + self.resources.invocation_gaurd(i, ii) + + self.registers.slice_set(invocation.register_selection, radix_composite( + resources=self.resources, + inverse=inverse, + register_list=self.registers.register_slice(invocation.register_selection), + primes=stage.primes, + twiddle_index=invocation.get_inner_block_offset(self.resources.tid), + twiddle_N=invocation.block_width + )) + + self.resources.invocation_end(i) + self.resources.stage_end(i) + +@contextlib.contextmanager +def fft_context(buffer_shape: Tuple, + axis: Optional[int] = None, + max_register_count: Optional[int] = None, + compute_type: dtypes.dtype = vd.complex64, + name: Optional[str] = None): try: - with vc.builder_context(enable_exec_bounds=False) as builder: - manager = FFTManager( - builder=builder, + with vd.shader_context(vc.ShaderFlags.NO_EXEC_BOUNDS) as context: + fft_context = FFTContext( + shader_context=context, buffer_shape=buffer_shape, axis=axis, max_register_count=max_register_count, - output_map=output_map, - input_map=input_map, - kernel_map=kernel_map + compute_type=compute_type, + name=name ) - yield manager + yield fft_context - manager.compile_shader() + fft_context.compile_shader() finally: - pass \ No newline at end of file + pass diff --git a/vkdispatch/fft/cooley_tukey.py b/vkdispatch/fft/cooley_tukey.py new file mode 100644 index 00000000..f2821907 --- /dev/null +++ b/vkdispatch/fft/cooley_tukey.py @@ -0,0 +1,180 @@ +import vkdispatch.codegen as vc +from .resources import FFTResources + +from typing import List, Union + +from ..compat import numpy_compat as npc + +def get_angle_factor(inverse: bool) -> float: + return 2 * npc.pi * (1 if inverse else -1) + +def _apply_right_angle_twiddle(resources: FFTResources, register: vc.ShaderVariable, angle_int: int) -> bool: + if angle_int == 0: + return True + + if angle_int == 1: + resources.radix_registers[0].real = register.real + register.real = -register.imag + register.imag = resources.radix_registers[0].real + return True + + if angle_int == -1: + resources.radix_registers[0].real = register.real + register.real = register.imag + register.imag = -resources.radix_registers[0].real + return True + + if angle_int == 2 or angle_int == -2: + register[:] = -register + return True + + return False + +def _apply_constant_twiddle(resources: FFTResources, register: vc.ShaderVariable, omega: complex) -> bool: + scaled_angle = 2 * npc.angle(omega) / npc.pi + rounded_angle = npc.round(scaled_angle) + + if abs(scaled_angle - rounded_angle) >= 1e-8: + return False + + return _apply_right_angle_twiddle(resources, register, int(rounded_angle)) + +def _apply_twiddle_to_register( + resources: FFTResources, + register: vc.ShaderVariable, + twiddle: Union[complex, vc.ShaderVariable]): + if isinstance(twiddle, complex): + if _apply_constant_twiddle(resources, register, twiddle): + return + + twiddle = vc.to_dtype(register.var_type, twiddle.real, twiddle.imag) + + resources.radix_registers[0][:] = vc.mult_complex(register, twiddle) + register[:] = resources.radix_registers[0] + +def radix_P(resources: FFTResources, inverse: bool, register_list: List[vc.ShaderVariable]): + assert len(register_list) <= len(resources.radix_registers), "Too many registers for radix_P" + + if len(register_list) == 1: + return + + if len(register_list) == 2: + vc.comment("Radix-2 butterfly base case", preceding_new_line=False) + resources.radix_registers[0][:] = register_list[1] + register_list[1][:] = register_list[0] - resources.radix_registers[0] + register_list[0][:] = register_list[0] + resources.radix_registers[0] + return + + vc.comment(f"Radix-{len(register_list)} DFT", preceding_new_line=False) + + angle_factor = get_angle_factor(inverse) + + for i in range(0, len(register_list)): + for j in range(0, len(register_list)): + if j == 0: + resources.radix_registers[i][:] = register_list[j] + continue + + if i == 0: + resources.radix_registers[i] += register_list[j] + continue + + if i * j == len(register_list) // 2 and len(register_list) % 2 == 0: + resources.radix_registers[i] -= register_list[j] + continue + + omega = npc.exp_complex(1j * angle_factor * i * j / len(register_list)) + typed_omega = vc.to_dtype(register_list[j].var_type, omega.real, omega.imag) + resources.omega_register[:] = vc.mult_complex(register_list[j], typed_omega) + resources.radix_registers[i] += resources.omega_register + + for i in range(0, len(register_list)): + register_list[i][:] = resources.radix_registers[i] + +def apply_twiddle_factors( + resources: FFTResources, + inverse: bool, + register_list: List[vc.ShaderVariable], + twiddle_index: Union[int, vc.ShaderVariable] = 0, + twiddle_N: int = 1): + + if isinstance(twiddle_index, int) and twiddle_index == 0: + return + + twiddle_index_str = str(twiddle_index) if isinstance(twiddle_index, int) else twiddle_index.resolve() + vc.comment(f"""Applying Cooley-Tukey inter-stage twiddle factors before the next butterfly pass. +Twiddle domain size: N = {twiddle_N}. Twiddle index source: {twiddle_index_str}. +For each non-DC lane i>0, multiply by W_N^(i * twiddle_index). +This phase-aligns each sub-FFT with its parent decomposition stage.""") + + angle_factor = get_angle_factor(inverse) + + for i in range(len(register_list)): + if i == 0: + continue + + if isinstance(twiddle_index, int): + if twiddle_index == 0: + continue + + omega = npc.exp_complex(1j * angle_factor * i * twiddle_index / twiddle_N) + + _apply_twiddle_to_register(resources, register_list[i], omega) + continue + + angle_scale = vc.to_dtype(resources.omega_register.real.var_type, angle_factor * i / twiddle_N) + twiddle_scale = vc.to_dtype(resources.omega_register.real.var_type, twiddle_index) + resources.omega_register.real = angle_scale * twiddle_scale + resources.omega_register[:] = vc.complex_from_euler_angle(resources.omega_register.real) + resources.radix_registers[0][:] = vc.mult_complex(register_list[i], resources.omega_register) + register_list[i][:] = resources.radix_registers[0] + +def radix_composite( + resources: FFTResources, + inverse: bool, + register_list: List[vc.ShaderVariable], + primes: List[int], + twiddle_index: Union[int, vc.ShaderVariable] = 0, + twiddle_N: int = 1): + if len(register_list) == 1: + return + + N = len(register_list) + + assert N == npc.prod(primes), "Product of primes must be equal to the number of registers" + + vc.comment(f"""Starting mixed-radix FFT decomposition for this invocation on {N} register samples. +Radix factorization sequence: {primes}. +At each level: partition lanes into stage-local sub-sequences, apply twiddles, +run radix-P butterflies, then reassemble in stride-consistent order for downstream stages.""") + + apply_twiddle_factors( + resources=resources, + inverse=inverse, + register_list=register_list, + twiddle_index=twiddle_index, + twiddle_N=twiddle_N + ) + + output_stride = 1 + + for prime in primes: + sub_squences = [register_list[i::N//prime] for i in range(N//prime)] + + block_width = output_stride * prime + + for i in range(0, N // prime): + inner_block_offset = i % output_stride + block_index = (i * prime) // block_width + + apply_twiddle_factors(resources, inverse, sub_squences[i], twiddle_index=inner_block_offset, twiddle_N=block_width) + radix_P(resources, inverse, sub_squences[i]) + + sub_sequence_offset = block_index * block_width + inner_block_offset + + for j in range(prime): + register_list[sub_sequence_offset + j * output_stride] = sub_squences[i][j] + + output_stride *= prime + + return register_list diff --git a/vkdispatch/fft/functions.py b/vkdispatch/fft/functions.py index b35a8f4c..0818a8eb 100644 --- a/vkdispatch/fft/functions.py +++ b/vkdispatch/fft/functions.py @@ -1,8 +1,105 @@ import vkdispatch as vd -from .shader import make_fft_shader, make_convolution_shader +from .shader_factories import make_fft_shader, make_convolution_shader, make_transpose_shader, get_transposed_size +from .precision import ( + ensure_supported_complex_precision, + resolve_compute_precision, + validate_complex_precision, +) -from typing import Tuple, Union +from typing import List, Tuple, Union, Optional + + +def _validate_map_argument_annotations(map_fn: vd.MappingFunction, map_name: str) -> None: + for buffer_type in map_fn.buffer_types: + if not hasattr(buffer_type, "__args__") or len(buffer_type.__args__) != 1: + raise ValueError( + f"{map_name} contains an annotation without exactly one type argument: {buffer_type}" + ) + + +def _resolve_output_precision( + buffers: Tuple[vd.Buffer, ...], + output_map: Optional[vd.MappingFunction], + output_type: Optional[vd.dtype], +) -> Optional[vd.dtype]: + if output_map is not None: + if output_type is not None: + raise ValueError("output_type cannot be provided when output_map is used") + return None + + resolved_output = buffers[0].var_type if output_type is None else output_type + validate_complex_precision(resolved_output, arg_name="output_type") + ensure_supported_complex_precision(resolved_output, role="Output") + return resolved_output + + +def _resolve_input_precision( + buffers: Tuple, + input_map: Optional[vd.MappingFunction], + output_map: Optional[vd.MappingFunction], + input_type: Optional[vd.dtype], + output_precision: Optional[vd.dtype], +) -> Optional[vd.dtype]: + if input_map is not None: + if input_type is not None: + raise ValueError("input_type cannot be provided when input_map is used") + return None + + if output_map is not None: + output_arg_count = len(output_map.buffer_types) + if len(buffers) <= output_arg_count: + raise ValueError( + "When output_map is used without input_map, an input buffer argument must be provided " + "after output_map arguments" + ) + + resolved_input = input_type + if resolved_input is None: + inferred_input = buffers[output_arg_count] + if not hasattr(inferred_input, "var_type"): + raise ValueError( + "When output_map is used without input_map, the argument after output_map arguments " + "must be a buffer" + ) + resolved_input = inferred_input.var_type + + validate_complex_precision(resolved_input, arg_name="input_type") + ensure_supported_complex_precision(resolved_input, role="Input") + return resolved_input + + if output_precision is None: + raise ValueError("output_precision must be provided when output_map is not used") + + resolved_input = output_precision if input_type is None else input_type + validate_complex_precision(resolved_input, arg_name="input_type") + ensure_supported_complex_precision(resolved_input, role="Input") + + if resolved_input != output_precision: + raise ValueError( + "input_type must match output_type when input_map is None (default FFT path is in-place)" + ) + + return resolved_input + + +def _resolve_kernel_precision( + buffers: Tuple[vd.Buffer, ...], + kernel_map: Optional[vd.MappingFunction], + kernel_type: Optional[vd.dtype], +) -> Optional[vd.dtype]: + if kernel_map is not None: + if kernel_type is not None: + raise ValueError("kernel_type cannot be provided when kernel_map is used") + return None + + if len(buffers) < 2: + raise ValueError("Kernel precision inference requires a kernel buffer argument") + + resolved_kernel = buffers[1].var_type if kernel_type is None else kernel_type + validate_complex_precision(resolved_kernel, arg_name="kernel_type") + ensure_supported_complex_precision(resolved_kernel, role="Kernel") + return resolved_kernel def fft( *buffers: vd.Buffer, @@ -15,13 +112,37 @@ def fft( normalize_inverse: bool = True, r2c: bool = False, input_map: vd.MappingFunction = None, - output_map: vd.MappingFunction = None): + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + compute_type: vd.dtype = None, + input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None): assert len(buffers) >= 1, "At least one buffer must be provided" + + if input_map is None and output_map is None and len(buffers) != 1: + raise ValueError("fft() expects exactly one buffer unless input_map/output_map are used") if buffer_shape is None: buffer_shape = buffers[0].shape + resolved_output_type = _resolve_output_precision(buffers, output_map, output_type) + resolved_input_type = _resolve_input_precision(buffers, input_map, output_map, input_type, resolved_output_type) + + io_precisions: List[vd.dtype] = [] + if output_map is None: + io_precisions.append(resolved_output_type) + else: + _validate_map_argument_annotations(output_map, "output_map") + + if input_map is None: + if resolved_input_type is not None: + io_precisions.append(resolved_input_type) + else: + _validate_map_argument_annotations(input_map, "input_map") + + resolved_compute_type = resolve_compute_precision(io_precisions, compute_type) + fft_shader = make_fft_shader( tuple(buffer_shape), axis, @@ -29,25 +150,91 @@ def fft( normalize_inverse=normalize_inverse, r2c=r2c, input_map=input_map, - output_map=output_map) + output_map=output_map, + input_type=resolved_input_type, + output_type=resolved_output_type, + compute_type=resolved_compute_type, + input_signal_range=input_signal_range) if print_shader: print(fft_shader) fft_shader(*buffers, graph=graph) -def fft2(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): +def fft2( + buffer: vd.Buffer, + graph: vd.CommandGraph = None, + print_shader: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + compute_type: vd.dtype = None, +): assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' - fft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 2, input_map=input_map) - fft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 1, output_map=output_map) - -def fft3(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=len(buffer.shape) - 2, + input_map=input_map, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=len(buffer.shape) - 1, + output_map=output_map, + output_type=output_type if output_map is None else None, + input_type=input_type if output_map is None else None, + compute_type=compute_type, + ) + +def fft3( + buffer: vd.Buffer, + graph: vd.CommandGraph = None, + print_shader: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + compute_type: vd.dtype = None, +): assert len(buffer.shape) == 3, 'Buffer must have 3 dimensions' - fft(buffer, graph=graph, print_shader=print_shader, axis=0, input_map=input_map) - fft(buffer, graph=graph, print_shader=print_shader, axis=1) - fft(buffer, graph=graph, print_shader=print_shader, axis=2, output_map=output_map) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=0, + input_map=input_map, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=1, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=2, + output_map=output_map, + output_type=output_type if output_map is None else None, + input_type=input_type if output_map is None else None, + compute_type=compute_type, + ) def ifft( @@ -58,54 +245,225 @@ def ifft( name: str = None, normalize: bool = True, input_map: vd.MappingFunction = None, - output_map: vd.MappingFunction = None): - fft(buffer, graph=graph, print_shader=print_shader, axis=axis, name=name, inverse=True, normalize_inverse=normalize, input_map=input_map, output_map=output_map) - -def ifft2(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + compute_type: vd.dtype = None): + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=axis, + name=name, + inverse=True, + normalize_inverse=normalize, + input_map=input_map, + output_map=output_map, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + +def ifft2( + buffer: vd.Buffer, + graph: vd.CommandGraph = None, + print_shader: bool = False, + normalize: bool = True, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + compute_type: vd.dtype = None, +): assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' - ifft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 2, normalize=normalize, input_map=input_map) - ifft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 1, normalize=normalize, output_map=output_map) - -def ifft3(buffer: vd.Buffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=len(buffer.shape) - 2, + normalize=normalize, + input_map=input_map, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=len(buffer.shape) - 1, + normalize=normalize, + output_map=output_map, + output_type=output_type if output_map is None else None, + input_type=input_type if output_map is None else None, + compute_type=compute_type, + ) + +def ifft3( + buffer: vd.Buffer, + graph: vd.CommandGraph = None, + print_shader: bool = False, + normalize: bool = True, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + compute_type: vd.dtype = None, +): assert len(buffer.shape) == 3, 'Buffer must have 3 dimensions' - ifft(buffer, graph=graph, print_shader=print_shader, axis=0, normalize=normalize, input_map=input_map) - ifft(buffer, graph=graph, print_shader=print_shader, axis=1, normalize=normalize) - ifft(buffer, graph=graph, print_shader=print_shader, axis=2, normalize=normalize, output_map=output_map) - - -def rfft(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, name: str = None): - fft(buffer, buffer_shape=buffer.real_shape, graph=graph, print_shader=print_shader, name=name, r2c=True) - -def rfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False): + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=0, + normalize=normalize, + input_map=input_map, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=1, + normalize=normalize, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=2, + normalize=normalize, + output_map=output_map, + output_type=output_type if output_map is None else None, + input_type=input_type if output_map is None else None, + compute_type=compute_type, + ) + + +def rfft( + buffer: vd.RFFTBuffer, + graph: vd.CommandGraph = None, + print_shader: bool = False, + name: str = None, + compute_type: vd.dtype = None, +): + fft( + buffer, + buffer_shape=buffer.real_shape, + graph=graph, + print_shader=print_shader, + name=name, + r2c=True, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) + +def rfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, compute_type: vd.dtype = None): assert len(buffer.real_shape) == 2 or len(buffer.real_shape) == 3, 'Buffer must have 2 or 3 dimensions' - rfft(buffer, graph=graph, print_shader=print_shader) - fft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.real_shape) - 2) - -def rfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False): + rfft(buffer, graph=graph, print_shader=print_shader, compute_type=compute_type) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=len(buffer.real_shape) - 2, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) + +def rfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, compute_type: vd.dtype = None): assert len(buffer.real_shape) == 3, 'Buffer must have 3 dimensions' - rfft(buffer, graph=graph, print_shader=print_shader) - fft(buffer, graph=graph, print_shader=print_shader, axis=1) - fft(buffer, graph=graph, print_shader=print_shader, axis=0) - -def irfft(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, name: str = None, normalize: bool = True): - fft(buffer, buffer_shape=buffer.real_shape, graph=graph, print_shader=print_shader, name=name, inverse=True, normalize_inverse=normalize, r2c=True) - -def irfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True): + rfft(buffer, graph=graph, print_shader=print_shader, compute_type=compute_type) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=1, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) + fft( + buffer, + graph=graph, + print_shader=print_shader, + axis=0, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) + +def irfft( + buffer: vd.RFFTBuffer, + graph: vd.CommandGraph = None, + print_shader: bool = False, + name: str = None, + normalize: bool = True, + compute_type: vd.dtype = None, +): + fft( + buffer, + buffer_shape=buffer.real_shape, + graph=graph, + print_shader=print_shader, + name=name, + inverse=True, + normalize_inverse=normalize, + r2c=True, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) + +def irfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, compute_type: vd.dtype = None): assert len(buffer.real_shape) == 2 or len(buffer.real_shape) == 3, 'Buffer must have 2 or 3 dimensions' - ifft(buffer, graph=graph, print_shader=print_shader, axis=len(buffer.real_shape) - 2, normalize=normalize) - irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=len(buffer.real_shape) - 2, + normalize=normalize, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) + irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize, compute_type=compute_type) -def irfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True): +def irfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, compute_type: vd.dtype = None): assert len(buffer.real_shape) == 3, 'Buffer must have 3 dimensions' - ifft(buffer, graph=graph, print_shader=print_shader, axis=0, normalize=normalize) - ifft(buffer, graph=graph, print_shader=print_shader, axis=1, normalize=normalize) - irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=0, + normalize=normalize, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) + ifft( + buffer, + graph=graph, + print_shader=print_shader, + axis=1, + normalize=normalize, + output_type=buffer.var_type, + input_type=buffer.var_type, + compute_type=compute_type, + ) + irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize, compute_type=compute_type) def convolve( *buffers: vd.Buffer, @@ -117,19 +475,62 @@ def convolve( axis: int = None, normalize: bool = True, name: str = None, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, input_map: vd.MappingFunction = None, - output_map: vd.MappingFunction = None): + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + kernel_type: vd.dtype = None, + compute_type: vd.dtype = None, + input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None): + assert len(buffers) >= 1, "At least one buffer must be provided" + + if kernel_map is None and len(buffers) < 2: + raise ValueError("convolve() requires at least an output buffer and kernel buffer") + if buffer_shape is None: buffer_shape = buffers[0].shape + resolved_output_type = _resolve_output_precision(buffers, output_map, output_type) + resolved_input_type = _resolve_input_precision(buffers, input_map, output_map, input_type, resolved_output_type) + resolved_kernel_type = _resolve_kernel_precision(buffers, kernel_map, kernel_type) + + io_precisions: List[vd.dtype] = [] + + if output_map is None: + io_precisions.append(resolved_output_type) + else: + _validate_map_argument_annotations(output_map, "output_map") + + if input_map is None: + if resolved_input_type is not None: + io_precisions.append(resolved_input_type) + else: + _validate_map_argument_annotations(input_map, "input_map") + + if kernel_map is None: + io_precisions.append(resolved_kernel_type) + else: + _validate_map_argument_annotations(kernel_map, "kernel_map") + + resolved_compute_type = resolve_compute_precision(io_precisions, compute_type) + fft_shader = make_convolution_shader( tuple(buffer_shape), kernel_map, kernel_num, axis, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, normalize=normalize, input_map=input_map, - output_map=output_map) + output_map=output_map, + input_type=resolved_input_type, + output_type=resolved_output_type, + kernel_type=resolved_kernel_type, + compute_type=resolved_compute_type, + input_signal_range=input_signal_range) if print_shader: print(fft_shader) @@ -144,8 +545,14 @@ def convolve2D( graph: vd.CommandGraph = None, print_shader: bool = False, normalize: bool = True, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, input_map: vd.MappingFunction = None, - output_map: vd.MappingFunction = None): + output_map: vd.MappingFunction = None, + output_type: vd.dtype = None, + input_type: vd.dtype = None, + kernel_type: vd.dtype = None, + compute_type: vd.dtype = None): assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' @@ -158,21 +565,135 @@ def convolve2D( if output_map is not None: output_buffers.append(buffer) - fft(*input_buffers, graph=graph, print_shader=print_shader, input_map=input_map) - convolve(buffer, kernel, kernel_map=kernel_map, buffer_shape=buffer_shape, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 2, normalize=normalize) - ifft(*output_buffers, graph=graph, print_shader=print_shader, normalize=normalize, output_map=output_map) + fft( + *input_buffers, + graph=graph, + print_shader=print_shader, + input_map=input_map, + output_type=output_type, + input_type=input_type, + compute_type=compute_type, + ) + convolve( + buffer, + kernel, + kernel_map=kernel_map, + buffer_shape=buffer_shape, + graph=graph, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + print_shader=print_shader, + axis=len(buffer.shape) - 2, + normalize=normalize, + output_type=output_type, + input_type=input_type, + kernel_type=kernel_type, + compute_type=compute_type, + ) + ifft( + *output_buffers, + graph=graph, + print_shader=print_shader, + normalize=normalize, + output_map=output_map, + output_type=output_type if output_map is None else None, + input_type=input_type if output_map is None else None, + compute_type=compute_type, + ) def convolve2DR( buffer: vd.RFFTBuffer, kernel: vd.RFFTBuffer, kernel_map: vd.MappingFunction = None, buffer_shape: Tuple = None, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, graph: vd.CommandGraph = None, print_shader: bool = False, - normalize: bool = True): + normalize: bool = True, + compute_type: vd.dtype = None): assert len(buffer.shape) == 2 or len(buffer.shape) == 3, 'Buffer must have 2 or 3 dimensions' - rfft(buffer, graph=graph, print_shader=print_shader) - convolve(buffer, kernel, kernel_map=kernel_map, buffer_shape=buffer_shape, graph=graph, print_shader=print_shader, axis=len(buffer.shape) - 2, normalize=normalize) - irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize) \ No newline at end of file + rfft(buffer, graph=graph, print_shader=print_shader, compute_type=compute_type) + convolve( + buffer, + kernel, + kernel_map=kernel_map, + buffer_shape=buffer_shape, + graph=graph, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + print_shader=print_shader, + axis=len(buffer.shape) - 2, + normalize=normalize, + output_type=buffer.var_type, + input_type=buffer.var_type, + kernel_type=kernel.var_type, + compute_type=compute_type, + ) + irfft(buffer, graph=graph, print_shader=print_shader, normalize=normalize, compute_type=compute_type) + +def transpose( + in_buffer: vd.Buffer, + conv_shape: Tuple = None, + axis: int = None, + out_buffer: vd.Buffer = None, + graph: vd.CommandGraph = None, + kernel_inner_only: bool = False, + print_shader: bool = False, + input_type: vd.dtype = None, + output_type: vd.dtype = None, + compute_type: vd.dtype = None) -> vd.Buffer: + + resolved_input_type = in_buffer.var_type if input_type is None else input_type + validate_complex_precision(resolved_input_type, arg_name="input_type") + ensure_supported_complex_precision(resolved_input_type, role="Input") + + resolved_output_type = ( + out_buffer.var_type if (out_buffer is not None and output_type is None) + else in_buffer.var_type if output_type is None + else output_type + ) + validate_complex_precision(resolved_output_type, arg_name="output_type") + ensure_supported_complex_precision(resolved_output_type, role="Output") + + resolved_compute_type = resolve_compute_precision( + [resolved_input_type, resolved_output_type], + compute_type, + ) + + transposed_size = get_transposed_size( + tuple(in_buffer.shape), + axis=axis, + compute_type=resolved_compute_type, + ) + + if out_buffer is None: + out_buffer = vd.Buffer((transposed_size,), var_type=resolved_output_type) + else: + if out_buffer.var_type != resolved_output_type: + raise ValueError( + f"out_buffer type ({out_buffer.var_type.name}) does not match output_type ({resolved_output_type.name})" + ) + + assert out_buffer.size >= transposed_size, f"Output buffer size {out_buffer.size} does not match expected transposed size {transposed_size}" + + if conv_shape is None: + conv_shape = in_buffer.shape + + transpose_shader = make_transpose_shader( + tuple(conv_shape), + axis=axis, + kernel_inner_only=kernel_inner_only, + input_type=resolved_input_type, + output_type=resolved_output_type, + compute_type=resolved_compute_type, + ) + + if print_shader: + print(transpose_shader) + + transpose_shader(out_buffer, in_buffer, graph=graph) + + return out_buffer diff --git a/vkdispatch/fft/global_memory_iterators.py b/vkdispatch/fft/global_memory_iterators.py new file mode 100644 index 00000000..c621f6b6 --- /dev/null +++ b/vkdispatch/fft/global_memory_iterators.py @@ -0,0 +1,334 @@ +import vkdispatch.codegen as vc + +from typing import Optional, Tuple + +import dataclasses + +from .registers import FFTRegisters +from .memory_iterators import memory_reads_iterator, memory_writes_iterator, MemoryOp + + +def _cast_if_needed(value: vc.ShaderVariable, dst_type): + if value.var_type == dst_type: + return value + + return vc.to_dtype(dst_type, value) + +def global_batch_offset( + registers: FFTRegisters, + r2c: bool = False, + is_output: bool = None, + inverse: bool = None, + inner_only: bool = False) -> vc.ShaderVariable: + config = registers.config + grid = registers.resources.grid + + if inner_only: + return grid.global_inner_offset + + outer_batch_stride = config.N * config.fft_stride + + if r2c: + assert inverse is not None, "Must specify inverse for r2c io" + assert is_output is not None, "Must specify is_output for r2c io" + assert config.fft_stride == 1, "R2C io only supported for contiguous data" + + outer_batch_stride = (config.N // 2) + 1 + + # for inverse-r2c write and forward-r2c read, the + # outer batch stride is doubled, since we are writting + # floats and not vec2s + if inverse == is_output: + outer_batch_stride *= 2 + + return grid.global_outer_offset * outer_batch_stride + grid.global_inner_offset + +@dataclasses.dataclass +class GlobalWriteOp(MemoryOp): + register: vc.ShaderVariable + io_index: vc.ShaderVariable + r2c: bool + inverse: Optional[bool] + + @classmethod + def from_memory_op(cls, + base: MemoryOp, + register: vc.ShaderVariable, + io_index: vc.ShaderVariable, + r2c: bool, + inverse: Optional[bool] = None) -> 'GlobalWriteOp': + return cls(**vars(base), + register=register, + io_index=io_index, + r2c=r2c, + inverse=inverse) + + def write_to_buffer(self, + buffer: vc.Buffer, + register: Optional[vc.ShaderVariable] = None, + io_index: Optional[vc.ShaderVariable] = None): + if register is None: + register = self.register + + if io_index is None: + io_index = self.io_index + + if not self.r2c: + buffer[io_index] = _cast_if_needed(register, buffer.var_type) + return + + if not self.inverse: + vc.if_statement(self.fft_index < (self.fft_size // 2) + 1) + buffer[io_index] = _cast_if_needed(register, buffer.var_type) + vc.end() + return + + out_scalar_type = buffer.var_type.child_type + out_real = _cast_if_needed(register.real, out_scalar_type) + buffer[io_index // 2][io_index % 2] = out_real + +def global_writes_iterator( + registers: FFTRegisters, + r2c: bool = False, + inverse: bool = None): + + extra_comment_lines = "" + + if r2c: + assert inverse is not None, "Must specify inverse for r2c io" + + if inverse: + extra_comment_lines = "\nDoing R2C inverse write, applying Hermitian reconstruction and packed-real rules as needed." + else: + extra_comment_lines = "\nDoing R2C forward write, applying Hermitian-half truncation and packed-real rules as needed." + + vc.comment(f"""Writing register-resident FFT outputs to global memory. +Addressing uses computed batch offsets plus FFT-lane stride.{extra_comment_lines}""") + + resources = registers.resources + config = registers.config + + resources.output_batch_offset[:] = global_batch_offset(registers, r2c=r2c, is_output=True, inverse=inverse) + + for write_op in memory_writes_iterator(resources, -1): + resources.io_index[:] = resources.output_batch_offset + write_op.fft_index * config.fft_stride + + global_write_op = GlobalWriteOp.from_memory_op( + base=write_op, + register=registers[write_op.register_id], + io_index=resources.io_index, + r2c=r2c, + inverse=inverse + ) + + yield global_write_op + +@dataclasses.dataclass +class GlobalReadOp(MemoryOp): + register: vc.ShaderVariable + io_index: vc.ShaderVariable + io_index_2: vc.ShaderVariable + r2c: bool + inverse: Optional[bool] + r2c_inverse_offset: vc.ShaderVariable + format_transposed: bool + signal_range: Tuple[int, int] + + @classmethod + def from_memory_op(cls, + base: MemoryOp, + register: vc.ShaderVariable, + io_index: vc.ShaderVariable, + io_index_2: vc.ShaderVariable, + r2c: bool, + inverse: Optional[bool], + r2c_inverse_offset: vc.ShaderVariable, + format_transposed: bool, + signal_range: Tuple[int, int]) -> 'GlobalReadOp': + return cls(**vars(base), + register=register, + io_index=io_index, + io_index_2=io_index_2, + r2c=r2c, + inverse=inverse, + r2c_inverse_offset=r2c_inverse_offset, + format_transposed=format_transposed, + signal_range=signal_range + ) + + def check_in_signal_range(self) -> bool: + if self.signal_range == (0, self.fft_size): + return + + if self.signal_range[0] == 0: + vc.if_statement(self.fft_index < self.signal_range[1]) + return + + if self.signal_range[1] == self.fft_size: + vc.if_statement(self.fft_index >= self.signal_range[0]) + return + + vc.if_all(self.fft_index >= self.signal_range[0], self.fft_index < self.signal_range[1]) + + def signal_range_end(self, register: vc.ShaderVariable): + if self.signal_range == (0, self.fft_size): + return + + vc.else_statement() + register[:] = vc.to_dtype(register.var_type, 0) + vc.end() + + def read_from_buffer(self, + buffer: vc.Buffer, + register: Optional[vc.ShaderVariable] = None, + io_index: Optional[vc.ShaderVariable] = None): + self.check_in_signal_range() + + if io_index is None: + io_index = self.io_index + + if register is None: + register = self.register + + if not self.r2c: + register[:] = _cast_if_needed(buffer[io_index], register.var_type) + self.signal_range_end(register) + return + + if not self.inverse: + packed_real = buffer[io_index // 2][io_index % 2] + packed_complex = vc.to_complex(packed_real) + register[:] = _cast_if_needed(packed_complex, register.var_type) + self.signal_range_end(register) + return + + vc.if_statement(self.fft_index >= (self.fft_size // 2) + 1) + self.io_index_2[:] = self.r2c_inverse_offset - io_index + register[:] = _cast_if_needed(buffer[self.io_index_2], register.var_type) + register.imag = -register.imag + vc.else_statement() + register[:] = _cast_if_needed(buffer[io_index], register.var_type) + vc.end() + + self.signal_range_end(register) + +def resolve_signal_range( + signal_range: Optional[Tuple[Optional[int], Optional[int]]], + N: int) -> Tuple[int, int]: + if signal_range is None: + return 0, N + + start, end = signal_range + + if start is None: + start = 0 + if end is None: + end = N + + return start, end + +def global_reads_iterator( + registers: FFTRegisters, + r2c: bool = False, + inverse: bool = None, + format_transposed: bool = False, + inner_only: bool = False, + signal_range: Optional[Tuple[Optional[int], Optional[int]]] = None): + + signal_range = resolve_signal_range(signal_range, registers.config.N) + + transpose_comment_str = "" + if format_transposed: + transpose_comment_str = "\nReading in transposed format, using grid-mapped indices." + + signal_range_comment_str = "" + if signal_range != (0, registers.config.N): + signal_range_comment_str = f"\nApplying signal-range masking for FFT lanes outside [{signal_range[0]}, {signal_range[1]})." + + r2c_comment_str = "" + if r2c: + if inverse: + r2c_comment_str = "\nDoing R2C inverse read, applying Hermitian reconstruction and packed-real rules as needed." + else: + r2c_comment_str = "\nDoing R2C forward read, applying packed-real format rules as needed." + + vc.comment(f"""Reading input samples from global memory into FFT registers.{transpose_comment_str}{signal_range_comment_str}{r2c_comment_str}""") + + if r2c: + assert not format_transposed, "R2C transposed format not supported" + + resources = registers.resources + config = registers.config + + r2c_inverse_offset = None + + if not format_transposed: + resources.input_batch_offset[:] = global_batch_offset(registers, r2c=r2c, is_output=False, inverse=inverse, inner_only=inner_only) + r2c_inverse_offset = 2 * resources.input_batch_offset + config.N * config.fft_stride + + for read_op in memory_reads_iterator(resources, 0): + if format_transposed: + resources.io_index[:] = resources.grid.get_transposed_index(read_op.register_id, inner_only=inner_only) + else: + resources.io_index[:] = resources.input_batch_offset + read_op.fft_index * config.fft_stride + + global_read_op = GlobalReadOp.from_memory_op( + base=read_op, + register=registers[read_op.register_id], + io_index=resources.io_index, + io_index_2=resources.io_index_2, + r2c=r2c, + inverse=inverse, + r2c_inverse_offset=r2c_inverse_offset, + format_transposed=format_transposed, + signal_range=signal_range + ) + + yield global_read_op + + +@dataclasses.dataclass +class GlobalTransposedWriteOp(MemoryOp): + register: vc.ShaderVariable + io_index: vc.ShaderVariable + + @classmethod + def from_memory_op(cls, + base: MemoryOp, + register: vc.ShaderVariable, + io_index: vc.ShaderVariable) -> 'GlobalTransposedWriteOp': + return cls(**vars(base), + register=register, + io_index=io_index + ) + + def write_to_buffer(self, + buffer: vc.Buffer, + register: Optional[vc.ShaderVariable] = None, + io_index: Optional[vc.ShaderVariable] = None): + if io_index is None: + io_index = self.io_index + + if register is None: + register = self.register + + buffer[io_index] = _cast_if_needed(register, buffer.var_type) + +def global_trasposed_write_iterator(registers: FFTRegisters, inner_only: bool = False): + vc.comment("""Writing registers to global memory in transposed order. +Indices come from the grid transposition map. +This produces axis-swapped, coalesced tiles for downstream kernels without +an additional reorder pass.""") + + resources = registers.resources + + for read_op in memory_reads_iterator(resources, 0): # Iterate in read order to match register format when reading + resources.io_index[:] = resources.grid.get_transposed_index(read_op.register_id, inner_only=inner_only) + + global_trasposed_write_op = GlobalTransposedWriteOp.from_memory_op( + base=read_op, + register=registers[read_op.register_id], + io_index=resources.io_index + ) + + yield global_trasposed_write_op diff --git a/vkdispatch/fft/grid_manager.py b/vkdispatch/fft/grid_manager.py new file mode 100644 index 00000000..5d6aa4e9 --- /dev/null +++ b/vkdispatch/fft/grid_manager.py @@ -0,0 +1,262 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc + +from typing import Optional, Tuple, Union, Literal + +from .config import FFTConfig +from .prime_utils import prime_factors + +from ..compat import numpy_compat as npc + +def allocation_valid(workgroup_size: int, shared_memory_size: int): + valid_workgroup = workgroup_size <= vd.get_context().max_workgroup_invocations + valid_shared_memory = shared_memory_size <= vd.get_context().max_shared_memory + return valid_workgroup and valid_shared_memory + +def allocate_inline_batches( + batch_num: int, + batch_threads: int, + shared_elements: int, + element_size: int, + max_workgroup_size: int, + max_total_threads: int): + + shared_memory_allocation = shared_elements * element_size + batch_num_primes = prime_factors(batch_num) + prime_index = 0 + workgroup_size = batch_threads + inline_batches = 1 + + while allocation_valid(workgroup_size, shared_memory_allocation) and \ + prime_index < len(batch_num_primes) and \ + inline_batches <= max_workgroup_size and \ + workgroup_size <= max_total_threads: + + test_prime = batch_num_primes[prime_index] + + is_valid = allocation_valid(workgroup_size * test_prime, shared_memory_allocation * test_prime) + + is_valid = is_valid and inline_batches * test_prime <= max_workgroup_size + is_valid = is_valid and workgroup_size * test_prime <= max_total_threads + + if is_valid: + workgroup_size *= test_prime + shared_memory_allocation *= test_prime + inline_batches *= test_prime + + prime_index += 1 + + return inline_batches + +def set_to_multiple_with_max(count, max_count): + if count <= max_count: + return count + + count_primes = prime_factors(count) + + result_count = 1 + for prime in count_primes: + if result_count * prime > max_count: + break + result_count *= prime + + return result_count + +def allocate_workgroups(total_count: int, declare_variables: bool = True) -> Tuple[vc.ShaderVariable, Tuple[int, int, int]]: + workgroups_x = set_to_multiple_with_max( + total_count, + vd.get_context().max_workgroup_count[0] + ) + workgroups_y = 1 + workgroups_z = 1 + + if not declare_variables: + return None, (workgroups_x, workgroups_y, workgroups_z) + + workgroup_index = vc.new_uint_register( + vc.workgroup_id().x, + var_name="workgroup_index" + ) + + if workgroups_x != total_count: + workgroups_y = set_to_multiple_with_max( + total_count // workgroups_x, + vd.get_context().max_workgroup_count[1] + ) + + workgroup_index += workgroups_x * vc.workgroup_id().y + + if workgroups_y != total_count // workgroups_x: + workgroups_z = set_to_multiple_with_max( + total_count // (workgroups_x * workgroups_y), + vd.get_context().max_workgroup_count[2] + ) + + workgroup_index += workgroups_x * workgroups_y * vc.workgroup_id().z + + return workgroup_index, (workgroups_x, workgroups_y, workgroups_z) + +def decompose_workgroup_index( + workgroup_index: vc.ShaderVariable, + inner_batch_count: int, + fft_threads: int, + local_size: Tuple[int, int, int]) -> Tuple[vc.ShaderVariable, vc.ShaderVariable]: + + if inner_batch_count == None: + if fft_threads == 1: + return None, workgroup_index * local_size[0] + vc.local_invocation_id().x + + return None, workgroup_index * local_size[1] + vc.local_invocation_id().y + + global_inner_offset = vc.new_uint_register( + (workgroup_index % inner_batch_count) * local_size[0] + vc.local_invocation_id().x, + var_name="global_inner_index" + ) + + global_outer_offset = vc.new_uint_register( + (workgroup_index // inner_batch_count) * local_size[2] + vc.local_invocation_id().z, + var_name="global_outer_index" + ) + + return global_inner_offset, global_outer_offset + +class FFTGridManager: + config: FFTConfig + + shared_memory_enabled: bool + shared_memory_allocation: int + + inline_batches_inner: int + inline_batches_outer: int + + local_inner: Optional[vc.ShaderVariable] + local_outer: vc.ShaderVariable + + tid: vc.ShaderVariable + + global_inner_offset: Union[vc.ShaderVariable, Literal[0]] + global_outer_offset: vc.ShaderVariable + + local_size: Tuple[int, int, int] + workgroup_count: Tuple[int, int, int] + exec_size: Tuple[int, int, int] + + workgroup_index: vc.ShaderVariable + + transposed_offset: Optional[vc.ShaderVariable] + transposed_stride: int + + transposed_inner_offset: Optional[vc.ShaderVariable] + transposed_inner_stride: int + + def __init__(self, config: FFTConfig, force_sdata: bool = False, declare_variables: bool = True): + self.config = config + + make_sdata_buffer = config.batch_threads > 1 or force_sdata + + self.inline_batches_inner = allocate_inline_batches( + config.batch_inner_count, + config.batch_threads, + config.sdata_allocation if make_sdata_buffer else 0, + config.compute_type.item_size, + min(vd.get_context().max_workgroup_size[0], 4), + vd.get_context().max_workgroup_invocations) + + max_inline_outer_batches = vd.get_context().max_workgroup_size[ + 1 if config.batch_inner_count == 1 else 2 + ] + + # For some reason it's better not to have too many inline outer batches + max_inline_outer_batches = min(max_inline_outer_batches, vd.get_context().subgroup_size) + + self.inline_batches_outer = allocate_inline_batches( + config.batch_outer_count, + config.batch_threads * self.inline_batches_inner, + config.sdata_allocation * self.inline_batches_inner if make_sdata_buffer else 0, + config.compute_type.item_size, + vd.get_context().max_workgroup_size[ + 1 if self.inline_batches_inner == 1 else 2 + ], + max_inline_outer_batches) + + + if config.batch_inner_count > 1: + self.local_size = (self.inline_batches_inner, config.batch_threads, self.inline_batches_outer) + + inner_workgroups = config.batch_inner_count // self.inline_batches_inner + outer_workgroups = config.batch_outer_count // self.inline_batches_outer + + self.workgroup_index, self.workgroup_count = allocate_workgroups( + inner_workgroups * outer_workgroups, + declare_variables=declare_variables + ) + + if declare_variables: + self.local_inner = vc.local_invocation_id().x + self.local_outer = vc.local_invocation_id().z + + self.global_inner_offset, self.global_outer_offset = decompose_workgroup_index( + self.workgroup_index, + inner_workgroups, + config.batch_threads, + self.local_size + ) + + self.tid = vc.local_invocation_id().y.to_register("tid") + else: + self.local_inner = None + self.global_inner_offset = 0 + + if config.batch_threads > 1: + self.local_size = (config.batch_threads, self.inline_batches_outer, 1) + else: + self.local_size = (self.inline_batches_outer, 1, 1) + + self.workgroup_index, self.workgroup_count = allocate_workgroups( + config.batch_outer_count // self.inline_batches_outer, + declare_variables=declare_variables + ) + + if declare_variables: + if config.batch_threads > 1: + self.tid = vc.local_invocation_id().x.to_register("tid") + self.local_outer = vc.local_invocation_id().y + else: + self.tid = 0 + self.local_outer = vc.local_invocation_id().x + + _, self.global_outer_offset = decompose_workgroup_index( + self.workgroup_index, + None, + config.batch_threads, + self.local_size + ) + + self.exec_size = ( + self.local_size[0] * self.workgroup_count[0], + self.local_size[1] * self.workgroup_count[1], + self.local_size[2] * self.workgroup_count[2] + ) + + if not declare_variables: + return + + self.transposed_stride = npc.prod(self.local_size) + self.transposed_offset = vc.local_invocation_index() + self.transposed_stride * self.config.register_count * self.workgroup_index + + self.transposed_inner_stride = None + self.transposed_inner_offset = None + + if config.batch_inner_count > 1: + self.transposed_inner_stride = self.local_size[0] * self.local_size[1] + self.transposed_inner_offset = vc.local_invocation_id().x + self.local_size[0] * vc.local_invocation_id().y + \ + self.transposed_inner_stride * self.config.register_count * (self.workgroup_index % inner_workgroups) + else: + self.transposed_inner_stride = self.local_size[0] + self.transposed_inner_offset = vc.local_invocation_id().x + + def get_transposed_index(self, register_id: int, inner_only: bool = False) -> vc.ShaderVariable: + if not inner_only: + return self.transposed_offset + register_id * self.transposed_stride + + return self.transposed_inner_offset + register_id * self.transposed_inner_stride diff --git a/vkdispatch/fft/io_manager.py b/vkdispatch/fft/io_manager.py index 5807b440..b91d6bd9 100644 --- a/vkdispatch/fft/io_manager.py +++ b/vkdispatch/fft/io_manager.py @@ -1,26 +1,78 @@ import vkdispatch as vd import vkdispatch.codegen as vc +import vkdispatch.base.dtype as dtypes -from typing import Optional +from typing import Optional, Tuple + +import threading from .io_proxy import IOProxy +from .registers import FFTRegisters +from .global_memory_iterators import global_writes_iterator, global_reads_iterator +from .global_memory_iterators import GlobalWriteOp, GlobalReadOp + +_write_op = threading.local() +_read_op = threading.local() + +def _get_write_op() -> Optional[GlobalWriteOp]: + return getattr(_write_op, 'op', None) + +def _get_read_op() -> Optional[GlobalReadOp]: + return getattr(_read_op, 'op', None) + +def write_op() -> GlobalWriteOp: + op = _get_write_op() + assert op is not None, "No global write operation is set for the current thread!" + return op + +def read_op() -> GlobalReadOp: + op = _get_read_op() + assert op is not None, "No global read operation is set for the current thread!" + return op + +def set_write_op(op: GlobalWriteOp): + if op is None: + _write_op.op = None + return + + assert _get_write_op() is None, "A global write operation is already set for the current thread!" + _write_op.op = op + +def set_read_op(op: GlobalReadOp): + if op is None: + _read_op.op = None + return + + assert _get_read_op() is None, "A global read operation is already set for the current thread!" + _read_op.op = op class IOManager: + default_registers: FFTRegisters output_proxy: IOProxy input_proxy: IOProxy kernel_proxy: IOProxy - signature: vd.ShaderSignature - def __init__(self, - builder: vc.ShaderBuilder, - output: Optional[vd.MappingFunction], - input: Optional[vd.MappingFunction] = None, - kernel: Optional[vd.MappingFunction] = None): + default_registers: FFTRegisters, + shader_context: vd.ShaderContext, + output_map: Optional[vd.MappingFunction], + output_type: dtypes.dtype = vd.complex64, + input_type: Optional[dtypes.dtype] = None, + input_map: Optional[vd.MappingFunction] = None, + kernel_map: Optional[vd.MappingFunction] = None): + self.default_registers = default_registers + self.output_proxy = IOProxy(output_type if output_map is None else output_map, "Output") + + if input_map is not None: + self.input_proxy = IOProxy(input_map, "Input") + elif output_map is not None: + if input_type is None: + raise ValueError("input_type must be provided when output_map is used without input_map") + self.input_proxy = IOProxy(input_type, "Input") + else: + self.input_proxy = IOProxy(None, "Input") - self.output_proxy = IOProxy(vd.complex64 if output is None else output, "Output") - self.input_proxy = IOProxy(input, "Input") - self.kernel_proxy = IOProxy(kernel, "Kernel") + self.kernel_proxy = IOProxy(kernel_map, "Kernel") output_types = self.output_proxy.buffer_types input_types = self.input_proxy.buffer_types @@ -31,8 +83,7 @@ def __init__(self, if len(all_types) == 0: raise ValueError("A big error happened") - self.signature = vd.ShaderSignature.from_type_annotations(builder, all_types) - sig_vars = self.signature.get_variables() + sig_vars = shader_context.declare_input_arguments(all_types) output_count = len(output_types) input_count = len(input_types) @@ -43,3 +94,85 @@ def __init__(self, if input_count == 0: self.input_proxy = self.output_proxy + + def read_from_proxy(self, + proxy: IOProxy, + registers: Optional[FFTRegisters] = None, + r2c: bool = False, + inverse: bool = None, + format_transposed: bool = False, + inner_only: bool = False, + signal_range: Optional[Tuple[Optional[int], Optional[int]]] = None): + + if registers is None: + registers = self.default_registers + + for read_op in global_reads_iterator( + registers=registers, + r2c=r2c, + inverse=inverse, + format_transposed=format_transposed, + inner_only=inner_only, + signal_range=signal_range + ): + + if proxy.has_callback(): + set_read_op(read_op) + proxy.do_callback() + set_read_op(None) + else: + read_op.read_from_buffer(proxy.buffer_variables[0]) + + def write_to_proxy(self, + proxy: IOProxy, + registers: Optional[FFTRegisters] = None, + r2c: bool = False, + inverse: bool = None): + + if registers is None: + registers = self.default_registers + + for write_op in global_writes_iterator( + registers=registers, + r2c=r2c, + inverse=inverse + ): + + if proxy.has_callback(): + set_write_op(write_op) + proxy.do_callback() + set_write_op(None) + else: + write_op.write_to_buffer(proxy.buffer_variables[0]) + + def read_input(self, + registers: Optional[FFTRegisters] = None, + r2c: bool = False, + inverse: bool = None, + signal_range: Optional[Tuple[Optional[int], Optional[int]]] = None): + self.read_from_proxy( + self.input_proxy, + registers, + r2c=r2c, + inverse=inverse, + signal_range=signal_range + ) + + def write_output(self, + registers: Optional[FFTRegisters] = None, + r2c: bool = False, + inverse: bool = None): + self.write_to_proxy( + self.output_proxy, + registers, + r2c=r2c, + inverse=inverse + ) + + def read_kernel(self, registers: Optional[FFTRegisters] = None, format_transposed: bool = False, inner_only: bool = False): + self.read_from_proxy( + self.kernel_proxy, + registers, + format_transposed=format_transposed, + inner_only=inner_only + ) diff --git a/vkdispatch/fft/io_proxy.py b/vkdispatch/fft/io_proxy.py index f6674176..5744b1ba 100644 --- a/vkdispatch/fft/io_proxy.py +++ b/vkdispatch/fft/io_proxy.py @@ -43,89 +43,9 @@ def set_variables(self, vars: List[vc.Buffer]) -> None: self.buffer_variables = vars - def read(self, - register: vc.ShaderVariable, - memory_index: vc.ShaderVariable, - spare_register: vc.ShaderVariable = None, - r2c: bool = False) -> vc.ShaderVariable: - assert self.enabled, f"{self.name} IOProxy is not enabled" - - if self.map_func is not None: - assert spare_register is not None, "Spare register must be provided when using a mapping function" - - vc.set_mapping_index(memory_index) - vc.set_mapping_registers([register, spare_register]) - - self.map_func.callback(*self.buffer_variables) - - return - - if not r2c: - register[:] = self.buffer_variables[0][memory_index] - return - - real_value = self.buffer_variables[0][memory_index / 2][memory_index % 2] - register[:] = f"vec2({real_value}, 0)" - - def read_r2c_inverse(self, - register: vc.ShaderVariable, - memory_index: vc.ShaderVariable, - fft_index: vc.ShaderVariable, - spare_index: vc.ShaderVariable, - input_batch_offset: vc.ShaderVariable, - fft_size: int, - fft_stride: int) -> vc.ShaderVariable: - assert self.enabled, f"{self.name} IOProxy is not enabled" - - assert self.map_func is None, "Mapping functions do not support inverse r2c operations" - - vc.if_statement(fft_index >= (fft_size // 2) + 1) - spare_index[:] = 2 * input_batch_offset + fft_size * fft_stride - memory_index - register[:] = self.buffer_variables[0][spare_index] - register.y = -register.y - vc.else_statement() - register[:] = self.buffer_variables[0][memory_index] - vc.end() - - def write(self, - register: vc.ShaderVariable, - memory_index: vc.ShaderVariable, - r2c: bool = False, - inverse: bool = False, - fft_index: vc.ShaderVariable = None, - fft_size: int = None) -> vc.ShaderVariable: - assert self.enabled, f"{self.name} IOProxy is not enabled" - - if self.map_func is not None: - - if not inverse and r2c: - assert fft_size is not None, "FFT size must be provided for forward r2c write" - assert fft_index is not None, "FFT index must be provided for forward r2c write" - - vc.if_statement(fft_index < (fft_size // 2) + 1) - - vc.set_mapping_index(memory_index) - vc.set_mapping_registers([register]) - self.map_func.callback(*self.buffer_variables) - - if not inverse and r2c: - vc.end() - - return - - if not r2c: - self.buffer_variables[0][memory_index] = register - return - - if not inverse: - assert fft_size is not None, "FFT size must be provided for forward r2c write" - assert fft_index is not None, "FFT index must be provided for forward r2c write" - - vc.if_statement(fft_index < (fft_size // 2) + 1) - self.buffer_variables[0][memory_index] = register - vc.end() - return - + def has_callback(self) -> bool: + return self.map_func is not None - self.buffer_variables[0][memory_index / 2][memory_index % 2] = register.x - \ No newline at end of file + def do_callback(self): + assert self.map_func is not None, "IOProxy does not have a mapping function" + self.map_func.callback(*self.buffer_variables) diff --git a/vkdispatch/fft/manager.py b/vkdispatch/fft/manager.py deleted file mode 100644 index ed65d79f..00000000 --- a/vkdispatch/fft/manager.py +++ /dev/null @@ -1,59 +0,0 @@ -import vkdispatch as vd -import vkdispatch.codegen as vc - -from typing import Optional, Tuple, Union - -from .io_manager import IOManager -from .config import FFTConfig -from .resources import FFTResources, allocate_fft_resources - -class FFTCallable: - shader_object: vd.ShaderObject - exec_size: Tuple[int, int, int] - - def __init__(self, shader_object: vd.ShaderObject, exec_size: Tuple[int, int, int]): - self.shader_object = shader_object - self.exec_size = exec_size - - def __call__(self, *args, **kwargs): - self.shader_object(*args, exec_size=self.exec_size, **kwargs) - - def __repr__(self): - return repr(self.shader_object) - -class FFTManager: - builder: vc.ShaderBuilder - io_manager: IOManager - config: FFTConfig - resources: FFTResources - fft_callable: FFTCallable - name: str - - def __init__(self, - builder: vc.ShaderBuilder, - buffer_shape: Tuple, - axis: int = None, - max_register_count: int = None, - output_map: Union[vd.MappingFunction, type, None] = None, - input_map: Union[vd.MappingFunction, type, None] = None, - kernel_map: Union[vd.MappingFunction, type, None] = None, - name: str = None): - self.builder = builder - self.io_manager = IOManager(builder, output_map, input_map, kernel_map) - self.config = FFTConfig(buffer_shape, axis, max_register_count) - self.resources = allocate_fft_resources(self.config, True) - self.fft_callable = None - self.name = name if name is not None else f"fft_shader_{buffer_shape}_{axis}" - - def compile_shader(self): - self.fft_callable = FFTCallable(vd.ShaderObject( - self.builder.build(self.name), - self.io_manager.signature, - local_size=self.resources.local_size - ), - self.resources.exec_size - ) - - def get_callable(self) -> FFTCallable: - assert self.fft_callable is not None, "Shader not compiled yet... something is wrong" - return self.fft_callable diff --git a/vkdispatch/fft/memory_io.py b/vkdispatch/fft/memory_io.py deleted file mode 100644 index 5727fb91..00000000 --- a/vkdispatch/fft/memory_io.py +++ /dev/null @@ -1,182 +0,0 @@ -import vkdispatch as vd -import vkdispatch.codegen as vc -from vkdispatch.codegen.abreviations import * - -from typing import List, Tuple, Optional - -from .resources import FFTResources -from .config import FFTRegisterStageConfig, FFTParams - -from .io_proxy import IOProxy - -import dataclasses - -@dataclasses.dataclass -class FFTRegisterStageInvocation: - stage: FFTRegisterStageConfig - output_stride: int - block_width: int - inner_block_offset: int - block_index: int - sub_sequence_offset: int - register_selection: slice - - def __init__(self, stage: FFTRegisterStageConfig, output_stride: int, instance_index: int, tid: vc.ShaderVariable, N: int): - self.stage = stage - self.output_stride = output_stride - - self.block_width = output_stride * stage.fft_length - - instance_index_stride = N // (stage.fft_length * stage.instance_count) - - self.instance_id = tid + instance_index_stride * instance_index - - self.inner_block_offset = self.instance_id % output_stride - - if output_stride == 1: - self.inner_block_offset = 0 - - self.sub_sequence_offset = self.instance_id * stage.fft_length - self.inner_block_offset * (stage.fft_length - 1) - - if self.block_width == N: - self.inner_block_offset = self.instance_id - self.sub_sequence_offset = self.inner_block_offset - - self.register_selection = slice(instance_index * stage.fft_length, (instance_index + 1) * stage.fft_length) - -def load_sdata_state_to_registers( - resources: FFTResources, - params: FFTParams, - offset: Const[u32], - stride: int, - register_list: List[vc.ShaderVariable] = None, - do_sdata_padding: bool = False) -> None: - - for i in range(len(register_list)): - resources.io_index[:] = i * stride + offset - - if resources.sdata_offset is not None: - resources.io_index[:] = resources.io_index + resources.sdata_offset - - if do_sdata_padding: - resources.io_index[:] = resources.io_index + resources.io_index / params.sdata_row_size - - register_list[i][:] = resources.sdata[resources.io_index] - -def load_buffer_to_registers( - resources: FFTResources, - params: FFTParams, - buffer: Optional[IOProxy], - offset: Const[u32], - stride: int, - register_list: List[vc.ShaderVariable] = None, - do_sdata_padding: bool = False) -> None: - if register_list is None: - register_list = resources.registers - - vc.comment(f"Loading to registers from buffer {buffer} at offset {offset} and stride {stride}") - - if buffer is not None: - resources.io_index[:] = offset * params.fft_stride + resources.input_batch_offset - - for i in range(len(register_list)): - if i != 0: - resources.io_index += stride * params.fft_stride - - if params.r2c and params.inverse: - buffer.read_r2c_inverse( - register=register_list[i], - memory_index=resources.io_index, - fft_index=i * stride + offset, - spare_index=resources.io_index_2, - input_batch_offset=resources.input_batch_offset, - fft_size=params.config.N, - fft_stride=params.fft_stride - ) - else: - buffer.read(register_list[i], resources.io_index, spare_register=resources.omega_register, r2c=params.r2c) - - return - - if resources.sdata_offset is not None: - resources.io_index[:] = offset + resources.sdata_offset - else: - resources.io_index[:] = offset - - for i in range(len(register_list)): - if do_sdata_padding: - resources.io_index_2[:] = resources.io_index + stride * i + ((resources.io_index + stride * i) / params.sdata_row_size) - register_list[i][:] = resources.sdata[resources.io_index_2] - else: - register_list[i][:] = resources.sdata[resources.io_index + stride * i] - -def store_register( - resources: FFTResources, - params: FFTParams, - buffer: Optional[IOProxy], - offset: Const[u32], - register: vc.ShaderVariable, - do_sdata_padding: bool = False) -> None: - if buffer is None: - sdata_index = offset - - if resources.sdata_offset is not None: - sdata_index = sdata_index + resources.sdata_offset - - if do_sdata_padding: - resources.io_index[:] = sdata_index - resources.io_index[:] = resources.io_index + resources.io_index / params.sdata_row_size - sdata_index = resources.io_index - - resources.sdata[sdata_index] = register - else: - if params.normalize and params.inverse: - register[:] = register / params.config.N - - buffer.write( - register=register, - memory_index=resources.io_index, - r2c=params.r2c, - inverse=params.inverse, - fft_size=params.config.N, - fft_index=offset - ) - -def store_registers_from_stages( - resources: FFTResources, - params: FFTParams, - stage: FFTRegisterStageConfig, - stage_invocations: List[FFTRegisterStageInvocation], - output: IOProxy, - stride: int): - - sdata_padding = params.sdata_row_size != params.sdata_row_size_padded and stride < 32 and output is None - - if output is not None: - resources.io_index[:] = resources.tid * params.fft_stride + resources.output_batch_offset - - vc.comment(f"Storing from registers to buffer {output} ") - - instance_index_stride = params.config.N // (stage.fft_length * stage.instance_count) - - for jj in range(stage.fft_length): - for ii, invocation in enumerate(stage_invocations): - if stage.remainder_offset == 1 and ii == stage.extra_ffts: - vc.if_statement(resources.tid < params.config.N // stage.registers_used) - - if output is not None and jj != 0 or ii != 0: - resources.io_index += instance_index_stride * params.fft_stride - - store_register( - resources=resources, - params=params, - buffer=output, - offset=invocation.sub_sequence_offset + jj * stride, - register=resources.registers[invocation.register_selection][jj], - do_sdata_padding=sdata_padding - ) - - if stage.remainder_offset == 1: - vc.end() - - return sdata_padding \ No newline at end of file diff --git a/vkdispatch/fft/memory_iterators.py b/vkdispatch/fft/memory_iterators.py new file mode 100644 index 00000000..a7793ab7 --- /dev/null +++ b/vkdispatch/fft/memory_iterators.py @@ -0,0 +1,90 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc + +from .resources import FFTResources + +import dataclasses + +@dataclasses.dataclass +class MemoryOp: + fft_offset: vc.ShaderVariable + fft_stride: int + fft_index: vc.ShaderVariable + fft_size: int + register_id: int + register_count: int + element_id: int + element_count: int + instance_id: int + instance_count: int + +def memory_reads_iterator(resources: FFTResources, stage_index: int = 0): + resources.stage_begin(stage_index) + + index_list = list(range(resources.config.register_count)) + invocations = resources.config.stages[stage_index].invocations + + for ii, invocation in enumerate(invocations): + resources.invocation_gaurd(stage_index, ii) + + register_indicies = index_list[invocation.register_selection] + + offset = invocation.get_offset(resources.tid) + stride = resources.config.N // resources.config.stages[stage_index].fft_length + + for i in range(len(register_indicies)): + fft_index = i * stride + offset + + read_op = MemoryOp( + fft_offset=offset, + fft_stride=stride, + fft_index=fft_index, + fft_size=resources.config.N, + register_id=register_indicies[i], + register_count=resources.config.register_count, + element_id=i, + element_count=len(register_indicies), + instance_id=ii, + instance_count=len(invocations) + ) + + yield read_op + + resources.invocation_end(stage_index) + resources.stage_end(stage_index) + +def memory_writes_iterator(resources: FFTResources, stage_index: int = -1): + resources.stage_begin(stage_index) + + index_list = list(range(resources.config.register_count)) + element_count = resources.config.stages[stage_index].fft_length + invocations = resources.config.stages[stage_index].invocations + + for i in range(element_count): + for ii, invocation in enumerate(invocations): + resources.invocation_gaurd(stage_index, ii) + + offset = invocation.get_sub_sequence_offset(resources.tid) + stride = resources.config.stages[stage_index].input_stride + + fft_index = offset + i * stride + + register_indicies = index_list[invocation.register_selection] + + write_op = MemoryOp( + fft_offset=offset, + fft_stride=stride, + fft_index=fft_index, + fft_size=resources.config.N, + register_id=register_indicies[i], + register_count=resources.config.register_count, + element_id=i, + element_count=element_count, + instance_id=ii, + instance_count=len(invocations) + ) + + yield write_op + + resources.invocation_end(stage_index) + resources.stage_end(stage_index) \ No newline at end of file diff --git a/vkdispatch/fft/plan.py b/vkdispatch/fft/plan.py deleted file mode 100644 index 15d92117..00000000 --- a/vkdispatch/fft/plan.py +++ /dev/null @@ -1,264 +0,0 @@ -import vkdispatch as vd -import vkdispatch.codegen as vc -from vkdispatch.codegen.abreviations import * - -import dataclasses -from typing import List, Tuple -from functools import lru_cache -import numpy as np - -from .resources import FFTResources -from .config import FFTRegisterStageConfig, FFTParams - -from .io_proxy import IOProxy - -from .memory_io import load_buffer_to_registers, store_registers_from_stages, FFTRegisterStageInvocation - -def set_batch_offsets(resources: FFTResources, params: FFTParams): - input_batch_stride_y = params.batch_outer_stride - output_batch_stride_y = params.batch_outer_stride - - if params.r2c and not params.inverse: - output_batch_stride_y = (params.config.N // 2) + 1 - input_batch_stride_y = output_batch_stride_y * 2 - - if params.r2c and params.inverse: - input_batch_stride_y = (params.config.N // 2) + 1 - output_batch_stride_y = input_batch_stride_y * 2 - - resources.input_batch_offset[:] = resources.global_outer_index * input_batch_stride_y + resources.global_inner_index * params.batch_inner_stride - resources.output_batch_offset[:] = resources.global_outer_index * output_batch_stride_y + resources.global_inner_index * params.batch_inner_stride - -def do_c64_mult_const(register_out: vc.ShaderVariable, register_in: vc.ShaderVariable, constant: complex): - vc.comment(f"Multiplying {register_in} by {constant}") - - register_out.x = register_in.y * -constant.imag - register_out.x = vc.fma(register_in.x, constant.real, register_out.x) - - register_out.y = register_in.y * constant.real - register_out.y = vc.fma(register_in.x, constant.imag, register_out.y) - -def radix_P(resources: FFTResources, params: FFTParams, register_list: List[vc.ShaderVariable]): - assert len(register_list) <= len(resources.radix_registers), "Too many registers for radix_P" - - if len(register_list) == 1: - return - - if len(register_list) == 2: - vc.comment(f"Performing a DFT for Radix-2 FFT") - resources.radix_registers[0][:] = register_list[1] - register_list[1][:] = register_list[0] - resources.radix_registers[0] - register_list[0][:] = register_list[0] + resources.radix_registers[0] - return - - vc.comment(f"Performing a DFT for Radix-{len(register_list)} FFT") - - for i in range(0, len(register_list)): - for j in range(0, len(register_list)): - if j == 0: - resources.radix_registers[i][:] = register_list[j] - continue - - if i == 0: - resources.radix_registers[i] += register_list[j] - continue - - if i * j == len(register_list) // 2 and len(register_list) % 2 == 0: - resources.radix_registers[i] -= register_list[j] - continue - - omega = np.exp(1j * params.angle_factor * i * j / len(register_list)) - do_c64_mult_const(resources.omega_register, register_list[j], omega) - resources.radix_registers[i] += resources.omega_register - - for i in range(0, len(register_list)): - register_list[i][:] = resources.radix_registers[i] - -def apply_cooley_tukey_twiddle_factors(resources: FFTResources, params: FFTParams, register_list: List[vc.ShaderVariable], twiddle_index: int = 0, twiddle_N: int = 1): - if isinstance(twiddle_index, int) and twiddle_index == 0: - return - - vc.comment(f"Applying Cooley-Tukey twiddle factors for twiddle index {twiddle_index} and twiddle N {twiddle_N}") - - if not isinstance(twiddle_index, int): - resources.omega_register.x = params.angle_factor * twiddle_index / twiddle_N - resources.omega_register[:] = vc.complex_from_euler_angle(resources.omega_register.x) - - inited_radix = False - - for i in range(len(register_list)): - if i == 0: - continue - - if isinstance(twiddle_index, int): - if twiddle_index == 0: - continue - - omega = np.exp(1j * params.angle_factor * i * twiddle_index / twiddle_N) - - scaled_angle = 2 * np.angle(omega) / np.pi - rounded_angle = np.round(scaled_angle) - - if np.abs(scaled_angle - rounded_angle) < 1e-8: - angle_int = int(rounded_angle) - - if angle_int == 1: - resources.omega_register.x = register_list[i].x - register_list[i].x = -register_list[i].y - register_list[i].y = resources.omega_register.x - elif angle_int == -1: - resources.omega_register.x = register_list[i].x - register_list[i].x = register_list[i].y - register_list[i].y = -resources.omega_register.x - elif angle_int == 2 or angle_int == -2: - register_list[i][:] = -register_list[i] - - continue - - do_c64_mult_const(resources.omega_register, register_list[i], omega) - register_list[i][:] = resources.omega_register - continue - - if not inited_radix: - resources.radix_registers[1][:] = resources.omega_register - inited_radix = True - - do_c64_mult_const(resources.radix_registers[0], register_list[i], resources.radix_registers[1]) - register_list[i][:] = resources.radix_registers[0] - - if i < len(register_list) - 1: - do_c64_mult_const(resources.radix_registers[0], resources.omega_register, resources.radix_registers[1]) - resources.radix_registers[1][:] = resources.radix_registers[0] - -def register_radix_composite(resources: FFTResources, params: FFTParams, register_list: List[vc.ShaderVariable], primes: List[int]): - if len(register_list) == 1: - return - - N = len(register_list) - - assert N == np.prod(primes), "Product of primes must be equal to the number of registers" - - vc.comment(f"Performing a Radix-{primes} FFT on {N} registers") - - output_stride = 1 - - for prime in primes: - sub_squences = [register_list[i::N//prime] for i in range(N//prime)] - - block_width = output_stride * prime - - for i in range(0, N // prime): - inner_block_offset = i % output_stride - block_index = (i * prime) // block_width - - apply_cooley_tukey_twiddle_factors(resources, params, sub_squences[i], twiddle_index=inner_block_offset, twiddle_N=block_width) - radix_P(resources, params, sub_squences[i]) - - sub_sequence_offset = block_index * block_width + inner_block_offset - - for j in range(prime): - register_list[sub_sequence_offset + j * output_stride] = sub_squences[i][j] - - output_stride *= prime - - return register_list - -def process_fft_register_stage(resources: FFTResources, - params: FFTParams, - stage: FFTRegisterStageConfig, - output_stride: int, - input = None, - output = None, - do_sdata_padding: bool = False) -> bool: - do_runtime_if = stage.thread_count < params.config.batch_threads - - vc.comment(f"Processing prime group {stage.primes} by doing {stage.instance_count} radix-{stage.fft_length} FFTs on {params.config.N // stage.registers_used} groups") - if do_runtime_if: vc.if_statement(resources.tid < stage.thread_count) - - stage_invocations: List[FFTRegisterStageInvocation] = [] - - for i in range(stage.instance_count): - stage_invocations.append(FFTRegisterStageInvocation(stage, output_stride, i, resources.tid, params.config.N)) - - for ii, invocation in enumerate(stage_invocations): - if stage.remainder_offset == 1 and ii == stage.extra_ffts: - vc.if_statement(resources.tid < params.config.N // stage.registers_used) - - load_buffer_to_registers( - resources=resources, - params=params, - buffer=input, - offset=invocation.instance_id, - stride=params.config.N // stage.fft_length, - register_list=resources.registers[invocation.register_selection], - do_sdata_padding=do_sdata_padding - ) - - apply_cooley_tukey_twiddle_factors( - resources=resources, - params=params, - register_list=resources.registers[invocation.register_selection], - twiddle_index=invocation.inner_block_offset, - twiddle_N=invocation.block_width - ) - - resources.registers[invocation.register_selection] = register_radix_composite( - resources=resources, - params=params, - register_list=resources.registers[invocation.register_selection], - primes=stage.primes - ) - - if stage.remainder_offset == 1: - vc.end() - - if do_runtime_if: vc.end() - - if (input is None and output is None) or params.input_sdata: - vc.barrier() - - if do_runtime_if: vc.if_statement(resources.tid < stage.thread_count) - - do_padding_next = store_registers_from_stages( - resources=resources, - params=params, - stage=stage, - stage_invocations=stage_invocations, - output=output, - stride=output_stride - ) - - - if do_runtime_if: vc.end() - - return do_padding_next - -def plan( - resources: FFTResources, - params: FFTParams, - input: IOProxy = None, - output: IOProxy = None, - do_sdata_padding: bool = False) -> bool: - - set_batch_offsets(resources, params) - - output_stride = 1 - - stage_count = len(params.config.stages) - - for i in range(stage_count): - do_sdata_padding = process_fft_register_stage( - resources, - params, - params.config.stages[i], - output_stride, - input=input if i == 0 else None, - output=output if i == stage_count - 1 else None, - do_sdata_padding=do_sdata_padding) - - output_stride *= params.config.stages[i].fft_length - - if i < stage_count - 1: - vc.barrier() - - return do_sdata_padding \ No newline at end of file diff --git a/vkdispatch/fft/precision.py b/vkdispatch/fft/precision.py new file mode 100644 index 00000000..d9d6d640 --- /dev/null +++ b/vkdispatch/fft/precision.py @@ -0,0 +1,99 @@ +import vkdispatch as vd + +from typing import Iterable, List, Optional + + +_COMPLEX_PRECISION_ORDER = (vd.complex32, vd.complex64, vd.complex128) +_COMPLEX_PRECISION_RANK = {dtype: rank for rank, dtype in enumerate(_COMPLEX_PRECISION_ORDER)} + + +def is_complex_precision(dtype) -> bool: + return dtype in _COMPLEX_PRECISION_RANK + + +def validate_complex_precision(dtype, *, arg_name: str) -> None: + if not is_complex_precision(dtype): + raise ValueError(f"{arg_name} must be one of complex32, complex64, or complex128 (got {dtype})") + + +def promote_complex_precisions(dtypes: Iterable) -> vd.dtype: + candidates = list(dtypes) + if len(candidates) == 0: + raise ValueError("At least one complex dtype is required for promotion") + + for candidate in candidates: + validate_complex_precision(candidate, arg_name="dtype") + + return max(candidates, key=lambda dtype: _COMPLEX_PRECISION_RANK[dtype]) + + +def default_compute_precision(io_precisions: Iterable) -> vd.dtype: + promoted = promote_complex_precisions(io_precisions) + + # Default to at least complex64 for numerical stability. + if _COMPLEX_PRECISION_RANK[promoted] < _COMPLEX_PRECISION_RANK[vd.complex64]: + return vd.complex64 + + return promoted + + +def supports_complex_precision(dtype) -> bool: + validate_complex_precision(dtype, arg_name="dtype") + scalar_type = dtype.child_type + + for device in vd.get_context().device_infos: + if scalar_type == vd.float16: + if device.float_16_support != 1: + return False + + # Half precision in storage buffers typically needs one of these capabilities. + if ( + device.storage_buffer_16_bit_access != 1 + and device.uniform_and_storage_buffer_16_bit_access != 1 + ): + return False + + if scalar_type == vd.float64 and device.float_64_support != 1: + return False + + return True + + +def ensure_supported_complex_precision(dtype, *, role: str) -> None: + if not supports_complex_precision(dtype): + raise ValueError(f"{role} precision '{dtype.name}' is not supported on the active device set") + + +def resolve_compute_precision(io_precisions: List, compute_precision: Optional[vd.dtype]) -> vd.dtype: + if compute_precision is not None: + validate_complex_precision(compute_precision, arg_name="compute_type") + ensure_supported_complex_precision(compute_precision, role="Compute") + return compute_precision + + for io_precision in io_precisions: + validate_complex_precision(io_precision, arg_name="io_precision") + + if len(io_precisions) == 0: + for candidate in (vd.complex64, vd.complex32): + if supports_complex_precision(candidate): + return candidate + + raise ValueError( + "Unable to resolve a default compute precision supported by all active devices" + ) + + target = default_compute_precision(io_precisions) + if supports_complex_precision(target): + return target + + # Auto fallback: drop from complex128 to complex64 when fp64 is unsupported. + for candidate in (vd.complex64, vd.complex32): + if ( + _COMPLEX_PRECISION_RANK[candidate] <= _COMPLEX_PRECISION_RANK[target] + and supports_complex_precision(candidate) + ): + return candidate + + raise ValueError( + "Unable to resolve an auto compute precision supported by all active devices" + ) diff --git a/vkdispatch/fft/prime_utils.py b/vkdispatch/fft/prime_utils.py index 783ed6e6..ee1624fa 100644 --- a/vkdispatch/fft/prime_utils.py +++ b/vkdispatch/fft/prime_utils.py @@ -1,13 +1,10 @@ -import numpy as np from typing import List import vkdispatch as vd +from ..compat import numpy_compat as npc def default_register_limit(): - if vd.get_devices()[0].is_nvidia(): - return 16 - - return 15 + return 16 def default_max_prime(): return 13 @@ -42,7 +39,7 @@ def group_primes(primes, register_count): groups.append([prime]) continue - if np.prod(groups[-1]) * prime <= register_count: + if npc.prod(groups[-1]) * prime <= register_count: groups[-1].append(prime) continue @@ -63,4 +60,4 @@ def pad_dim(dim: int, max_register_count: int = None): current_dim += 1 current_primes = prime_factors(current_dim) - return current_dim \ No newline at end of file + return current_dim diff --git a/vkdispatch/fft/registers.py b/vkdispatch/fft/registers.py new file mode 100644 index 00000000..31c79e32 --- /dev/null +++ b/vkdispatch/fft/registers.py @@ -0,0 +1,85 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc + +from typing import List, Dict + +from .config import FFTConfig +from .resources import FFTResources + +import dataclasses + +@dataclasses.dataclass +class RegisterIOOp: + register: vc.ShaderVariable + offset: vc.ShaderVariable + stride: int + fft_index: vc.ShaderVariable + register_id: int + register_count: int + element_id: int + element_count: int + instance_id: int + instance_count: int + +class FFTRegisters: + resources: FFTResources + config: FFTConfig + registers: List[vc.ShaderVariable] + count: int + + def __init__(self, resources: FFTResources, count: int, name: str): + self.resources = resources + self.config = resources.config + + self.registers = [ + vc.new_register(self.config.compute_type, var_name=f"{name}_reg_{i}") for i in range(count) + ] + + self.count = count + + def clear(self): + for reg in self.registers: + reg[:] = 0 + + def register_slice(self, slc: slice) -> List[vc.ShaderVariable]: + return self.registers[slc] + def slice_set(self, slc: slice, values: List[vc.ShaderVariable]): + self.registers[slc] = values + + def __getitem__(self, index: int) -> vc.ShaderVariable: + return self.registers[index] + + def __setitem__(self, index: int, value: vc.ShaderVariable): + self.registers[index][:] = value + + def normalize(self): + normalization = vc.to_dtype(self.config.compute_type.child_type, self.config.N) + for i in range(self.count): + self.registers[i][:] = self.registers[i] / normalization + + def try_shuffle(self, output_stage: int = -1, input_stage: int = 0) -> bool: + out_format = self.config.stages[output_stage].get_output_format(len(self.registers)) + in_format = self.config.stages[input_stage].get_input_format(len(self.registers)) + + if out_format.keys() != in_format.keys(): + return False + + vc.comment("Performing register shuffle w/o shared memory.", preceding_new_line=False) + + # Some stages can use fewer registers than config.register_count. + # Shuffle only registers that appear in the input format. + shuffled_registers = list(self.registers) + + for format_key, input_register in in_format.items(): + shuffled_registers[input_register] = self.registers[out_format[format_key]] + + for i in range(len(self.registers)): + self.registers[i] = shuffled_registers[i] + + return True + + def read_from_registers(self, other: "FFTRegisters") -> "FFTRegisters": + assert self.count == other.count, "Register counts must match for copy" + + for i in range(self.count): + self.registers[i][:] = other.registers[i] diff --git a/vkdispatch/fft/resources.py b/vkdispatch/fft/resources.py index 2115544f..f63bd04e 100644 --- a/vkdispatch/fft/resources.py +++ b/vkdispatch/fft/resources.py @@ -2,230 +2,64 @@ import vkdispatch.codegen as vc from vkdispatch.codegen.abreviations import * -import numpy as np import dataclasses -from typing import List, Tuple +from typing import List from .config import FFTConfig -from .prime_utils import prime_factors, default_register_limit +from .grid_manager import FFTGridManager -def allocation_valid(workgroup_size: int, shared_memory: int): - return workgroup_size <= vd.get_context().max_workgroup_invocations and shared_memory <= vd.get_context().max_shared_memory - -def allocate_inline_batches(batch_num: int, batch_threads: int, N: int, max_workgroup_size: int, max_total_threads: int): - shared_memory_allocation = N * vd.complex64.item_size - batch_num_primes = prime_factors(batch_num) - prime_index = 0 - workgroup_size = batch_threads - inline_batches = 1 +@dataclasses.dataclass +class FFTResources: + input_batch_offset: vc.ShaderVariable + output_batch_offset: vc.ShaderVariable + omega_register: vc.ShaderVariable + subsequence_offset: Const[u32] + io_index: Const[u32] + io_index_2: Const[u32] - while allocation_valid(workgroup_size, shared_memory_allocation) and prime_index < len(batch_num_primes) and inline_batches <= max_workgroup_size and workgroup_size <= max_total_threads: - test_prime = batch_num_primes[prime_index] + radix_registers: List[vc.ShaderVariable] - is_valid = allocation_valid(workgroup_size * test_prime, shared_memory_allocation * test_prime) + tid: vc.ShaderVariable - is_valid = is_valid and inline_batches * test_prime <= max_workgroup_size - is_valid = is_valid and workgroup_size * test_prime <= max_total_threads + grid: FFTGridManager - if is_valid: - workgroup_size *= test_prime - shared_memory_allocation *= test_prime - inline_batches *= test_prime - - prime_index += 1 + config: FFTConfig - return inline_batches + def __init__(self, config: FFTConfig, grid: FFTGridManager): + self.tid = grid.tid + self.grid = grid + self.config = config + self.input_batch_offset = vc.new_uint_register(var_name="input_batch_offset") + self.output_batch_offset = vc.new_uint_register(var_name="output_batch_offset") + self.omega_register = vc.new_register(config.compute_type, var_name="omega_register") + self.subsequence_offset = vc.new_uint_register(var_name="subsequence_offset") + self.io_index = vc.new_uint_register(var_name="io_index") + self.io_index_2 = vc.new_uint_register(var_name="io_index_2") -def allocate_workgroups(total_count: int) -> Tuple[vc.ShaderVariable, Tuple[int, int, int]]: - def set_to_multiple_with_max(count, max_count): - if count <= max_count: - return count - - count_primes = prime_factors(count) + self.radix_registers = [ + vc.new_register(config.compute_type, var_name=f"radix_register_{i}") for i in range(config.max_prime_radix) + ] - result_count = 1 - for prime in count_primes: - if result_count * prime > max_count: - break - result_count *= prime + def stage_begin(self, stage_index: int): + thread_count = self.config.stages[stage_index].thread_count - return result_count + if thread_count < self.config.batch_threads: + vc.if_statement(self.tid < thread_count) - workgroups_x = set_to_multiple_with_max( - total_count, - vd.get_context().max_workgroup_count[0] - ) - workgroups_y = 1 - workgroups_z = 1 - - workgroup_index = vc.new_uint( - vc.workgroup().x, - var_name="workgroup_index" - ) - - if workgroups_x != total_count: - workgroups_y = set_to_multiple_with_max( - total_count // workgroups_x, - vd.get_context().max_workgroup_count[1] - ) - - workgroup_index += workgroups_x * vc.workgroup().y - - if workgroups_y != total_count // workgroups_x: - workgroups_z = set_to_multiple_with_max( - total_count // (workgroups_x * workgroups_y), - vd.get_context().max_workgroup_count[2] - ) + def stage_end(self, stage_index: int): + thread_count = self.config.stages[stage_index].thread_count - workgroup_index += workgroups_x * workgroups_y * vc.workgroup().z + if thread_count < self.config.batch_threads: + vc.end() - return workgroup_index, (workgroups_x, workgroups_y, workgroups_z) + def invocation_gaurd(self, stage_index: int, invocation_index: int): + stage = self.config.stages[stage_index] -def decompose_workgroup_index(workgroup_index: vc.ShaderVariable, inner_batch_count: int, fft_threads: int, local_size: Tuple[int, int, int]) -> Tuple[vc.ShaderVariable, vc.ShaderVariable]: - if inner_batch_count == None: - if fft_threads == 1: - return None, workgroup_index * local_size[0] + vc.local_invocation().x + if stage.remainder_offset == 1 and invocation_index == stage.extra_ffts: + vc.if_statement(self.tid < self.config.N // stage.registers_used) - return None, workgroup_index * local_size[1] + vc.local_invocation().y - - global_inner = vc.new_uint( - (workgroup_index % inner_batch_count) * local_size[0] + vc.local_invocation().x, - var_name="global_inner_index" - ) - - global_outer = vc.new_uint( - (workgroup_index / inner_batch_count) * local_size[2] + vc.local_invocation().z, - var_name="global_outer_index" - ) - - return global_inner, global_outer - -@dataclasses.dataclass -class FFTResources: - registers: List[vc.ShaderVariable] - radix_registers: List[vc.ShaderVariable] - omega_register: vc.ShaderVariable - tid: Const[u32] - input_batch_offset: Const[u32] - output_batch_offset: Const[u32] - subsequence_offset: Const[u32] - sdata: Buff[c64] - sdata_offset: Const[u32] - io_index: Const[u32] - io_index_2: Const[u32] - global_inner_index: Const[u32] - global_outer_index: Const[u32] - exec_size: Tuple[int, int, int] - - shared_memory_size: int - local_size: Tuple[int, int, int] - -def allocate_fft_resources(config: FFTConfig, convolve: bool = False) -> FFTResources: - make_sdata_buffer = config.batch_threads > 1 or convolve - - inline_batch_inner = allocate_inline_batches( - config.batch_inner_count, - config.batch_threads, - config.sdata_allocation if make_sdata_buffer else 0, - min(vd.get_context().max_workgroup_size[0], 4), - vd.get_context().max_workgroup_invocations) - - max_inline_outer_batches = vd.get_context().max_workgroup_size[1 if config.batch_inner_count == 1 else 2] - - # For some reason it's better not to have too many inline outer batches - max_inline_outer_batches = min(max_inline_outer_batches, vd.get_context().subgroup_size) - - inline_batch_outer = allocate_inline_batches( - config.batch_outer_count, - config.batch_threads * inline_batch_inner, - config.sdata_allocation * inline_batch_inner if make_sdata_buffer else 0, - vd.get_context().max_workgroup_size[1 if inline_batch_inner == 1 else 2], - max_inline_outer_batches) - - sdata_buffer = None - - if make_sdata_buffer: - sdata_buffer = vc.shared_buffer( - vd.complex64, - config.sdata_allocation * inline_batch_outer * inline_batch_inner, - var_name="sdata") - - - if config.batch_inner_count > 1: - local_inner = vc.local_invocation().x - local_outer = vc.local_invocation().z - local_size = (inline_batch_inner, config.batch_threads, inline_batch_outer) - - inner_workgroups = config.batch_inner_count // inline_batch_inner - outer_workgroups = config.batch_outer_count // inline_batch_outer - - workgroup_index, workgroups = allocate_workgroups(inner_workgroups * outer_workgroups) - - global_inner, global_outer = decompose_workgroup_index( - workgroup_index, - inner_workgroups, - config.batch_threads, - local_size - ) - - exec_size = ( - local_size[0] * workgroups[0], - local_size[1] * workgroups[1], - local_size[2] * workgroups[2] - ) - - tid = vc.local_invocation().y.copy("tid") - else: - local_inner = None - global_inner = 0 - - if config.batch_threads > 1: - tid = vc.local_invocation().x.copy("tid") - local_outer = vc.local_invocation().y - local_size = (config.batch_threads, inline_batch_outer, 1) - else: - tid = vc.new_uint(0, var_name="tid") - local_outer = vc.local_invocation().x - local_size = (inline_batch_outer, 1, 1) - - workgroup_index, workgroups = allocate_workgroups(config.batch_outer_count // inline_batch_outer) - - _, global_outer = decompose_workgroup_index(workgroup_index, None, config.batch_threads, local_size) - - exec_size = ( - local_size[0] * workgroups[0], - local_size[1] * workgroups[1], - local_size[2] * workgroups[2] - ) - - sdata_offset = None - - if inline_batch_outer > 1 or inline_batch_inner > 1: - sdata_offset_value = local_outer * inline_batch_inner * config.N - - if local_inner is not None: - sdata_offset_value = sdata_offset_value + local_inner * config.N - - sdata_offset = vc.new_uint(sdata_offset_value, var_name="sdata_offset") - - resources = FFTResources( - registers=[vc.new(c64, 0, var_name=f"register_{i}") for i in range(config.register_count)], - radix_registers=[vc.new(c64, 0, var_name=f"radix_{i}") for i in range(config.max_prime_radix)], - omega_register=vc.new(c64, 0, var_name="omega_register"), - tid=tid, - input_batch_offset=vc.new_uint(var_name="input_batch_offset"), - output_batch_offset=vc.new_uint(var_name="output_batch_offset"), - subsequence_offset=vc.new_uint(0, var_name="subsequence_offset"), - sdata=sdata_buffer, - sdata_offset=sdata_offset, - io_index=vc.new_uint(0, var_name="io_index"), - io_index_2=vc.new_uint(0, var_name="io_index_2"), - shared_memory_size=config.N * inline_batch_outer * inline_batch_inner * vd.complex64.item_size, - local_size=local_size, - global_inner_index=global_inner, - global_outer_index=global_outer, - exec_size=exec_size - ) - - return resources + def invocation_end(self, stage_index: int): + stage = self.config.stages[stage_index] + if stage.remainder_offset == 1: + vc.end() diff --git a/vkdispatch/fft/sdata_manager.py b/vkdispatch/fft/sdata_manager.py new file mode 100644 index 00000000..d00ff31e --- /dev/null +++ b/vkdispatch/fft/sdata_manager.py @@ -0,0 +1,104 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc + +from typing import Literal, Union, List, Optional + +from .config import FFTConfig +from .grid_manager import FFTGridManager +from .resources import FFTResources +from .registers import FFTRegisters + +from .memory_iterators import memory_reads_iterator, memory_writes_iterator + +class FFTSDataManager: + sdata: vc.Buffer + sdata_offset: Union[vc.Const[vc.u32], Literal[0]] + + sdata_row_size: int + sdata_row_size_padded: int + padding_enabled: bool + + # None: not set yet + # True: last operation was write + # False: last operation was read + last_op: bool + + use_padding: bool + + tid: vc.ShaderVariable + fft_N: int + + resources: FFTResources + default_registers: FFTRegisters + + + def __init__(self, config: FFTConfig, grid: FFTGridManager, default_registers: FFTRegisters): + self.sdata_row_size = config.sdata_row_size + self.sdata_row_size_padded = config.sdata_row_size_padded + self.padding_enabled = self.sdata_row_size != self.sdata_row_size_padded + self.use_padding = False + self.fft_N = config.N + self.tid = grid.tid + self.last_op = None + self.default_registers = default_registers + self.resources = default_registers.resources + + total_inner_batches = grid.inline_batches_inner * grid.inline_batches_outer + + self.sdata = vc.shared_buffer( + config.compute_type, + config.sdata_allocation * total_inner_batches, + var_name="sdata") + + self.sdata_offset = 0 + + if total_inner_batches > 1: + sdata_offset_value = grid.local_outer * grid.inline_batches_inner * config.N + + if grid.local_inner is not None: + sdata_offset_value = sdata_offset_value + grid.local_inner * config.N + + self.sdata_offset = vc.new_uint_register(sdata_offset_value, var_name="sdata_offset") + + + def do_op(self, op: bool): + if self.last_op is not None and self.last_op != op: + vc.barrier() + + self.last_op = op + + def op_read(self) -> bool: + self.do_op(False) + + def op_write(self) -> bool: + self.do_op(True) + + def read_from_sdata(self, registers: Optional[FFTRegisters] = None, stage_index: int = 0): + self.op_read() + + if registers is None: + registers = self.default_registers + + for read_op in memory_reads_iterator(self.resources, stage_index): + self.resources.io_index[:] = read_op.fft_index + self.sdata_offset + + if self.use_padding: + self.resources.io_index[:] = self.resources.io_index + (self.resources.io_index // self.sdata_row_size) + + registers[read_op.register_id] = self.sdata[self.resources.io_index] + + def write_to_sdata(self, registers: Optional[FFTRegisters] = None, stage_index: int = -1): + self.op_write() + + self.use_padding = self.padding_enabled and self.resources.config.stages[stage_index].input_stride < 32 + + if registers is None: + registers = self.default_registers + + for write_op in memory_writes_iterator(self.resources, stage_index): + self.resources.io_index[:] = write_op.fft_index + self.sdata_offset + + if self.use_padding: + self.resources.io_index[:] = self.resources.io_index + (self.resources.io_index // self.sdata_row_size) + + self.sdata[self.resources.io_index] = registers[write_op.register_id] diff --git a/vkdispatch/fft/shader.py b/vkdispatch/fft/shader.py deleted file mode 100644 index 0facb61c..00000000 --- a/vkdispatch/fft/shader.py +++ /dev/null @@ -1,164 +0,0 @@ -import vkdispatch as vd -import vkdispatch.codegen as vc -from vkdispatch.codegen.abreviations import * - -from typing import List, Tuple, Union -from functools import lru_cache -import numpy as np - -from .memory_io import load_sdata_state_to_registers, FFTRegisterStageInvocation - -from .plan import plan - -@lru_cache(maxsize=None) -def make_fft_shader( - buffer_shape: Tuple, - axis: int = None, - inverse: bool = False, - normalize_inverse: bool = True, - r2c: bool = False, - input_map: vd.MappingFunction = None, - output_map: vd.MappingFunction = None) -> Tuple[vd.ShaderObject, Tuple[int, int, int]]: - - with vd.fft.fft_context( - buffer_shape, - axis=axis, - input_map=input_map, - output_map=output_map - ) as manager: - - plan( - manager.resources, - manager.config.params( - inverse, - normalize_inverse, - r2c), - input=manager.io_manager.input_proxy, - output=manager.io_manager.output_proxy) - - return manager.get_callable() - -@lru_cache(maxsize=None) -def make_convolution_shader( - buffer_shape: Tuple, - kernel_map: vd.MappingFunction = None, - kernel_num: int = 1, - axis: int = None, - normalize: bool = True, - input_map: vd.MappingFunction = None, - output_map: vd.MappingFunction = None) -> Tuple[vd.ShaderObject, Tuple[int, int, int]]: - - if kernel_map is None: - def kernel_map_func(kernel_buffer: vc.Buffer[c64]): - img_val = vc.mapping_registers()[0] - read_register = vc.mapping_registers()[1] - - read_register[:] = kernel_buffer[vc.mapping_index()] - img_val[:] = vc.mult_conj_c64(img_val, read_register) - - kernel_map = vd.map(kernel_map_func, register_types=[c64], input_types=[vc.Buffer[c64]]) - - with vd.fft.fft_context( - buffer_shape, - axis=axis, - input_map=input_map, - output_map=output_map, - kernel_map=kernel_map - ) as manager: - vc.comment("Performing forward FFT stage in convolution shader") - - do_sdata_padding = plan( - manager.resources, - manager.config.params( - inverse=False, - ), - input=manager.io_manager.input_proxy) - - vc.barrier() - - vc.comment("Performing convolution stage in convolution shader") - - inverse_params = manager.config.params( - inverse=True, - normalize=normalize) - - assert inverse_params.config.stages[0].instance_count == 1, "Something is very wrong" - - invocation = FFTRegisterStageInvocation( - inverse_params.config.stages[0], - 1, 0, - manager.resources.tid, - inverse_params.config.N - ) - - vc.comment(f"Loading state to registers in convolution shader") - - if kernel_num == 1: - load_sdata_state_to_registers( - manager.resources, - inverse_params, - invocation.instance_id, - inverse_params.config.N // inverse_params.config.stages[0].fft_length, - manager.resources.registers[invocation.register_selection], - do_sdata_padding - ) - - vc.comment("Performing IFFT stage in convolution shader") - - vc.barrier() - - vc.set_kernel_index(0) - - plan( - manager.resources, - inverse_params, - input=manager.io_manager.kernel_proxy, - output=manager.io_manager.output_proxy, - do_sdata_padding=do_sdata_padding) - - else: - backup_registers = [] - for i in range(len(manager.resources.registers)): - backup_registers.append(vc.new(c64, 0, var_name=f"backup_register_{i}")) - - load_sdata_state_to_registers( - manager.resources, - inverse_params, - invocation.instance_id, - inverse_params.config.N // inverse_params.config.stages[0].fft_length, - backup_registers[invocation.register_selection], - do_sdata_padding - ) - - vc.comment("Performing IFFT stage in convolution shader") - - for kern_index in range(kernel_num): - vc.barrier() - - for i in range(len(manager.resources.registers)): - manager.resources.registers[i][:] = backup_registers[i] - - vc.set_kernel_index(kern_index) - - plan( - manager.resources, - inverse_params, - input=manager.io_manager.kernel_proxy, - output=manager.io_manager.output_proxy, - do_sdata_padding=do_sdata_padding) - - return manager.get_callable() - -def get_cache_info(): - return make_fft_shader.cache_info() - -def get_convoliution_cache_info(): - return make_convolution_shader.cache_info() - -def print_cache_info(): - print(get_cache_info()) - print(get_convoliution_cache_info()) - -def cache_clear(): - make_convolution_shader.cache_clear() - make_fft_shader.cache_clear() \ No newline at end of file diff --git a/vkdispatch/fft/shader_factories.py b/vkdispatch/fft/shader_factories.py new file mode 100644 index 00000000..28a481fd --- /dev/null +++ b/vkdispatch/fft/shader_factories.py @@ -0,0 +1,209 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc +from vkdispatch.codegen.abreviations import * + +from ..compat import numpy_compat as npc + +from typing import Tuple, Optional +from functools import lru_cache +import threading + +@lru_cache(maxsize=None) +def make_fft_shader( + buffer_shape: Tuple, + axis: int = None, + inverse: bool = False, + normalize_inverse: bool = True, + r2c: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + input_type: vd.dtype = None, + output_type: vd.dtype = None, + compute_type: vd.dtype = None, + input_signal_range: Optional[Tuple[Optional[int], Optional[int]]] = None) -> vd.ShaderFunction: + + if output_type is None: + output_type = vd.complex64 + + if input_type is None and input_map is None: + input_type = output_type + + if compute_type is None: + compute_type = vd.complex64 + + name = f"fft_shader_{buffer_shape}_{axis}_{inverse}_{normalize_inverse}_{r2c}" + + with vd.fft.fft_context(buffer_shape, axis=axis, compute_type=compute_type, name=name) as ctx: + io_manager = ctx.make_io_manager( + input_map=input_map, + output_map=output_map, + output_type=output_type, + input_type=input_type, + ) + + io_manager.read_input( + r2c=r2c, + inverse=inverse, + signal_range=input_signal_range + ) + + ctx.execute(inverse=inverse) + + if inverse and normalize_inverse: + ctx.registers.normalize() + + io_manager.write_output( + r2c=r2c, + inverse=inverse + ) + + return ctx.get_callable() + +@lru_cache(maxsize=None) +def get_transposed_size( + buffer_shape: Tuple, + axis: int = None, + compute_type: vd.dtype = vd.complex64) -> vd.ShaderFunction: + + config = vd.fft.FFTConfig(buffer_shape, axis, compute_type=compute_type) + grid = vd.fft.FFTGridManager(config, True, False) + + return npc.prod(grid.local_size) * npc.prod(grid.workgroup_count) * config.register_count + +@lru_cache(maxsize=None) +def make_transpose_shader( + buffer_shape: Tuple, + axis: int = None, + kernel_inner_only: bool = False, + input_type: vd.dtype = vd.complex64, + output_type: vd.dtype = vd.complex64, + compute_type: vd.dtype = vd.complex64) -> vd.ShaderFunction: + + with vd.fft.fft_context(buffer_shape, axis=axis, compute_type=compute_type) as ctx: + args = ctx.declare_shader_args([vc.Buffer[output_type], vc.Buffer[input_type]]) + + if kernel_inner_only: + vc.if_statement(ctx.grid.global_outer_offset == 0) + + for read_op in vd.fft.global_reads_iterator(ctx.registers, format_transposed=False): + read_op.read_from_buffer(args[1]) + + for write_op in vd.fft.global_trasposed_write_iterator(ctx.registers, inner_only=kernel_inner_only): + write_op.write_to_buffer(args[0]) + + if kernel_inner_only: + vc.end() + + return ctx.get_callable() + +_kernel_index_state = threading.local() + +def set_global_kernel_index(index: Optional[int]): + _kernel_index_state.index = index + +def mapped_kernel_index() -> Optional[int]: + return getattr(_kernel_index_state, "index", None) + +@lru_cache(maxsize=None) +def make_convolution_shader( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + kernel_num: int = 1, + axis: int = None, + normalize: bool = True, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + input_type: vd.dtype = None, + output_type: vd.dtype = None, + kernel_type: vd.dtype = None, + compute_type: vd.dtype = None, + input_signal_range: Optional[Tuple[Optional[int], Optional[int]]] = None) -> vd.ShaderFunction: + + if output_type is None: + output_type = vd.complex64 + + if input_type is None and input_map is None: + input_type = output_type + + if kernel_type is None: + kernel_type = vd.complex64 + + if compute_type is None: + compute_type = vd.complex64 + + if kernel_map is None: + def kernel_map_func(kernel_buffer: vc.Buffer[kernel_type]): + read_op = vd.fft.read_op() + + kernel_val = vc.new_register(compute_type) + read_op.read_from_buffer(kernel_buffer, register=kernel_val) + + read_op.register[:] = vc.mult_complex(read_op.register, kernel_val.conjugate()) + + kernel_map = vd.map(kernel_map_func, input_types=[vc.Buffer[kernel_type]]) + + name = f"convolution_shader_{buffer_shape}_{axis}" + + with vd.fft.fft_context(buffer_shape, axis=axis, compute_type=compute_type, name=name) as ctx: + io_manager = ctx.make_io_manager( + input_map=input_map, + output_map=output_map, + output_type=output_type, + input_type=input_type, + kernel_map=kernel_map + ) + + vc.comment("""Convolution pipeline phase 1/3. +Load spatial-domain input samples and run a forward FFT into frequency space. +Then shuffle registers so lane layout matches kernel application and inverse passes.""") + + io_manager.read_input(signal_range=input_signal_range) + ctx.execute(inverse=False) + ctx.register_shuffle() + + backup_registers = None + + if kernel_num > 1: + backup_registers = ctx.allocate_registers("backup") + backup_registers.read_from_registers(ctx.registers) + + for kern_index in range(kernel_num): + vc.comment(f"""Convolution pipeline phase 2/3. Kernel {kern_index + 1}/{kernel_num}. +Map this kernel onto the current spectrum.""") + + if backup_registers is not None: + ctx.registers.read_from_registers(backup_registers) + + set_global_kernel_index(kern_index) + io_manager.read_kernel(format_transposed=transposed_kernel, inner_only=kernel_inner_only) + + vc.comment(f"""Convolution pipeline phase 3/3. +Run inverse FFT back to the spatial domain, optionally normalize by length, +and write this kernel's output slice to global memory.""") + + ctx.execute(inverse=True) + + if normalize: + ctx.registers.normalize() + + io_manager.write_output(inverse=True) + + set_global_kernel_index(None) + + return ctx.get_callable() + +def get_cache_info(): + return make_fft_shader.cache_info() + +def get_convoliution_cache_info(): + return make_convolution_shader.cache_info() + +def print_cache_info(): + print(get_cache_info()) + print(get_convoliution_cache_info()) + +def cache_clear(): + make_convolution_shader.cache_clear() + make_fft_shader.cache_clear() diff --git a/vkdispatch/fft/src_functions.py b/vkdispatch/fft/src_functions.py new file mode 100644 index 00000000..e8952bb3 --- /dev/null +++ b/vkdispatch/fft/src_functions.py @@ -0,0 +1,342 @@ +import vkdispatch as vd + +from .shader_factories import make_fft_shader, make_convolution_shader, make_transpose_shader, get_transposed_size + +from typing import Tuple, Union, Optional + +def fft_src( + buffer_shape: Tuple, + axis: int = None, + inverse: bool = False, + normalize_inverse: bool = True, + r2c: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None, + line_numbers: bool = False) -> vd.ShaderSource: + + fft_shader = make_fft_shader( + tuple(buffer_shape), + axis, + inverse=inverse, + normalize_inverse=normalize_inverse, + r2c=r2c, + input_map=input_map, + output_map=output_map, + input_signal_range=input_signal_range) + + return fft_shader.get_src(line_numbers=line_numbers) + +def fft2_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer Shape must have 2 or 3 dimensions' + + return ( + fft_src(axis=len(buffer_shape) - 2, input_map=input_map), + fft_src(axis=len(buffer_shape) - 1, output_map=output_map) + ) + +def fft3_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions' + + return ( + fft_src(buffer_shape, axis=0, input_map=input_map), + fft_src(buffer_shape, axis=1), + fft_src(buffer_shape, axis=2, output_map=output_map) + ) + + +def ifft_src( + buffer_shape: Tuple, + axis: int = None, + normalize: bool = True, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None): + return fft_src(buffer_shape, axis=axis, inverse=True, normalize_inverse=normalize, input_map=input_map, output_map=output_map) + +def ifft2_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + + return ( + ifft_src(buffer_shape, axis=len(buffer_shape) - 2, normalize=normalize, input_map=input_map), + ifft_src(buffer_shape, axis=len(buffer_shape) - 1, normalize=normalize, output_map=output_map) + ) + +def ifft3_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions' + + return ( + ifft_src(buffer_shape, axis=0, normalize=normalize, input_map=input_map), + ifft_src(buffer_shape, axis=1, normalize=normalize), + ifft_src(buffer_shape, axis=2, normalize=normalize, output_map=output_map) + ) + + +def rfft_src(buffer_shape: Tuple): + return fft_src(buffer_shape, r2c=True) + +def rfft2_src(buffer_shape: Tuple): + assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + + return ( + rfft_src(buffer_shape), + fft_src(buffer_shape, axis=len(buffer_shape) - 2) + ) + +def rfft3_src(buffer_shape: Tuple): + assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions' + + return ( + rfft_src(buffer_shape), + fft_src(buffer_shape, axis=1), + fft_src(buffer_shape, axis=0) + ) + +def irfft_src(buffer_shape: Tuple, normalize: bool = True): + return fft_src(buffer_shape, inverse=True, normalize_inverse=normalize, r2c=True) + +def irfft2_src(buffer_shape: Tuple, normalize: bool = True): + assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + + return ( + ifft_src(buffer_shape, axis=len(buffer_shape) - 2, normalize=normalize), + irfft_src(buffer_shape, normalize=normalize) + ) + +def irfft3_src(buffer_shape: Tuple, normalize: bool = True): + assert len(buffer_shape) == 3, 'Buffer must have 3 dimensions' + + return ( + ifft_src(buffer_shape, axis=0, normalize=normalize), + ifft_src(buffer_shape, axis=1, normalize=normalize), + irfft_src(buffer_shape, normalize=normalize) + ) + +def convolve_src( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + kernel_num: int = 1, + axis: int = None, + normalize: bool = True, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None, + line_numbers: bool = False) -> vd.ShaderSource: + + fft_shader = make_convolution_shader( + tuple(buffer_shape), + kernel_map, + kernel_num, + axis, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + normalize=normalize, + input_map=input_map, + output_map=output_map, + input_signal_range=input_signal_range) + + return fft_shader.get_src(line_numbers=line_numbers) + +def convolve2D_src( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + normalize: bool = True, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None): + + assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + + return ( + fft_src(buffer_shape, input_map=input_map), + convolve_src( + buffer_shape, + kernel_map=kernel_map, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + axis=len(buffer_shape) - 2, + normalize=normalize + ), + ifft_src(buffer_shape, normalize=normalize, output_map=output_map) + ) + +def convolve2DR_src( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + normalize: bool = True): + + assert len(buffer_shape) == 2 or len(buffer_shape) == 3, 'Buffer must have 2 or 3 dimensions' + + return ( + rfft_src(buffer_shape), + convolve_src( + buffer_shape, + kernel_map=kernel_map, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + axis=len(buffer_shape) - 2, + normalize=normalize + ), + irfft_src(buffer_shape, normalize=normalize) + ) + +def transpose_src( + buffer_shape: Tuple, + axis: int = None, + kernel_inner_only: bool = False, + line_numbers: bool = False) -> vd.Buffer: + + transpose_shader = make_transpose_shader( + tuple(buffer_shape), + axis=axis, + kernel_inner_only=kernel_inner_only + ) + + return transpose_shader.get_src(line_numbers=line_numbers) + + +def fft_print_src( + buffer_shape: Tuple, + axis: int = None, + inverse: bool = False, + normalize_inverse: bool = True, + r2c: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None, + line_numbers: bool = False) -> vd.ShaderSource: + + print(fft_src( + buffer_shape, + axis, + inverse=inverse, + normalize_inverse=normalize_inverse, + r2c=r2c, + input_map=input_map, + output_map=output_map, + input_signal_range=input_signal_range, + line_numbers=line_numbers)) + +def fft2_print_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + srcs = fft2_src(buffer_shape, input_map=input_map, output_map=output_map) + print(f"// FFT Stage 1 (axis {len(buffer_shape) - 2}):\n{srcs[0]}\n// FFT Stage 2 (axis {len(buffer_shape) - 1}):\n{srcs[1]}") + +def fft3_print_src(buffer_shape: Tuple, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + srcs = fft3_src(buffer_shape, input_map=input_map, output_map=output_map) + print(f"// FFT Stage 1 (axis 0):\n{srcs[0]}\n// FFT Stage 2 (axis 1):\n{srcs[1]}\n// FFT Stage 3 (axis 2):\n{srcs[2]}") + +def ifft_print_src( + buffer_shape: Tuple, + axis: int = None, + normalize: bool = True, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None): + print(ifft_src(buffer_shape, axis=axis, normalize=normalize, input_map=input_map, output_map=output_map)) + +def ifft2_print_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + srcs = ifft2_src(buffer_shape, normalize=normalize, input_map=input_map, output_map=output_map) + print(f"// IFFT Stage 1 (axis {len(buffer_shape) - 2}):\n{srcs[0]}\n// IFFT Stage 2 (axis {len(buffer_shape) - 1}):\n{srcs[1]}") + +def ifft3_print_src(buffer_shape: Tuple, normalize: bool = True, input_map: vd.MappingFunction = None, output_map: vd.MappingFunction = None): + srcs = ifft3_src(buffer_shape, normalize=normalize, input_map=input_map, output_map=output_map) + print(f"// IFFT Stage 1 (axis 0):\n{srcs[0]}\n// IFFT Stage 2 (axis 1):\n{srcs[1]}\n// IFFT Stage 3 (axis 2):\n{srcs[2]}") + +def rfft_print_src(buffer_shape: Tuple): + print(rfft_src(buffer_shape)) + +def rfft2_print_src(buffer_shape: Tuple): + srcs = rfft2_src(buffer_shape) + print(f"// RFFT Stage 1:\n{srcs[0]}\n// RFFT Stage 2 (axis {len(buffer_shape) - 2}):\n{srcs[1]}") + +def rfft3_print_src(buffer_shape: Tuple): + srcs = rfft3_src(buffer_shape) + print(f"// RFFT Stage 1:\n{srcs[0]}\n// RFFT Stage 2 (axis 1):\n{srcs[1]}\n// RFFT Stage 3 (axis 0):\n{srcs[2]}") + +def irfft_print_src(buffer_shape: Tuple, normalize: bool = True): + print(irfft_src(buffer_shape, normalize=normalize)) + +def irfft2_print_src(buffer_shape: Tuple, normalize: bool = True): + srcs = irfft2_src(buffer_shape, normalize=normalize) + print(f"// IRFFT Stage 1 (axis {len(buffer_shape) - 2}):\n{srcs[0]}\n// IRFFT Stage 2:\n{srcs[1]}") + +def irfft3_print_src(buffer_shape: Tuple, normalize: bool = True): + srcs = irfft3_src(buffer_shape, normalize=normalize) + print(f"// IRFFT Stage 1 (axis 0):\n{srcs[0]}\n// IRFFT Stage 2 (axis 1):\n{srcs[1]}\n// IRFFT Stage 3:\n{srcs[2]}") + +def convolve_print_src( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + kernel_num: int = 1, + axis: int = None, + normalize: bool = True, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None, + input_signal_range: Union[Tuple[Optional[int], Optional[int]], None] = None, + line_numbers: bool = False) -> vd.ShaderSource: + + print(convolve_src( + buffer_shape, + kernel_map=kernel_map, + kernel_num=kernel_num, + axis=axis, + normalize=normalize, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + input_map=input_map, + output_map=output_map, + input_signal_range=input_signal_range, + line_numbers=line_numbers + )) + +def convolve2D_print_src( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + normalize: bool = True, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + input_map: vd.MappingFunction = None, + output_map: vd.MappingFunction = None): + srcs = convolve2D_src( + buffer_shape, + kernel_map=kernel_map, + normalize=normalize, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + input_map=input_map, + output_map=output_map + ) + print(f"// FFT Stage (axis {len(buffer_shape) - 2}):\n{srcs[0]}\n// Convolution Stage (axis {len(buffer_shape) - 2}):\n{srcs[1]}\n// IFFT Stage:\n{srcs[2]}") + +def convolve2DR_print_src( + buffer_shape: Tuple, + kernel_map: vd.MappingFunction = None, + transposed_kernel: bool = False, + kernel_inner_only: bool = False, + normalize: bool = True): + srcs = convolve2DR_src( + buffer_shape, + kernel_map=kernel_map, + transposed_kernel=transposed_kernel, + kernel_inner_only=kernel_inner_only, + normalize=normalize + ) + print(f"// RFFT Stage:\n{srcs[0]}\n// Convolution Stage (axis {len(buffer_shape) - 2}):\n{srcs[1]}\n// IRFFT Stage:\n{srcs[2]}") + +def transpose_print_src( + buffer_shape: Tuple, + axis: int = None, + kernel_inner_only: bool = False, + line_numbers: bool = False) -> vd.Buffer: + + print(transpose_src( + buffer_shape, + axis=axis, + kernel_inner_only=kernel_inner_only, + line_numbers=line_numbers + )) \ No newline at end of file diff --git a/vkdispatch/fft/stages.py b/vkdispatch/fft/stages.py new file mode 100644 index 00000000..0cb348fd --- /dev/null +++ b/vkdispatch/fft/stages.py @@ -0,0 +1,198 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc +import dataclasses +from typing import List, Tuple, Dict + +from ..compat import numpy_compat as npc +from .prime_utils import prime_factors + +@dataclasses.dataclass +class FFTStagePlanInvocation: + fft_length: int + input_stride: int + instance_index: int + instance_index_stride: int + block_width: int + full_width_block: bool + instance_id0: int + inner_block_offset0: int + sub_sequence_offset0: int + register_selection: slice + + def __init__(self, + stage_fft_length: int, + stage_instance_count: int, + input_stride: int, + instance_index: int, + N: int): + self.fft_length = stage_fft_length + self.input_stride = input_stride + self.instance_index = instance_index + self.block_width = input_stride * stage_fft_length + self.instance_index_stride = N // (stage_fft_length * stage_instance_count) + + self.full_width_block = self.block_width == N + + # pretend tid is 0, used for calculating register shuffles + self.instance_id0 = self.instance_index_stride * instance_index + self.inner_block_offset0 = self.instance_id0 % input_stride + self.sub_sequence_offset0 = self.instance_id0 * stage_fft_length - self.inner_block_offset0 * (stage_fft_length - 1) + + self.register_selection = slice(instance_index * stage_fft_length, (instance_index + 1) * stage_fft_length) + + def get_offset(self, tid: vc.ShaderVariable): + return tid + self.instance_index_stride * self.instance_index + + def get_inner_block_offset(self, tid: vc.ShaderVariable): + if self.input_stride == 1: + return 0 + + if self.full_width_block: + return self.get_offset(tid) + + return self.get_offset(tid) % self.input_stride + + def get_sub_sequence_offset(self, tid: vc.ShaderVariable): + if self.full_width_block: + return self.get_offset(tid) + + return self.get_offset(tid) * self.fft_length - self.get_inner_block_offset(tid) * (self.fft_length - 1) + + def get_write_index(self, fft_index: int): + return self.sub_sequence_offset0 + fft_index * self.input_stride + + def get_read_index(self, offset: int): + return self.instance_id0 + offset + +@dataclasses.dataclass +class FFTRegisterStageConfig: + """ + Configuration for an FFT register stage. + + Attributes: + + primes (Tuple[int]): The prime numbers used for factorization. + fft_length (int): The length of each FFT stage. + instance_count (int): The number of instances required to achieve the desired level of parallelism. + registers_used (int): The total number of registers used by the FFT stage. + remainder (int): The remainder of `N` divided by `registers_used`. + remainder_offset (int): A flag indicating whether the remainder is non-zero. + extra_ffts (int): The additional number of FFT stages required to process the remainder. + thread_count (int): The total number of threads used in the computation. + sdata_size (int): The size of the shared memory buffer used to store intermediate results. + sdata_width (int): The width of each element in the shared memory buffer. + sdata_width_padded (int): The padded width of each element in the shared memory buffer. + + """ + + N: int + primes: Tuple[int] + fft_length: int + instance_count: int + registers_used: int + remainder: int + remainder_offset: int + extra_ffts: int + thread_count: int + sdata_size: int + sdata_width: int + sdata_width_padded: int + input_stride: int + output_stride: int + invocations: Tuple[FFTStagePlanInvocation] + + def __init__(self, primes: List[int], + max_register_count: int, + N: int, + compute_item_size: int, + input_stride: int): + """ + Initializes the FFTRegisterStageConfig object. + + Parameters: + + primes (List[int]): The prime numbers to use for factorization. + max_register_count (int): The maximum number of registers allowed per thread. + N (int): The length of the input data. + + """ + self.N = N + self.primes = tuple(primes) + self.input_stride = input_stride + self.fft_length = int(round(npc.prod(primes))) + self.output_stride = self.input_stride * self.fft_length + instance_primes = prime_factors(N // self.fft_length) + + self.instance_count = 1 + + while len(instance_primes) > 0: + if self.instance_count * self.fft_length * instance_primes[0] > max_register_count: + break + self.instance_count *= instance_primes[0] + instance_primes = instance_primes[1:] + + self.registers_used = self.fft_length * self.instance_count + + self.remainder = N % self.registers_used + assert self.remainder % self.fft_length == 0, "Remainder must be divisible by the FFT length" + self.remainder_offset = 1 if self.remainder != 0 else 0 + self.extra_ffts = self.remainder // self.fft_length + + self.thread_count = N // self.registers_used + self.remainder_offset + + self.sdata_width = self.registers_used + + threads_primes = prime_factors(self.thread_count) + + while self.sdata_width < 16 and len(threads_primes) > 0: + self.sdata_width *= threads_primes[0] + threads_primes = threads_primes[1:] + + self.sdata_width_padded = self.sdata_width + + if self.sdata_width_padded % 2 == 0: + self.sdata_width_padded += 1 + + self.sdata_size = self.sdata_width_padded * int(npc.prod(threads_primes)) + + if self.sdata_size > vd.get_context().max_shared_memory // compute_item_size: + self.sdata_width_padded = self.sdata_width + self.sdata_size = self.sdata_width_padded * int(npc.prod(threads_primes)) + + invocations = [] + for instance_index in range(self.instance_count): + invocations.append(FFTStagePlanInvocation( + stage_fft_length=self.fft_length, + stage_instance_count=self.instance_count, + input_stride=input_stride, + instance_index=instance_index, + N=N + )) + + self.invocations = tuple(invocations) + + def get_input_format(self, register_count: int) -> Dict[int, int]: + in_format = {} + + stride = self.N // self.fft_length + + register_index_list = list(range(register_count)) + + for invocation in self.invocations: + sub_registers = register_index_list[invocation.register_selection] + + for i in range(len(sub_registers)): + in_format[invocation.get_read_index(stride * i)] = sub_registers[i] + + return in_format + + def get_output_format(self, register_count: int) -> Dict[int, int]: + out_format = {} + + register_index_list = list(range(register_count)) + + for jj in range(self.fft_length): + for invocation in self.invocations: + out_format[invocation.get_write_index(jj)] = register_index_list[invocation.register_selection][jj] + + return out_format \ No newline at end of file diff --git a/vkdispatch/reduce/__init__.py b/vkdispatch/reduce/__init__.py new file mode 100644 index 00000000..3eb2279d --- /dev/null +++ b/vkdispatch/reduce/__init__.py @@ -0,0 +1,8 @@ +from .operations import ReduceOp, SubgroupAdd, SubgroupMul, SubgroupMin +from .operations import SubgroupMax, SubgroupAnd, SubgroupOr, SubgroupXor + +from .stage import make_reduction_stage, ReductionParams, mapped_io_index #, mapped_reduce_op + +from .reduce_function import ReduceFunction + +from .decorator import reduce, map_reduce \ No newline at end of file diff --git a/vkdispatch/shader_generation/decorators.py b/vkdispatch/reduce/decorator.py similarity index 50% rename from vkdispatch/shader_generation/decorators.py rename to vkdispatch/reduce/decorator.py index def19c0f..0cc1e189 100644 --- a/vkdispatch/shader_generation/decorators.py +++ b/vkdispatch/reduce/decorator.py @@ -4,6 +4,9 @@ import inspect from typing import Callable, TypeVar +from .stage import mapped_io_index, ReduceOp +from .reduce_function import ReduceFunction + import sys RetType = TypeVar('RetType') @@ -12,44 +15,9 @@ if sys.version_info >= (3, 10): from typing import ParamSpec P = ParamSpec('P') - P2 = ParamSpec('P2') else: P = ... # Placeholder for older Python versions - P2 = ... # Placeholder for older Python versions - -def shader( - exec_size=None, - local_size=None, - workgroups=None, - enable_subgroup_ops: bool = True, - enable_atomic_float_ops: bool = True, - enable_printf: bool = True, - enable_exec_bounds: bool = True): - if workgroups is not None and exec_size is not None: - raise ValueError("Cannot specify both 'workgroups' and 'exec_size'") - - def decorator(func: Callable[P, None]) -> Callable[P, None]: - shader_name = f"{func.__module__}.{func.__name__}" - with vc.builder_context( - enable_subgroup_ops=enable_subgroup_ops, - enable_atomic_float_ops=enable_atomic_float_ops, - enable_printf=enable_printf, - enable_exec_bounds=enable_exec_bounds - ) as builder: - signature = vd.ShaderSignature.from_inspectable_function(builder, func) - - func(*signature.get_variables()) - - return vd.ShaderObject( - builder.build(shader_name), - signature, - local_size=local_size, - workgroups=workgroups, - exec_count=exec_size - ) - - return decorator def reduce(identity, axes=None, group_size=None, mapping_function: vd.MappingFunction = None): def decorator(func: Callable[..., RetType]) -> Callable[[vd.Buffer[RetType]], vd.Buffer[RetType]]: @@ -62,14 +30,14 @@ def decorator(func: Callable[..., RetType]) -> Callable[[vd.Buffer[RetType]], vd if used_mapping_function is None: used_mapping_function = vd.map( - func = lambda buffer: buffer[vc.mapping_index()], + func = lambda buffer: buffer[mapped_io_index()], return_type=func_signature.return_annotation, input_types=[vc.Buffer[func_signature.return_annotation]]) else: assert used_mapping_function.return_type == func_signature.return_annotation, "Mapping function return type must match the return type of the reduction function" - return vd.ReductionObject( - reduction=vd.ReductionOperation( + return ReduceFunction( + reduction=ReduceOp( name=func.__name__, reduction=func, identity=identity @@ -82,15 +50,15 @@ def decorator(func: Callable[..., RetType]) -> Callable[[vd.Buffer[RetType]], vd return decorator -def map_reduce(reduction: vd.ReductionOperation, axes=None, group_size=None): - def decorator(func: Callable[P2, RetType2]) -> Callable[P2, vd.Buffer[RetType2]]: +def map_reduce(reduction: ReduceOp, axes=None, group_size=None): + def decorator_callback(func: Callable[P, RetType2]) -> Callable[P, vd.Buffer[RetType2]]: mapping_func = vd.map(func) - return vd.ReductionObject( + return ReduceFunction( reduction=reduction, group_size=group_size, axes=axes, mapping_function=mapping_func ) - return decorator \ No newline at end of file + return decorator_callback \ No newline at end of file diff --git a/vkdispatch/shader_generation/reduction_operations.py b/vkdispatch/reduce/operations.py similarity index 78% rename from vkdispatch/shader_generation/reduction_operations.py rename to vkdispatch/reduce/operations.py index 4d8ddce9..4081982b 100644 --- a/vkdispatch/shader_generation/reduction_operations.py +++ b/vkdispatch/reduce/operations.py @@ -7,58 +7,60 @@ from typing import Union from typing import Optional + + @dataclasses.dataclass -class ReductionOperation: +class ReduceOp: name: str reduction: Callable[[vc.ShaderVariable, vc.ShaderVariable], vc.ShaderVariable] identity: Union[int, float, str] subgroup_reduction: Optional[Callable[[vc.ShaderVariable], vc.ShaderVariable]] = None -SubgroupAdd = ReductionOperation( +SubgroupAdd = ReduceOp( name="add", reduction=lambda x, y: x + y, identity=0, subgroup_reduction=vc.subgroup_add ) -SubgroupMul = ReductionOperation( +SubgroupMul = ReduceOp( name="mul", reduction=lambda x, y: x * y, identity=1, subgroup_reduction=vc.subgroup_mul ) -SubgroupMin = ReductionOperation( +SubgroupMin = ReduceOp( name="min", reduction=lambda x, y: vc.min(x, y), - identity=vc.inf_f32, + identity="inf", subgroup_reduction=vc.subgroup_min ) -SubgroupMax = ReductionOperation( +SubgroupMax = ReduceOp( name="max", reduction=lambda x, y: vc.max(x, y), - identity=vc.ninf_f32, + identity="-inf", subgroup_reduction=vc.subgroup_max ) -SubgroupAnd = ReductionOperation( +SubgroupAnd = ReduceOp( name="and", reduction=lambda x, y: x & y, identity=-1, subgroup_reduction=vc.subgroup_and ) -SubgroupOr = ReductionOperation( +SubgroupOr = ReduceOp( name="or", reduction=lambda x, y: x | y, identity=0, subgroup_reduction=vc.subgroup_or ) -SubgroupXor = ReductionOperation( +SubgroupXor = ReduceOp( name="xor", reduction=lambda x, y: x ^ y, identity=0, subgroup_reduction=vc.subgroup_xor -) \ No newline at end of file +) diff --git a/vkdispatch/shader_generation/reduction_object.py b/vkdispatch/reduce/reduce_function.py similarity index 74% rename from vkdispatch/shader_generation/reduction_object.py rename to vkdispatch/reduce/reduce_function.py index 88de652d..e8438498 100644 --- a/vkdispatch/shader_generation/reduction_object.py +++ b/vkdispatch/reduce/reduce_function.py @@ -1,26 +1,26 @@ import vkdispatch as vd import vkdispatch.codegen as vc -from typing import Callable -from typing import List +from .operations import ReduceOp +from .stage import make_reduction_stage, ReductionParams -import numpy as np +from typing import List, Optional -class ReductionObject: +from ..compat import numpy_compat as npc + +class ReduceFunction: def __init__(self, - reduction: vd.ReductionOperation, + reduction: ReduceOp, group_size: int = None, axes: List[int] = None, - mapping_function: vd.MappingFunction = None): + mapping_function: Optional[vd.MappingFunction] = None): self.reduction = reduction - self.out_type = mapping_function.return_type #out_type + self.out_type = mapping_function.return_type self.group_size = group_size - self.map_func = mapping_function.callback # map_func - self.input_types = mapping_function.buffer_types # input_types if input_types is not None else [vc.Buffer[out_type]] + self.map_func = mapping_function + self.input_types = mapping_function.buffer_types self.axes = axes - assert len(mapping_function.register_types) == 0, "ReductionObject needs a MappingFunction with no registers!" - self.stage1 = None self.stage2 = None @@ -34,7 +34,7 @@ def make_stages(self): if self.group_size % vd.get_context().subgroup_size != 0: raise ValueError("Group size must be a multiple of the sub-group size!") - self.stage1 = vd.make_reduction_stage( + self.stage1 = make_reduction_stage( self.reduction, self.out_type, self.group_size, @@ -43,17 +43,32 @@ def make_stages(self): input_types=self.input_types ) - self.stage2 = vd.make_reduction_stage( + self.stage2 = make_reduction_stage( self.reduction, self.out_type, self.group_size, True, ) + + def get_src(self, line_numbers: bool = None) -> str: + self.make_stages() + + return [ + self.stage1.get_src(line_numbers), + self.stage2.get_src(line_numbers) + ] + + def print_src(self, line_numbers: bool = None): + srcs = self.get_src(line_numbers) + + print(f"// Reduction Stage 1:\n{srcs[0]}\n// Reduction Stage 2:\n{srcs[1]}") def __repr__(self) -> str: self.make_stages() - return f"Stage 1:\n{self.stage1}\nStage 2:\n{self.stage2}" + srcs = self.get_src() + + return f"// Reduction Stage 1:\n{srcs[0]}\n// Reduction Stage 2:\n{srcs[1]}" def __call__(self, *args, **kwargs) -> vd.Buffer: self.make_stages() @@ -98,7 +113,7 @@ def __call__(self, *args, **kwargs) -> vd.Buffer: assert input_stride == 1, "Reduction axes must be contiguous!" - workgroups_x = int(np.ceil(input_size / (self.group_size * input_stride))) + workgroups_x = int(npc.ceil(input_size / (self.group_size * input_stride))) if workgroups_x > self.group_size: workgroups_x = self.group_size @@ -113,7 +128,7 @@ def __call__(self, *args, **kwargs) -> vd.Buffer: reduction_buffer = vd.Buffer(tuple(output_buffer_shape), self.out_type) - stage1_params = vd.ReductionParams( + stage1_params = ReductionParams( input_offset=0, input_size=input_size, input_stride=input_stride, @@ -129,7 +144,7 @@ def __call__(self, *args, **kwargs) -> vd.Buffer: self.stage1(reduction_buffer, *args, stage1_params, exec_size=stage1_exec_size, graph=my_graph) - stage2_params = vd.ReductionParams( + stage2_params = ReductionParams( input_offset=batch_count, input_size=workgroups_x, input_stride=1, @@ -145,4 +160,4 @@ def __call__(self, *args, **kwargs) -> vd.Buffer: self.stage2(reduction_buffer, stage2_params, exec_size=stage2_exec_size, graph=my_graph) - return reduction_buffer \ No newline at end of file + return reduction_buffer diff --git a/vkdispatch/reduce/stage.py b/vkdispatch/reduce/stage.py new file mode 100644 index 00000000..3bce6759 --- /dev/null +++ b/vkdispatch/reduce/stage.py @@ -0,0 +1,190 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc +from typing import List, Optional + +from .operations import ReduceOp + +import dataclasses + +@dataclasses.dataclass +class ReductionParams: + input_offset: vd.uint32 + input_size: vd.uint32 + input_stride: vd.uint32 + input_y_batch_stride: vd.uint32 + input_z_batch_stride: vd.uint32 + + output_offset: vd.uint32 + output_stride: vd.uint32 + output_y_batch_stride: vd.uint32 + output_z_batch_stride: vd.uint32 + +__static_global_io_index: vc.ShaderVariable = None + +def set_mapped_io_index(io_index: vc.ShaderVariable): + global __static_global_io_index + __static_global_io_index = io_index + +def mapped_io_index() -> vc.ShaderVariable: + return __static_global_io_index + +def global_reduce( + reduction: ReduceOp, + out_type: vd.dtype, + buffers: List[vc.BufferVariable], + params: ReductionParams, + map_func: Optional[vd.MappingFunction] = None): + + ind = (vc.global_invocation_id().x * params.input_stride).to_register("ind") + + reduction_identity = reduction.identity + if reduction_identity == "inf": + reduction_identity = vc.inf_f32() if out_type == vd.float32 else vc.inf_f64() + elif reduction_identity == "-inf": + reduction_identity = vc.ninf_f32() if out_type == vd.float32 else vc.ninf_f64() + + reduction_aggregate = vc.new_register(out_type, reduction_identity, var_name="reduction_aggregate") + + batch_offset = vc.workgroup_id().y * params.input_y_batch_stride + inside_batch_offset = vc.workgroup_id().z * params.input_z_batch_stride + + start_index = vc.new_uint_register(params.input_offset + inside_batch_offset + batch_offset, var_name="start_index") + + current_index = vc.new_uint_register(start_index + ind, var_name="current_index") + + end_index = vc.new_uint_register(start_index + params.input_size, var_name="end_index") + + vc.while_statement(current_index < end_index) + + mapped_value = buffers[0][current_index] + + if map_func is not None: + set_mapped_io_index(current_index) + mapped_value = map_func.callback(*buffers) + set_mapped_io_index(None) + + reduction_aggregate[:] = reduction.reduction(reduction_aggregate, mapped_value) + + current_index += vc.workgroup_size().x * vc.num_workgroups().x + + vc.end() + + return reduction_aggregate + +def workgroup_reduce( + reduction_aggregate: vc.ShaderVariable, + reduction: ReduceOp, + out_type: vd.dtype, + group_size: int): + tid = vc.local_invocation_id().x + + sdata = vc.shared_buffer(out_type, group_size, var_name="sdata") + + sdata[tid] = reduction_aggregate + + vc.barrier() + + subgroup_reduce_size = vd.get_context().subgroup_size + + if not vd.get_context().subgroup_enabled: + subgroup_reduce_size = 1 + + current_size = group_size // 2 + while current_size > subgroup_reduce_size: + vc.if_statement(tid < current_size) + sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + current_size]) + if current_size // 2 > subgroup_reduce_size: + vc.end() + else: + tid_limit = 2 + + if subgroup_reduce_size != 1: + tid_limit = 2*vc.subgroup_size() + + vc.else_if_statement(tid < tid_limit) + sdata[tid] = vc.new_register(out_type, 0) + vc.end() + + vc.barrier() + + current_size //= 2 + + return sdata + +def subgroup_reduce( + sdata: vc.ShaderVariable, + reduction: ReduceOp, + group_size: int): + tid = vc.local_invocation_id().x + subgroup_reduce_size = vd.get_context().subgroup_size + + if not vd.get_context().subgroup_enabled: + subgroup_reduce_size = 1 + + if group_size > subgroup_reduce_size: + vc.if_statement(tid < subgroup_reduce_size) + sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + subgroup_reduce_size]) + vc.end() + + if subgroup_reduce_size == 1: + return sdata[tid].to_register("local_var") + + vc.subgroup_barrier() + + if reduction.subgroup_reduction is not None: + local_var = sdata[tid].to_register("local_var") + local_var[:] = reduction.subgroup_reduction(local_var) + + return local_var + else: + current_size = subgroup_reduce_size // 2 + while current_size > 1: + vc.if_statement(tid < current_size) + sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + current_size]) + vc.end() + vc.subgroup_barrier() + + current_size //= 2 + + result = reduction.reduction(sdata[tid], sdata[tid + current_size]) + + return result + +def make_reduction_stage( + reduction: ReduceOp, + out_type: vd.dtype, + group_size: int, + output_is_input: bool, + map_func: Optional[vd.MappingFunction] = None, + input_types: List = None) -> vd.ShaderFunction: + + name = f"reduction_stage_{reduction.name}_{out_type.name}_{input_types}_{group_size}" + + with vd.shader_context() as context: + signature_type_array = [] + + signature_type_array.append(vc.Buffer[out_type]) + + if input_types is not None: + signature_type_array.extend(input_types) + + signature_type_array.append(ReductionParams) + + input_variables = context.declare_input_arguments(signature_type_array) + + params: ReductionParams = input_variables[-1] + + input_buffers = input_variables[:-1] if output_is_input else input_variables[1:-1] + + reduction_aggregate = global_reduce(reduction, out_type, input_buffers, params, map_func) + sdata = workgroup_reduce(reduction_aggregate, reduction, out_type, group_size) + local_var = subgroup_reduce(sdata, reduction, group_size) + + batch_offset = vc.workgroup_id().y * params.output_y_batch_stride + output_offset = vc.workgroup_id().x * params.output_stride + + vc.if_statement(vc.local_invocation_id().x == 0) + input_variables[0][batch_offset + output_offset + params.output_offset] = local_var + vc.end() + + return context.get_function(local_size=(group_size, 1, 1), name=name) diff --git a/vkdispatch/shader/__init__.py b/vkdispatch/shader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vkdispatch/shader/context.py b/vkdispatch/shader/context.py new file mode 100644 index 00000000..2351ae8a --- /dev/null +++ b/vkdispatch/shader/context.py @@ -0,0 +1,47 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc + +from .signature import ShaderSignature + +from typing import List, Optional + +import contextlib + +class ShaderContext: + builder: vc.ShaderBuilder + signature: ShaderSignature + shader_function: vd.ShaderFunction + + def __init__(self, builder: vc.ShaderBuilder): + self.builder = builder + self.signature = None + + def get_function(self, + local_size=None, + workgroups=None, + exec_count=None, + name: Optional[str] = None) -> vd.ShaderFunction: + return vd.ShaderFunction.from_description( + self.builder.build("shader" if name is None else name), + self.signature, + local_size=local_size, + workgroups=workgroups, + exec_count=exec_count + ) + + def declare_input_arguments(self, annotations: List): + self.signature = ShaderSignature.from_type_annotations(self.builder, annotations) + return self.signature.get_variables() + +@contextlib.contextmanager +def shader_context(flags: vc.ShaderFlags = vc.ShaderFlags.NONE): + + builder = vc.ShaderBuilder(flags=flags, is_apple_device=vd.get_context().is_apple()) + old_builder = vc.set_builder(builder) + + context = ShaderContext(builder) + + try: + yield context + finally: + vc.set_builder(old_builder) \ No newline at end of file diff --git a/vkdispatch/shader/decorator.py b/vkdispatch/shader/decorator.py new file mode 100644 index 00000000..88e2ab8e --- /dev/null +++ b/vkdispatch/shader/decorator.py @@ -0,0 +1,54 @@ +import vkdispatch as vd +import vkdispatch.codegen as vc + +import inspect +from typing import Callable, TypeVar + +import sys + +if sys.version_info >= (3, 10): + from typing import ParamSpec + P = ParamSpec('P') +else: + P = ... # Placeholder for older Python versions + +def shader( + exec_size=None, + local_size=None, + workgroups=None, + flags: vc.ShaderFlags = vc.ShaderFlags.NONE): + """ + A decorator that transforms a Python function into a GPU Compute Shader. + + The decorated function will undergo runtime inspection. Operations performed on + ``vkdispatch`` types (buffers, registers) within the function are recorded and + transpiled to GLSL. + + :param exec_size: The total number of threads to dispatch (x, y, z). The number of + workgroups is calculated automatically based on ``local_size``. + Mutually exclusive with ``workgroups``. + :type exec_size: Union[int, Tuple[int, ...], Callable] + :param local_size: The number of threads per workgroup (x, y, z). Defaults to + the device's maximum supported workgroup size. + :type local_size: Union[int, Tuple[int, ...]] + :param workgroups: The explicit number of workgroups to dispatch (x, y, z). + Mutually exclusive with ``exec_size``. + :type workgroups: Union[int, Tuple[int, ...], Callable] + :param flags: Compilation flags (e.g., ``vc.ShaderFlags.NO_EXEC_BOUNDS``). + :type flags: vkdispatch.codegen.ShaderFlags + :return: A ``ShaderFunction`` wrapper that can be called to execute the kernel. + :raises ValueError: If both ``exec_size`` and ``workgroups`` are provided. + """ + if workgroups is not None and exec_size is not None: + raise ValueError("Cannot specify both 'workgroups' and 'exec_size'") + + def decorator_callback(func: Callable[P, None]) -> Callable[P, None]: + return vd.ShaderFunction( + func, + local_size=local_size, + workgroups=workgroups, + exec_count=exec_size, + flags=flags + ) + + return decorator_callback diff --git a/vkdispatch/shader_generation/mapping_shader.py b/vkdispatch/shader/map.py similarity index 74% rename from vkdispatch/shader_generation/mapping_shader.py rename to vkdispatch/shader/map.py index 3e85c928..6d27ccb6 100644 --- a/vkdispatch/shader_generation/mapping_shader.py +++ b/vkdispatch/shader/map.py @@ -10,7 +10,6 @@ @dataclasses.dataclass(frozen=True) class MappingFunction: buffer_types: List[vd.dtype] - register_types: List[vd.dtype] return_type: vd.dtype mapping_function: Callable @@ -29,23 +28,20 @@ def __eq__(self, other): def callback(self, *args): if self.return_type is None: - vc.new_scope() + vc.new_scope(indent=False) self.mapping_function(*args) - vc.end() + vc.end(indent=False) return - return_var = vc.new(self.return_type) + return_var = vc.new_register(self.return_type) - vc.new_scope() + vc.new_scope(indent=False) return_var[:] = self.mapping_function(*args) - vc.end() + vc.end(indent=False) return return_var -def map(func: Callable, register_types: List[vd.dtype] = None, return_type: vd.dtype = None, input_types: List[vd.dtype] = None) -> MappingFunction: - if register_types is None: - register_types = [] - +def map(func: Callable, return_type: vd.dtype = None, input_types: List[vd.dtype] = None) -> MappingFunction: if return_type is None: func_signature = inspect.signature(func) @@ -71,12 +67,5 @@ def map(func: Callable, register_types: List[vd.dtype] = None, return_type: vd.d return MappingFunction( buffer_types=input_types, return_type=return_type, - mapping_function=func, - register_types=register_types - ) - -def map_registers(register_types: List[vd.dtype]) -> Callable[[Callable], MappingFunction]: - def decorator(func: Callable): - return map(func, register_types) - - return decorator \ No newline at end of file + mapping_function=func + ) \ No newline at end of file diff --git a/vkdispatch/shader_generation/shader_object.py b/vkdispatch/shader/shader_function.py similarity index 57% rename from vkdispatch/shader_generation/shader_object.py rename to vkdispatch/shader/shader_function.py index 583731f3..8f155d75 100644 --- a/vkdispatch/shader_generation/shader_object.py +++ b/vkdispatch/shader/shader_function.py @@ -7,11 +7,15 @@ from typing import List from typing import Any +from vkdispatch.base.compute_plan import ComputePlan + +from .signature import ShaderArgumentType, ShaderSignature + import uuid import dataclasses -import numpy as np +from ..compat import numpy_compat as npc class LaunchParametersHolder: def __init__(self, names_and_defaults, args, kwargs) -> None: @@ -52,7 +56,7 @@ class ExectionBounds: def __init__(self, names_and_defaults, local_size, workgroups, exec_size) -> None: self.names_and_defaults = names_and_defaults - self.local_size = local_size + self.local_size = tuple(local_size) self.workgroups = workgroups self.exec_size = exec_size @@ -67,7 +71,7 @@ def process_input(self, in_val, args, kwargs) -> Tuple[int, int, int]: if callable(in_val): in_val = in_val(LaunchParametersHolder(self.names_and_defaults, args, kwargs)) - if isinstance(in_val, int) or np.issubdtype(type(in_val), np.integer): + if npc.is_integer_scalar(in_val): return (in_val, 1, 1) # type: ignore if not isinstance(in_val, tuple): @@ -79,7 +83,7 @@ def process_input(self, in_val, args, kwargs) -> Tuple[int, int, int]: return_val = [1, 1, 1] for ii, val in enumerate(in_val): - if not isinstance(val, int) and not np.issubdtype(type(val), np.integer): + if not npc.is_integer_scalar(val): raise ValueError("All dimensions must be integers!") return_val[ii] = val @@ -128,24 +132,70 @@ def get_blocks_and_limits(self, args, kwargs) -> Tuple[Tuple[int, int, int], Tup return (my_blocks, my_limits) -class ShaderObject: - plan: vd.ComputePlan +@dataclasses.dataclass +class ShaderSource: + name: str + code: str + local_size: Tuple[int, int, int] + + def __repr__(self): + return f"// ====== Source Code for '{self.name}', workgroup_size: {self.local_size} ======\n{self.code}" + +class ShaderFunction: + plan: ComputePlan + func: Callable shader_description: vc.ShaderDescription - shader_signature: vd.ShaderSignature + shader_signature: ShaderSignature bounds: ExectionBounds ready: bool + name: str source: str + flags: vc.ShaderFlags + local_size: Union[Tuple[int, int, int], Callable, None] + workgroups: Union[Tuple[int, int, int], Callable, None] + exec_size: Union[Tuple[int, int, int], Callable, None] - def __init__(self, description: vc.ShaderDescription, signature: vd.ShaderSignature, local_size=None, workgroups=None, exec_count=None) -> None: + def __init__(self, + func: Callable, + local_size=None, + workgroups=None, + exec_count=None, + flags: vc.ShaderFlags = vc.ShaderFlags.NONE, + name: str = None) -> None: + self.plan = None - self.shader_description = description - self.shader_signature = signature + self.func = func + self.shader_description = None + self.shader_signature = None self.bounds = None self.ready = False + self.name = name if name is not None else func.__name__ if func is not None else None self.source = None self.local_size = local_size self.workgroups = workgroups self.exec_size = exec_count + self.flags = flags + + def from_description( + shader_description: vc.ShaderDescription, + shader_signature: ShaderSignature, + local_size=None, + workgroups=None, + exec_count=None, + + ) -> "ShaderFunction": + shader_obj = ShaderFunction( + func=None, + local_size=local_size, + workgroups=workgroups, + exec_count=exec_count, + flags=vc.ShaderFlags.NONE + ) + + shader_obj.shader_description = shader_description + shader_obj.shader_signature = shader_signature + + return shader_obj def build(self): if self.ready: @@ -157,41 +207,136 @@ def build(self): else [vd.get_context().max_workgroup_size[0], 1, 1] ) + if self.shader_description is None or self.shader_signature is None: + assert self.shader_description is None and self.shader_signature is None, "Shader description and signature must both be set or both be None!" + assert self.func is not None, "Cannot build a shader without a function!" + + builder = vc.ShaderBuilder( + flags=self.flags, + is_apple_device=vd.get_context().is_apple() + ) + old_builder = vc.set_builder(builder) + + try: + signature = ShaderSignature.from_inspectable_function(builder, self.func) + self.func(*signature.get_variables()) + except Exception as e: + print(f"Error during shader inspection: {e}") + raise e + finally: + vc.set_builder(old_builder) + + self.shader_description = builder.build(self.func.__module__ + "." + self.func.__name__) + self.shader_signature = signature + + # Resource bindings are declared before final shader layout is known. + # For some shader construction paths (e.g. from_description), signatures are + # pre-populated and still hold logical bindings assuming a reserved UBO at 0. + binding_shift = self.shader_description.resource_binding_base - 1 + if binding_shift != 0: + binding_access_len = len(self.shader_description.binding_access) + needs_remap = False + + for shader_arg in self.shader_signature.arguments: + if ( + shader_arg.binding is not None + and ( + shader_arg.arg_type == ShaderArgumentType.BUFFER + or shader_arg.arg_type == ShaderArgumentType.IMAGE + ) + and shader_arg.binding >= binding_access_len + ): + needs_remap = True + break + + if needs_remap: + for shader_arg in self.shader_signature.arguments: + if ( + shader_arg.binding is not None + and ( + shader_arg.arg_type == ShaderArgumentType.BUFFER + or shader_arg.arg_type == ShaderArgumentType.IMAGE + ) + ): + shader_arg.binding += binding_shift + self.bounds = ExectionBounds(self.shader_signature.get_names_and_defaults(), my_local_size, self.workgroups, self.exec_size) + shader_backend_name = ( + self.shader_description.backend.name + if self.shader_description.backend is not None + else "glsl" + ) + + if vd.is_dummy(): + pass + elif vd.is_cuda() and shader_backend_name != "cuda": + raise RuntimeError( + "The selected CUDA runtime backend requires CUDA codegen output. " + "Call vd.initialize(backend='cuda') " + "before building shaders." + ) + elif vd.is_opencl() and shader_backend_name != "opencl": + raise RuntimeError( + "The selected OpenCL runtime backend requires OpenCL codegen output. " + "Call vd.initialize(backend='opencl') " + "before building shaders." + ) + elif vd.is_vulkan() and shader_backend_name == "cuda": + raise RuntimeError( + "Vulkan runtime backend cannot execute CUDA codegen output. " + "Use GLSL codegen or initialize with backend='cuda'." + ) + elif vd.is_vulkan() and shader_backend_name == "opencl": + raise RuntimeError( + "Vulkan runtime backend cannot execute OpenCL codegen output. " + "Use GLSL codegen or initialize with backend='opencl'." + ) + self.source = self.shader_description.make_source( my_local_size[0], my_local_size[1], my_local_size[2] ) try: - self.plan = vd.ComputePlan( - self.source, - self.shader_description.binding_type_list, - self.shader_description.pc_size, - self.shader_description.name - ) + if not vd.is_dummy(): + self.plan = ComputePlan( + self.source, + self.shader_description.binding_type_list, + self.shader_description.pc_size, + self.shader_description.name + ) except Exception as e: print(f"Error building shader: {e}") - print(self.make_repr()) + print(self.get_src(build=False, line_numbers=True)) raise e self.ready = True def __repr__(self) -> str: - self.build() - return self.make_repr() + return self.get_src().__repr__() - def make_repr(self, line_numbers: bool = True) -> str: + def get_src(self, line_numbers: bool = None, build: bool = True) -> ShaderSource: + if build: + self.build() + result = "" + if line_numbers is None: + line_numbers = vc.get_shader_print_line_numbers() + for ii, line in enumerate(self.source.split("\n")): line_prefix = f"{ii + 1:4d}: " if line_numbers else "" result += f"{line_prefix}{line}\n" - return result + return ShaderSource(name=self.name, code=result, local_size=self.bounds.local_size) + + def print_src(self, line_numbers: bool = None): + print(self.get_src(line_numbers)) def __call__(self, *args, **kwargs): + assert not vd.is_dummy(), "Cannot execute shader functions with dummy backend!" + self.build() if not self.ready: @@ -231,9 +376,9 @@ def __call__(self, *args, **kwargs): else: arg = kwargs[shader_arg.name] - if shader_arg.arg_type == vd.ShaderArgumentType.BUFFER: + if shader_arg.arg_type == ShaderArgumentType.BUFFER: if not isinstance(arg, vd.Buffer): - raise ValueError(f"Expected a buffer for argument '{shader_arg.name}'!") + raise ValueError(f"Expected a buffer for argument '{shader_arg.name}' but got '{arg}'!") bound_buffers.append(vd.BufferBindInfo( buffer=arg, @@ -243,7 +388,7 @@ def __call__(self, *args, **kwargs): write_access=self.shader_description.binding_access[shader_arg.binding][1] )) - elif shader_arg.arg_type == vd.ShaderArgumentType.IMAGE: + elif shader_arg.arg_type == ShaderArgumentType.IMAGE: if not isinstance(arg, vd.Sampler): raise ValueError(f"Expected an image for argument '{shader_arg.name}'!") @@ -254,20 +399,20 @@ def __call__(self, *args, **kwargs): write_access=self.shader_description.binding_access[shader_arg.binding][1] )) - elif shader_arg.arg_type == vd.ShaderArgumentType.CONSTANT: + elif shader_arg.arg_type == ShaderArgumentType.CONSTANT: if callable(arg): raise ValueError("Cannot use LaunchVariables for Constants") uniform_values[shader_arg.shader_name] = arg - elif shader_arg.arg_type == vd.ShaderArgumentType.CONSTANT_DATACLASS: + elif shader_arg.arg_type == ShaderArgumentType.CONSTANT_DATACLASS: if callable(arg): raise ValueError("Cannot use LaunchVariables for Constants") for field in dataclasses.fields(arg): uniform_values[shader_arg.shader_name[field.name]] = getattr(arg, field.name) - elif shader_arg.arg_type == vd.ShaderArgumentType.VARIABLE: + elif shader_arg.arg_type == ShaderArgumentType.VARIABLE: if len(self.shader_description.pc_structure) == 0: raise ValueError("Something went wrong with push constants!!") @@ -292,4 +437,3 @@ def __call__(self, *args, **kwargs): pc_values, shader_uuid=shader_uuid ) - diff --git a/vkdispatch/shader_generation/signature.py b/vkdispatch/shader/signature.py similarity index 93% rename from vkdispatch/shader_generation/signature.py rename to vkdispatch/shader/signature.py index 4c8b808d..8d6f4a46 100644 --- a/vkdispatch/shader_generation/signature.py +++ b/vkdispatch/shader/signature.py @@ -19,6 +19,16 @@ import enum +_PUSH_CONSTANT_UNSUPPORTED_BACKENDS = set() + + +def _push_constant_not_supported_error(backend_name: str) -> str: + return ( + f"Push Constants are not supported for the {backend_name.upper()} backend. " + "Use Const instead." + ) + + class ShaderArgumentType(enum.Enum): BUFFER = 0 IMAGE = 1 @@ -139,6 +149,9 @@ def from_type_annotations(cls, value_name = shader_param.raw_name arg_type = ShaderArgumentType.CONSTANT elif(issubclass(annotations[i].__origin__, vc.Variable)): + if builder.backend.name in _PUSH_CONSTANT_UNSUPPORTED_BACKENDS: + raise NotImplementedError(_push_constant_not_supported_error(builder.backend.name)) + shader_param = builder.declare_variable(annotations[i].__args__[0]) arg_type = ShaderArgumentType.VARIABLE value_name = shader_param.raw_name @@ -164,6 +177,3 @@ def get_variables(self) -> List[vc.ShaderVariable]: def get_names_and_defaults(self) -> List[Tuple[str, Any]]: return [(arg.name, arg.default_value) for arg in self.arguments] - -# def get_func_args(self) -> List[Tuple[str, str, Any]]: -# return [(arg.shader_name, arg.name, arg.default_value) for arg in self.arguments] diff --git a/vkdispatch/shader_generation/reduction_stage.py b/vkdispatch/shader_generation/reduction_stage.py deleted file mode 100644 index fce7f1ec..00000000 --- a/vkdispatch/shader_generation/reduction_stage.py +++ /dev/null @@ -1,161 +0,0 @@ -import vkdispatch as vd -import vkdispatch.codegen as vc - -from typing import Callable -from typing import List - -import dataclasses - -@dataclasses.dataclass -class ReductionParams: - input_offset: vd.int32 - input_size: vd.int32 - input_stride: vd.int32 - input_y_batch_stride: vd.int32 - input_z_batch_stride: vd.int32 - - output_offset: vd.int32 - output_stride: vd.int32 - output_y_batch_stride: vd.int32 - output_z_batch_stride: vd.int32 - -def global_reduce( - reduction: vd.ReductionOperation, - out_type: vd.dtype, - buffers: List[vc.BufferVariable], - params: ReductionParams, - map_func: Callable = None): - - ind = (vc.global_invocation().x * params.input_stride).copy("ind") - reduction_aggregate = vc.new(out_type, reduction.identity, var_name="reduction_aggregate") - - batch_offset = vc.workgroup().y * params.input_y_batch_stride - inside_batch_offset = vc.workgroup().z * params.input_z_batch_stride - - start_index = vc.new_uint(params.input_offset + inside_batch_offset + batch_offset, var_name="start_index") - - current_index = vc.new_uint(start_index + ind, var_name="current_index") - - end_index = vc.new_uint(start_index + params.input_size, var_name="end_index") - - vc.while_statement(current_index < end_index) - - mapped_value = buffers[0][current_index] - - - if map_func is not None: - vc.set_mapping_index(current_index) - mapped_value = map_func(*buffers) - - reduction_aggregate[:] = reduction.reduction(reduction_aggregate, mapped_value) - - current_index += vc.workgroup_size().x * vc.num_workgroups().x - - vc.end() - - return reduction_aggregate - -def workgroup_reduce( - reduction_aggregate: vc.ShaderVariable, - reduction: vd.ReductionOperation, - out_type: vd.dtype, - group_size: int): - tid = vc.local_invocation().x - - sdata = vc.shared_buffer(out_type, group_size, var_name="sdata") - - sdata[tid] = reduction_aggregate - - vc.barrier() - - current_size = group_size // 2 - while current_size > vd.get_context().subgroup_size: - vc.if_statement(tid < current_size) - sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + current_size]) - if current_size // 2 > vd.get_context().subgroup_size: - vc.end() - else: - vc.else_if_statement(tid < 2*vc.subgroup_size()) - sdata[tid] = vc.new(out_type, 0) - vc.end() - - vc.barrier() - - current_size //= 2 - - return sdata - -def subgroup_reduce( - sdata: vc.ShaderVariable, - reduction: vd.ReductionOperation, - group_size: int): - tid = vc.local_invocation().x - subgroup_size = vd.get_context().subgroup_size - - if group_size > subgroup_size: - vc.if_all(tid < subgroup_size) - sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + subgroup_size]) - vc.end() - vc.subgroup_barrier() - - - if reduction.subgroup_reduction is not None: - local_var = sdata[tid].copy("local_var") - local_var[:] = reduction.subgroup_reduction(local_var) - - return local_var - else: - current_size = subgroup_size // 2 - while current_size > 1: - vc.if_statement(tid < current_size) - sdata[tid] = reduction.reduction(sdata[tid], sdata[tid + current_size]) - vc.end() - vc.subgroup_barrier() - - current_size //= 2 - - result = reduction.reduction(sdata[tid], sdata[tid + current_size]) - - return result - -def make_reduction_stage( - reduction: vd.ReductionOperation, - out_type: vd.dtype, - group_size: int, - output_is_input: bool, - name: str = None, - map_func: Callable = None, - input_types: List = None) -> vd.ShaderObject: - - if name is None: - name = f"reduction_stage_{reduction.name}_{out_type.name}_{input_types}_{group_size}" - - with vc.builder_context() as builder: - signature_type_array = [] - - signature_type_array.append(vc.Buffer[out_type]) - - if input_types is not None: - signature_type_array.extend(input_types) - - signature_type_array.append(ReductionParams) - - signature = vd.ShaderSignature.from_type_annotations(builder, signature_type_array) - input_variables = signature.get_variables() - - params: ReductionParams = input_variables[-1] - - input_buffers = input_variables[:-1] if output_is_input else input_variables[1:-1] - - reduction_aggregate = global_reduce(reduction, out_type, input_buffers, params, map_func) - sdata = workgroup_reduce(reduction_aggregate, reduction, out_type, group_size) - local_var = subgroup_reduce(sdata, reduction, group_size) - - batch_offset = vc.workgroup().y * params.output_y_batch_stride - output_offset = vc.workgroup().x * params.output_stride - - vc.if_statement(vc.local_invocation().x == 0) - input_variables[0][batch_offset + output_offset + params.output_offset] = local_var - vc.end() - - return vd.ShaderObject(builder.build(name), signature, local_size=(group_size, 1, 1)) diff --git a/vkdispatch/tests/test_builder.py b/vkdispatch/tests/test_builder.py deleted file mode 100644 index b5ed2538..00000000 --- a/vkdispatch/tests/test_builder.py +++ /dev/null @@ -1,110 +0,0 @@ -import vkdispatch as vd -import vkdispatch.codegen as vc - -import numpy as np - -vd.initialize(log_level=vd.LogLevel.WARNING) - -def test_builder_basic(): - buff = vd.asbuffer(np.array([1, 2, 3, 4], dtype=np.float32)) - buff2 = vd.asbuffer(np.array([10, 20, 30, 40], dtype=np.float32)) - - uniform_buffer = vd.Buffer((vd.get_context().uniform_buffer_alignment, ), vd.float32) - - my_builder = vc.ShaderBuilder() - - var_buff = my_builder.declare_buffer(vc.f32) - var_buff2 = my_builder.declare_buffer(vc.f32) - - uniform_var = my_builder.declare_constant(vc.f32) - - var_buff[my_builder.global_invocation.x] += var_buff2[my_builder.global_invocation.x] - uniform_var - - shader_description = my_builder.build("my_shader") - - source = shader_description.make_source(4, 1, 1) - - compute_plan = vd.ComputePlan(source, shader_description.binding_type_list, shader_description.pc_size, shader_description.name) - - descriptor_set = vd.DescriptorSet(compute_plan) - - descriptor_set.bind_buffer(uniform_buffer, 0, uniform=True) - descriptor_set.bind_buffer(buff, var_buff.binding) - descriptor_set.bind_buffer(buff2, var_buff2.binding) - - uniform_buffer_builder = vd.BufferBuilder(usage=vd.BufferUsage.UNIFORM_BUFFER) - uniform_buffer_builder.register_struct("my_shader", shader_description.uniform_structure) - uniform_buffer_builder.prepare(1) - uniform_buffer_builder[("my_shader", shader_description.exec_count_name)] = [2, 1, 1, 0] - uniform_buffer_builder[("my_shader", uniform_var.raw_name)] = 5 - - uniform_buffer.write(uniform_buffer_builder.tobytes()) - - cmd_list = vd.CommandList() - - cmd_list.record_compute_plan(compute_plan, descriptor_set, [1, 1, 1]) - - cmd_list.submit(instance_count=1) - cmd_list.submit(instance_count=1) - - assert np.allclose(buff.read(0), np.array([11, 32, 3, 4], dtype=np.float32)) - - -def test_custom_GLSL_shader(): - buff = vd.asbuffer(np.array([1, 2, 3, 4], dtype=np.float32)) - buff2 = vd.asbuffer(np.array([10, 20, 30, 40], dtype=np.float32)) - - uniform_buffer = vd.Buffer((vd.get_context().uniform_buffer_alignment, ), vd.float32) - - source = """ -#version 450 -#extension GL_ARB_separate_shader_objects : enable -#extension GL_KHR_shader_subgroup_arithmetic : enable -#extension GL_EXT_debug_printf : enable - -layout(set = 0, binding = 0) uniform UniformObjectBuffer { - uvec4 exec_count; - float var0; -} UBO; -layout(set = 0, binding = 1) buffer Buffer1 { float data[]; } buf1; -layout(set = 0, binding = 2) buffer Buffer2 { float data[]; } buf2; - -layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; -void main() { - if((UBO.exec_count.x <= gl_GlobalInvocationID.x)) { - return ; - } - buf1.data[gl_GlobalInvocationID.x] += (buf2.data[gl_GlobalInvocationID.x] - UBO.var0); - -} -""" - - shader_uniform_structure = [ - vc.StructElement("exec_count", vc.uv4, 1), - vc.StructElement("var0", vc.f32, 1) - ] - - compute_plan = vd.ComputePlan(source, [3, 1, 1], 0, "my_shader") - - descriptor_set = vd.DescriptorSet(compute_plan) - - descriptor_set.bind_buffer(uniform_buffer, 0, uniform=True) - descriptor_set.bind_buffer(buff, 1) - descriptor_set.bind_buffer(buff2, 2) - - uniform_buffer_builder = vd.BufferBuilder(usage=vd.BufferUsage.UNIFORM_BUFFER) - uniform_buffer_builder.register_struct("my_shader", shader_uniform_structure) - uniform_buffer_builder.prepare(1) - uniform_buffer_builder[("my_shader", "exec_count")] = [2, 1, 1, 0] - uniform_buffer_builder[("my_shader", "var0")] = 5 - - uniform_buffer.write(uniform_buffer_builder.tobytes()) - - cmd_list = vd.CommandList() - - cmd_list.record_compute_plan(compute_plan, descriptor_set, [1, 1, 1]) - - cmd_list.submit(instance_count=1) - cmd_list.submit(instance_count=1) - - assert np.allclose(buff.read(0), np.array([11, 32, 3, 4], dtype=np.float32)) \ No newline at end of file diff --git a/vkdispatch/vkfft/__init__.py b/vkdispatch/vkfft/__init__.py index 69d9e6dd..2d96d064 100644 --- a/vkdispatch/vkfft/__init__.py +++ b/vkdispatch/vkfft/__init__.py @@ -1,9 +1,9 @@ -from .fft_plan import VkFFTPlan +from .vkfft_plan import VkFFTPlan -from .fft_dispatcher import fft, fft2, fft3 -from .fft_dispatcher import ifft, ifft2, ifft3 -from .fft_dispatcher import rfft, rfft2, rfft3 -from .fft_dispatcher import irfft, irfft2, irfft3 -from .fft_dispatcher import clear_plan_cache, convolve_2D +from .vkfft_dispatcher import fft, fft2, fft3 +from .vkfft_dispatcher import ifft, ifft2, ifft3 +from .vkfft_dispatcher import rfft, rfft2, rfft3 +from .vkfft_dispatcher import irfft, irfft2, irfft3 +from .vkfft_dispatcher import clear_plan_cache, convolve2D, transpose_kernel2D #from .fft_dispatcher import ifft, irfft, create_kernel_2Dreal, convolve_2Dreal #from .fft_dispatcher import reset_fft_plans \ No newline at end of file diff --git a/vkdispatch/vkfft/fft_dispatcher.py b/vkdispatch/vkfft/vkfft_dispatcher.py similarity index 76% rename from vkdispatch/vkfft/fft_dispatcher.py rename to vkdispatch/vkfft/vkfft_dispatcher.py index 383e3d8f..e289293b 100644 --- a/vkdispatch/vkfft/fft_dispatcher.py +++ b/vkdispatch/vkfft/vkfft_dispatcher.py @@ -1,15 +1,13 @@ - from typing import Tuple -from typing import Union +from typing import Union, Optional from typing import List -import numpy as np - import vkdispatch as vd -from .fft_plan import VkFFTPlan +from .vkfft_plan import VkFFTPlan import dataclasses +from functools import lru_cache from typing import Dict from typing import Union @@ -39,15 +37,42 @@ def sanitize_input_tuple(input: Tuple) -> Tuple: return tuple(input) -__fft_plans: Dict[FFTConfig, VkFFTPlan] = {} +@lru_cache(maxsize=None) +def get_fft_plan( + shape: Tuple[int, ...], + do_r2c: bool = False, + axes: Tuple[int] = None, + normalize: bool = False, + padding: Tuple[Tuple[int, int]] = None, + pad_frequency_domain: bool = False, + kernel_count: int = 0, + input_shape: Tuple[int, ...] = None, + input_type: vd.dtype = None, + kernel_convolution: bool = False, + conjugate_convolution: bool = False, + convolution_features: int = 1, + num_batches: int = 1, + keep_shader_code: bool = False) -> VkFFTPlan: + + return VkFFTPlan( + shape=shape, + do_r2c=do_r2c, + axes=axes, + normalize=normalize, + padding=padding, + pad_frequency_domain=pad_frequency_domain, + kernel_count=kernel_count, + input_shape=input_shape, + input_type=input_type, + kernel_convolution=kernel_convolution, + conjugate_convolution=conjugate_convolution, + convolution_features=convolution_features, + num_batches=num_batches, + keep_shader_code=keep_shader_code + ) def clear_plan_cache(): - global __fft_plans - - for plan in __fft_plans.values(): - plan.destroy() - - __fft_plans = {} + get_fft_plan.cache_clear() def execute_fft_plan( buffer: vd.Buffer, @@ -55,12 +80,11 @@ def execute_fft_plan( config: FFTConfig, kernel: vd.Buffer = None, input: vd.Buffer = None, - graph: Union[vd.CommandList, vd.CommandGraph, None] = None): + graph: Optional[vd.CommandGraph] = None): if graph is None: graph = vd.global_graph() - if config not in __fft_plans: - __fft_plans[config] = VkFFTPlan( + plan = get_fft_plan( shape=config.shape, do_r2c=config.do_r2c, axes=config.axes, @@ -76,8 +100,6 @@ def execute_fft_plan( num_batches=config.num_batches, keep_shader_code=config.keep_shader_code ) - - plan = __fft_plans[config] plan.record(graph, buffer, inverse, kernel, input) if isinstance(graph, vd.CommandGraph): @@ -103,7 +125,7 @@ def convolve_2Dreal( input: Union[vd.Buffer[vd.float32], vd.RFFTBuffer] = None, normalize: bool = False, conjugate_kernel: bool = False, - graph: Union[vd.CommandList, vd.CommandGraph, None] = None, + graph: Optional[vd.CommandGraph] = None, keep_shader_code: bool = False): buffer_shape = sanitize_2d_convolution_buffer_shape(buffer) @@ -147,7 +169,7 @@ def create_kernel_2Dreal( kernel: vd.RFFTBuffer, shape: Tuple[int, ...] = None, feature_count: int = 1, - graph: Union[vd.CommandList, vd.CommandGraph, None] = None, + graph: Optional[vd.CommandGraph] = None, keep_shader_code: bool = False) -> vd.RFFTBuffer: if shape is None: @@ -174,13 +196,12 @@ def create_kernel_2Dreal( return kernel - def convolve_2D( buffer: vd.Buffer, - kernel: Union[vd.Buffer[vd.float32], vd.Buffer], + kernel: vd.Buffer, normalize: bool = False, conjugate_kernel: bool = False, - graph: Union[vd.CommandList, vd.CommandGraph, None] = None, + graph: Optional[vd.CommandGraph] = None, keep_shader_code: bool = False, padding: Tuple[Tuple[int, int]] = None): @@ -215,6 +236,66 @@ def convolve_2D( kernel=kernel ) + +def transpose_kernel2D( + kernel: vd.Buffer, + shape: Tuple[int, ...] = None, + graph: Optional[vd.CommandGraph] = None, + keep_shader_code: bool = False): + if shape is None: + shape = kernel.shape + + if len(shape) == 2: + shape = (1,) + shape + + assert len(shape) == 3, "Kernel shape must be 2D or 3D!" + + execute_fft_plan( + kernel, + False, + graph = graph, + config = FFTConfig( + buffer_handle=kernel._handle, + shape=shape[1:], + kernel_convolution=True, + convolution_features=1, + num_batches=shape[0], + keep_shader_code=keep_shader_code + ) + ) + +def convolve2D( + buffer: vd.Buffer, + kernel: Union[vd.Buffer[vd.float32], vd.Buffer], + normalize: bool = False, + conjugate_kernel: bool = False, + graph: Optional[vd.CommandGraph] = None, + keep_shader_code: bool = False, + padding: Tuple[Tuple[int, int]] = None): + + in_shape = sanitize_input_tuple(buffer.shape) + + if len(in_shape) == 2: + in_shape = (1,) + in_shape + + execute_fft_plan( + buffer, + False, + graph = graph, + config = FFTConfig( + buffer_handle=buffer._handle, + shape=in_shape[1:], + normalize=normalize, + kernel_count=1, + conjugate_convolution=conjugate_kernel, + convolution_features=1, + keep_shader_code=keep_shader_code, + num_batches=buffer.shape[0], + padding=padding + ), + kernel=kernel + ) + def fft( buffer: vd.Buffer, input_buffer: vd.Buffer = None, @@ -315,4 +396,4 @@ def irfft2(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: b def irfft3(buffer: vd.RFFTBuffer, graph: vd.CommandGraph = None, print_shader: bool = False): assert len(buffer.real_shape) == 3, 'Buffer must have 3 dimensions' - irfft(buffer, graph=graph, print_shader=print_shader, axis=(0, 1, 2)) \ No newline at end of file + irfft(buffer, graph=graph, print_shader=print_shader, axis=(0, 1, 2)) diff --git a/vkdispatch/vkfft/fft_plan.py b/vkdispatch/vkfft/vkfft_plan.py similarity index 69% rename from vkdispatch/vkfft/fft_plan.py rename to vkdispatch/vkfft/vkfft_plan.py index 511e23ac..0ad12dea 100644 --- a/vkdispatch/vkfft/fft_plan.py +++ b/vkdispatch/vkfft/vkfft_plan.py @@ -1,12 +1,12 @@ -import numpy as np - -import vkdispatch_native +from vkdispatch.backends.backend_selection import native import vkdispatch as vd from typing import List from typing import Tuple +from vkdispatch.base.errors import check_for_errors + from ..base.context import get_context, Context, Handle class VkFFTPlan(Handle): @@ -35,9 +35,9 @@ def __init__(self, self.shape = shape self.do_r2c = do_r2c - self.mem_size = ( - np.prod(shape) * np.dtype(np.complex64).itemsize - ) # currently only support complex64 + self.mem_size = vd.complex64.item_size + for dim in shape: + self.mem_size *= dim if axes is None: axes = [0, 1, 2] @@ -58,14 +58,13 @@ def __init__(self, input_size = 0 if input_shape is not None: - input_buffer_type = np.dtype(np.complex64) + input_buffer_type = vd.complex64 if input_type is None else input_type - if input_type is not None: - input_buffer_type = np.dtype(vd.to_numpy_dtype(input_type)) + input_size = input_buffer_type.item_size + for dim in input_shape: + input_size *= dim - input_size = np.prod(input_shape) * input_buffer_type.itemsize - - handle = vkdispatch_native.stage_fft_plan_create( + handle = native.stage_fft_plan_create( self.context._handle, list(reversed(self.shape)), [axis for axis in flipped_axes if axis >= 0 and axis < 3], @@ -84,31 +83,30 @@ def __init__(self, single_kernel_multiple_batches, keep_shader_code ) - vd.check_for_errors() + check_for_errors() self.register_handle(handle) def _destroy(self): - vkdispatch_native.stage_fft_plan_destroy(self._handle) - vd.check_for_errors() + native.stage_fft_plan_destroy(self._handle) + check_for_errors() def __del__(self): self.destroy() - def record(self, command_list: vd.CommandList, buffer: vd.Buffer, inverse: bool = False, kernel: vd.Buffer = None, input: vd.Buffer = None): - vkdispatch_native.stage_fft_record( - command_list._handle, + def record(self, graph: vd.CommandGraph, buffer: vd.Buffer, inverse: bool = False, kernel: vd.Buffer = None, input: vd.Buffer = None): + native.stage_fft_record( + graph._handle, self._handle, buffer._handle, 1 if inverse else -1, kernel._handle if kernel is not None else 0, input._handle if input is not None else 0 ) - vd.check_for_errors() - - def record_forward(self, command_list: vd.CommandList, buffer: vd.Buffer): - self.record(command_list, buffer, False) + check_for_errors() - def record_inverse(self, command_list: vd.CommandList, buffer: vd.Buffer): - self.record(command_list, buffer, True) + def record_forward(self, graph: vd.CommandGraph, buffer: vd.Buffer): + self.record(graph, buffer, False) + def record_inverse(self, graph: vd.CommandGraph, buffer: vd.Buffer): + self.record(graph, buffer, True) diff --git a/vkdispatch_native/context/context.cpp b/vkdispatch_native/context/context.cpp index 91bcfd76..4d92935c 100644 --- a/vkdispatch_native/context/context.cpp +++ b/vkdispatch_native/context/context.cpp @@ -18,8 +18,6 @@ #include "../objects/command_list.hh" #include "../objects/objects_extern.hh" -//#include "../internal.hh" - void inplace_min(int* a, int b) { if(b < *a) { *a = b; @@ -34,7 +32,6 @@ struct Context* context_create_extern(int* device_indicies, int* queue_counts, i ctx->deviceCount = device_count; ctx->physicalDevices.resize(device_count); ctx->devices.resize(device_count); - //ctx->queues.resize(device_count); ctx->queue_index_map.resize(device_count); ctx->allocators.resize(device_count); ctx->glslang_resource_limits = new glslang_resource_t(); @@ -62,6 +59,16 @@ struct Context* context_create_extern(int* device_indicies, int* queue_counts, i struct PhysicalDeviceDetails* details = &_instance.device_details[device_indicies[i]]; + if(!details->timeline_semaphores) { + LOG_ERROR("Physical device %d does not support timeline semaphores", device_indicies[i]); + return nullptr; + } + + if(!details->scalar_block_layout) { + LOG_ERROR("Physical device %d does not support scalar block layout", device_indicies[i]); + return nullptr; + } + inplace_min(&resource->max_compute_work_group_size_x, details->max_workgroup_size_x); inplace_min(&resource->max_compute_work_group_size_y, details->max_workgroup_size_y); inplace_min(&resource->max_compute_work_group_size_z, details->max_workgroup_size_z); @@ -197,7 +204,7 @@ struct Context* context_create_extern(int* device_indicies, int* queue_counts, i LOG_INFO("Created context at %p with %d devices", ctx, device_count); - context_queue_wait_idle_extern(ctx, -1); + //context_queue_wait_idle_extern(ctx, -1); ctx->handle_manager = new HandleManager(ctx); @@ -206,45 +213,74 @@ struct Context* context_create_extern(int* device_indicies, int* queue_counts, i return ctx; } -void wait_for_queue(struct Context* ctx, int queue_index) { - LOG_INFO("Waiting for queue %d to finish execution...", queue_index); +bool signal_wait_extern(void* signal_ptr, bool wait_for_timestamp, int queue_index) { + Signal* signal = reinterpret_cast(signal_ptr); + LOG_VERBOSE("Waiting on signal %p (wait_for_timestamp=%d, queue_index=%d)...", signal, wait_for_timestamp, queue_index); + return signal->try_wait(wait_for_timestamp, queue_index); +} - uint64_t* p_timestamp = new uint64_t(); - Signal* signal = new Signal(ctx); +void* signal_insert_extern(struct Context* context, int queue_index) { + LOG_VERBOSE("Inserting signal into queue %d", queue_index); - *p_timestamp = 0; + Signal* signal = new Signal(context); - context_submit_command(ctx, "queue-wait-idle", queue_index, RECORD_TYPE_SYNC, - [ctx, signal, p_timestamp](VkCommandBuffer cmd_buffer, ExecIndicies indicies, void* pc_data, BarrierManager* barrier_manager, uint64_t timestamp){ - LOG_VERBOSE("Waiting for queue %d to finish execution...", indicies.queue_index); - *p_timestamp = timestamp; - signal->notify(); + context_submit_command(context, "queue-wait-idle", queue_index, RECORD_TYPE_SYNC, + [context, signal](VkCommandBuffer cmd_buffer, ExecIndicies indicies, void* pc_data, BarrierManager* barrier_manager, uint64_t timestamp){ + LOG_VERBOSE("Inserting signal to queue %d...", indicies.queue_index); + signal->notify(indicies.queue_index, timestamp); } ); - signal->wait(); + LOG_VERBOSE("Inserted signal %p into queue %d", signal, queue_index); - if(*p_timestamp == 0) { - if (ctx->running.load(std::memory_order_acquire)) - LOG_WARNING("Queue %d did not finish execution", queue_index); - } else { - LOG_INFO("Queue %d finished execution at timestamp %llu", queue_index, *p_timestamp); - } - - ctx->queues[queue_index]->wait_for_timestamp(*p_timestamp); + return reinterpret_cast(signal); +} +void signal_destroy_extern(void* signal_ptr) { + Signal* signal = reinterpret_cast(signal_ptr); delete signal; } -void context_queue_wait_idle_extern(struct Context* context, int queue_index) { - if(queue_index == -1) { - for(int i = 0; i < context->queues.size(); i++) { - wait_for_queue(context, i); - } - } else { - wait_for_queue(context, queue_index); - } -} + +// void wait_for_queue(struct Context* ctx, int queue_index) { +// LOG_INFO("Waiting for queue %d to finish execution...", queue_index); + +// uint64_t* p_timestamp = new uint64_t(); +// Signal* signal = new Signal(ctx); + +// *p_timestamp = 0; + +// context_submit_command(ctx, "queue-wait-idle", queue_index, RECORD_TYPE_SYNC, +// [ctx, signal, p_timestamp](VkCommandBuffer cmd_buffer, ExecIndicies indicies, void* pc_data, BarrierManager* barrier_manager, uint64_t timestamp){ +// LOG_VERBOSE("Waiting for queue %d to finish execution...", indicies.queue_index); +// *p_timestamp = timestamp; +// signal->notify(timestamp); +// } +// ); + +// signal->wait(); + +// if(*p_timestamp == 0) { +// if (ctx->running.load(std::memory_order_acquire)) +// LOG_WARNING("Queue %d did not finish execution", queue_index); +// } else { +// LOG_INFO("Queue %d finished execution at timestamp %llu", queue_index, *p_timestamp); +// } + +// ctx->queues[queue_index]->wait_for_timestamp(*p_timestamp); + +// delete signal; +// } + +// bool context_queue_wait_idle_extern(struct Context* context, int queue_index) { +// if(queue_index == -1) { +// for(int i = 0; i < context->queues.size(); i++) { +// wait_for_queue(context, i); +// } +// } else { +// wait_for_queue(context, queue_index); +// } +// } void context_submit_command( Context* context, @@ -256,7 +292,9 @@ void context_submit_command( LOG_INFO("Submitting command '%s' to queue %d", name, queue_index); command_list_record_command(context->command_list, name, 0, VK_PIPELINE_STAGE_TRANSFER_BIT, func); - command_list_submit_extern(context->command_list, NULL, 1, queue_index, record_type); + while(!command_list_submit_extern(context->command_list, NULL, 1, queue_index, record_type, name)) { + RETURN_ON_ERROR(;) + } command_list_reset_extern(context->command_list); RETURN_ON_ERROR(;) } @@ -264,7 +302,6 @@ void context_submit_command( void context_destroy_extern(struct Context* context) { LOG_INFO("Destroying context %p with %d devices...", context, context->deviceCount); LOG_INFO("Waiting for all queues to finish..."); - context_queue_wait_idle_extern(context, -1); context->work_queue->stop(); @@ -308,4 +345,4 @@ void context_stop_threads_extern(struct Context* context) { for(int i = 0; i < context->queues.size(); i++) { context->queues[i]->signal_stop(); } -} \ No newline at end of file +} diff --git a/vkdispatch_native/context/context_extern.hh b/vkdispatch_native/context/context_extern.hh index 27368ad4..3f0f7293 100644 --- a/vkdispatch_native/context/context_extern.hh +++ b/vkdispatch_native/context/context_extern.hh @@ -60,6 +60,11 @@ struct PhysicalDeviceDetails { unsigned int queue_family_count; struct QueueFamilyProperties* queue_family_properties; + + int scalar_block_layout; + int timeline_semaphores; + + unsigned char* uuid; }; void init_extern(bool debug, LogLevel log_level); @@ -70,7 +75,9 @@ void log_extern(LogLevel log_level, const char* text, const char* file_str, int void set_log_level_extern(LogLevel log_level); struct Context* context_create_extern(int* device_indicies, int* queue_counts, int* queue_families, int device_count); -void context_queue_wait_idle_extern(struct Context* context, int queue_index); +bool signal_wait_extern(void* signal_ptr, bool wait_for_timestamp, int queue_index); +void* signal_insert_extern(struct Context* context, int queue_index); +void signal_destroy_extern(void* signal_ptr); void context_destroy_extern(struct Context* context); void context_stop_threads_extern(struct Context* context); diff --git a/vkdispatch_native/context/context_extern.pxd b/vkdispatch_native/context/context_extern.pxd index febd5c36..873a38b7 100644 --- a/vkdispatch_native/context/context_extern.pxd +++ b/vkdispatch_native/context/context_extern.pxd @@ -66,6 +66,11 @@ cdef extern from "context/context_extern.hh": unsigned int queue_family_count QueueFamilyProperties* queue_family_properties + + int scalar_block_layout + int timeline_semaphores + + unsigned char* uuid void init_extern(bool debug, LogLevel log_level) PhysicalDeviceDetails* get_devices_extern(int* count) @@ -75,7 +80,10 @@ cdef extern from "context/context_extern.hh": struct Context Context* context_create_extern(int* device_indicies, int* queue_counts, int* queue_families, int device_count) - void context_queue_wait_idle_extern(Context* context, int queue_index); + bool signal_wait_extern(void* signal_ptr, bool wait_for_timestamp, int queue_index) + void* signal_insert_extern(Context* context, int queue_index) + void signal_destroy_extern(void* signal_ptr) + void context_destroy_extern(Context* device_context); const char* get_error_string_extern() @@ -138,7 +146,10 @@ cpdef inline get_devices(): device.supported_operations, device.quad_operations_in_all_stages, device.max_compute_shared_memory_size, - [(device.queue_family_properties[j].queueCount, device.queue_family_properties[j].queueFlags) for j in range(device.queue_family_count)] + [(device.queue_family_properties[j].queueCount, device.queue_family_properties[j].queueFlags) for j in range(device.queue_family_count)], + device.scalar_block_layout, + device.timeline_semaphores, + bytes([device.uuid[k] for k in range(16)]) if device.uuid != NULL else None ) device_list.append(device_info) @@ -177,8 +188,15 @@ cpdef inline context_create(list[int] device_indicies, list[list[int]] queue_fam return result -cpdef inline void context_queue_wait_idle(unsigned long long context, int queue_index): - context_queue_wait_idle_extern(context, queue_index) +cpdef inline bool signal_wait(unsigned long long signal_ptr, bool wait_for_timestamp, int queue_index): + return signal_wait_extern(signal_ptr, wait_for_timestamp, queue_index) + +cpdef inline unsigned long long signal_insert(unsigned long long context, int queue_index): + cdef void* signal_ptr = signal_insert_extern(context, queue_index) + return signal_ptr + +cpdef inline signal_destroy(unsigned long long signal_ptr): + signal_destroy_extern(signal_ptr) cpdef inline context_destroy(unsigned long long context): context_destroy_extern(context) diff --git a/vkdispatch_native/context/init.cpp b/vkdispatch_native/context/init.cpp index 07449cbb..86ef05f2 100644 --- a/vkdispatch_native/context/init.cpp +++ b/vkdispatch_native/context/init.cpp @@ -132,10 +132,10 @@ void init_extern(bool debug, LogLevel log_level) { } -#ifdef __APPLE__ - extensions.push_back(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME); - flags |= VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR; -#endif +//#ifdef __APPLE__ + //extensions.push_back(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME); + //flags |= VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR; +//#endif uint32_t layer_count = 0; VK_CALL(vkEnumerateInstanceLayerProperties(&layer_count, nullptr)); @@ -186,7 +186,7 @@ void init_extern(bool debug, LogLevel log_level) { VkInstanceCreateInfo instanceCreateInfo = {}; instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; - instanceCreateInfo.pNext = &validationFeatures; + if (debug) instanceCreateInfo.pNext = &validationFeatures; instanceCreateInfo.pApplicationInfo = &appInfo; instanceCreateInfo.flags = flags; instanceCreateInfo.enabledExtensionCount = supportedExtensions.size(); @@ -211,7 +211,6 @@ void init_extern(bool debug, LogLevel log_level) { if(debug) { LOG_INFO("Initializing Vulkan Debug Messenger..."); - VkDebugUtilsMessengerCreateInfoEXT debugCreateInfo = {}; debugCreateInfo.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT; debugCreateInfo.pNext = NULL; @@ -235,31 +234,37 @@ void init_extern(bool debug, LogLevel log_level) { VK_CALL(vkEnumeratePhysicalDevices(_instance.instance, &device_count, nullptr)); _instance.physicalDevices.resize(device_count); _instance.features.resize(device_count); - _instance.atomicFloatFeatures.resize(device_count); - _instance.float16int8Features.resize(device_count); + _instance.scalar_block_layout_features.resize(device_count); + _instance.atomic_float_features.resize(device_count); + _instance.float16_int8_features.resize(device_count); _instance.storage16bit.resize(device_count); _instance.properties.resize(device_count); _instance.subgroup_properties.resize(device_count); + _instance.id_properties.resize(device_count); _instance.device_details.resize(device_count); _instance.queue_family_properties.resize(device_count); _instance.timeline_semaphore_features.resize(device_count); VK_CALL(vkEnumeratePhysicalDevices(_instance.instance, &device_count, _instance.physicalDevices.data())); for(int i = 0; i < _instance.physicalDevices.size(); i++) { + _instance.scalar_block_layout_features[i] = {}; + _instance.scalar_block_layout_features[i].sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SCALAR_BLOCK_LAYOUT_FEATURES; + _instance.timeline_semaphore_features[i] = {}; _instance.timeline_semaphore_features[i].sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_TIMELINE_SEMAPHORE_FEATURES; + _instance.timeline_semaphore_features[i].pNext = &_instance.scalar_block_layout_features[i]; - _instance.atomicFloatFeatures[i] = {}; - _instance.atomicFloatFeatures[i].sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT; - _instance.atomicFloatFeatures[i].pNext = &_instance.timeline_semaphore_features[i]; + _instance.atomic_float_features[i] = {}; + _instance.atomic_float_features[i].sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT; + _instance.atomic_float_features[i].pNext = &_instance.timeline_semaphore_features[i]; - _instance.float16int8Features[i] = {}; - _instance.float16int8Features[i].sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES; - _instance.float16int8Features[i].pNext = &_instance.atomicFloatFeatures[i]; + _instance.float16_int8_features[i] = {}; + _instance.float16_int8_features[i].sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES; + _instance.float16_int8_features[i].pNext = &_instance.atomic_float_features[i]; _instance.storage16bit[i] = {}; _instance.storage16bit[i].sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES; - _instance.storage16bit[i].pNext = &_instance.float16int8Features[i]; + _instance.storage16bit[i].pNext = &_instance.float16_int8_features[i]; _instance.features[i] = {}; _instance.features[i].sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; @@ -268,10 +273,19 @@ void init_extern(bool debug, LogLevel log_level) { vkGetPhysicalDeviceFeatures2(_instance.physicalDevices[i], &_instance.features[i]); VkPhysicalDeviceFeatures features = _instance.features[i].features; - VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomicFloatFeatures = _instance.atomicFloatFeatures[i]; + _instance.features[i].features = {}; + _instance.features[i].features.shaderInt16 = features.shaderInt16; + _instance.features[i].features.shaderInt64 = features.shaderInt64; + _instance.features[i].features.shaderFloat64 = features.shaderFloat64; + + VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomicFloatFeatures = _instance.atomic_float_features[i]; + + _instance.id_properties[i] = {}; + _instance.id_properties[i].sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES; _instance.subgroup_properties[i] = {}; _instance.subgroup_properties[i].sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES; + _instance.subgroup_properties[i].pNext = &_instance.id_properties[i]; _instance.properties[i] = {}; _instance.properties[i].sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; @@ -304,7 +318,7 @@ void init_extern(bool debug, LogLevel log_level) { strcpy((char*)_instance.device_details[i].device_name, properties.deviceName); _instance.device_details[i].float_64_support = features.shaderFloat64; - _instance.device_details[i].float_16_support = _instance.float16int8Features[i].shaderFloat16; + _instance.device_details[i].float_16_support = _instance.float16_int8_features[i].shaderFloat16; _instance.device_details[i].int_64_support = features.shaderInt64; _instance.device_details[i].int_16_support = features.shaderInt16; @@ -346,6 +360,11 @@ void init_extern(bool debug, LogLevel log_level) { _instance.device_details[i].shader_buffer_float32_atomics = atomicFloatFeatures.shaderBufferFloat32Atomics; _instance.device_details[i].shader_buffer_float32_atomic_add = atomicFloatFeatures.shaderBufferFloat32AtomicAdd; + + _instance.device_details[i].timeline_semaphores = _instance.timeline_semaphore_features[i].timelineSemaphore; + _instance.device_details[i].scalar_block_layout = _instance.scalar_block_layout_features[i].scalarBlockLayout; + + _instance.device_details[i].uuid = _instance.id_properties[i].deviceUUID; } } diff --git a/vkdispatch_native/context/init.hh b/vkdispatch_native/context/init.hh index 475edea1..518c1351 100644 --- a/vkdispatch_native/context/init.hh +++ b/vkdispatch_native/context/init.hh @@ -14,6 +14,7 @@ * - Debug messenger (VkDebugUtilsMessengerEXT) * - Physical devices (VkPhysicalDevice) * - Features of the physical devices (VkPhysicalDeviceFeatures2) + * - Scalar block layout features (VkPhysicalDeviceScalarBlockLayoutFeatures) * - Shader atomic float features (VkPhysicalDeviceShaderAtomicFloatFeaturesEXT) * - Shader float16 and int8 features (VkPhysicalDeviceShaderFloat16Int8Features) * - 16-bit storage features (VkPhysicalDevice16BitStorageFeatures) @@ -32,11 +33,13 @@ typedef struct { VkDebugUtilsMessengerEXT debug_messenger; std::vector physicalDevices; std::vector features; - std::vector atomicFloatFeatures; - std::vector float16int8Features; + std::vector scalar_block_layout_features; + std::vector atomic_float_features; + std::vector float16_int8_features; std::vector storage16bit; std::vector properties; std::vector subgroup_properties; + std::vector id_properties; std::vector device_details; std::vector> queue_family_properties; std::vector timeline_semaphore_features; diff --git a/vkdispatch_native/objects/buffer.cpp b/vkdispatch_native/objects/buffer.cpp index 3b4b00bf..ede3347d 100644 --- a/vkdispatch_native/objects/buffer.cpp +++ b/vkdispatch_native/objects/buffer.cpp @@ -80,7 +80,7 @@ struct Buffer* buffer_create_extern(struct Context* ctx, unsigned long long size ctx->handle_manager->set_handle(indicies.queue_index, staging_allocations_handle, (uint64_t)h_staging_allocation); Signal* signal = (Signal*)ctx->handle_manager->get_handle(indicies.queue_index, signals_pointers_handle, 0); - signal->notify(); + signal->notify(indicies.queue_index, timestamp); }); return buffer; @@ -96,7 +96,7 @@ void buffer_destroy_extern(struct Buffer* buffer) { Signal* signal = (Signal*)ctx->handle_manager->get_handle(queue_index, signals_pointers_handle, 0); // wait for the recording thread to finish - signal->wait(); + //signal->wait(); ctx->handle_manager->destroy_handle(queue_index, buffer->signals_pointers_handle); @@ -136,26 +136,72 @@ void buffer_destroy_extern(struct Buffer* buffer) { delete buffer; } -void write_to_buffer(Context* ctx, struct Buffer* buffer, void* data, unsigned long long offset, unsigned long long size, int queue_index) { - int device_index = ctx->queues[queue_index]->device_index; +void* buffer_get_queue_signal_extern(struct Buffer* buffer, int queue_index) { + struct Context* ctx = buffer->ctx; uint64_t signals_pointers_handle = buffer->signals_pointers_handle; Signal* signal = (Signal*)ctx->handle_manager->get_handle(queue_index, signals_pointers_handle, 0); - // wait for the recording thread to finish - signal->wait(); - signal->reset(); + return (void*)signal; +} + +bool buffer_wait_staging_idle_extern(struct Buffer* buffer, int queue_index) { + struct Context* ctx = buffer->ctx; - // wait for the staging buffer to be ready uint64_t staging_buffer_timestamp = ctx->handle_manager->get_handle_timestamp(queue_index, buffer->staging_buffers_handle); - ctx->queues[queue_index]->wait_for_timestamp(staging_buffer_timestamp); + return ctx->queues[queue_index]->try_wait_for_timestamp(staging_buffer_timestamp); +} + +void buffer_write_staging_extern(struct Buffer* buffer, int queue_index, void* data, unsigned long long size) { + struct Context* ctx = buffer->ctx; + int device_index = ctx->queues[queue_index]->device_index; VmaAllocation staging_allocation = (VmaAllocation)ctx->handle_manager->get_handle(queue_index, buffer->staging_allocations_handle, 0); void* mapped; VK_CALL(vmaMapMemory(ctx->allocators[device_index], staging_allocation, &mapped)); memcpy(mapped, data, size); + VK_CALL(vmaFlushAllocation(ctx->allocators[device_index], staging_allocation, 0, size)); vmaUnmapMemory(ctx->allocators[device_index], staging_allocation); +} + +void buffer_read_staging_extern(struct Buffer* buffer, int queue_index, void* data, unsigned long long size) { + struct Context* ctx = buffer->ctx; + int device_index = ctx->queues[queue_index]->device_index; + + VmaAllocation staging_allocation = (VmaAllocation)ctx->handle_manager->get_handle(queue_index, buffer->staging_allocations_handle, 0); + + void* mapped; + VK_CALL(vmaMapMemory(ctx->allocators[device_index], staging_allocation, &mapped)); + VK_CALL(vmaInvalidateAllocation(ctx->allocators[device_index], staging_allocation, 0, size)); + memcpy(data, mapped, size); + vmaUnmapMemory(ctx->allocators[device_index], staging_allocation); +} + +void buffer_write_extern(struct Buffer* buffer, unsigned long long offset, unsigned long long size, int queue_index) { + LOG_INFO("Writing data to buffer (%p) at offset %d with size %d", buffer, offset, size); + + struct Context* ctx = buffer->ctx; + + int device_index = ctx->queues[queue_index]->device_index; + + uint64_t signals_pointers_handle = buffer->signals_pointers_handle; + Signal* signal = (Signal*)ctx->handle_manager->get_handle(queue_index, signals_pointers_handle, 0); + + // wait for the recording thread to finish + //signal->wait(); + signal->reset(); + + // wait for the staging buffer to be ready + // uint64_t staging_buffer_timestamp = ctx->handle_manager->get_handle_timestamp(queue_index, buffer->staging_buffers_handle); + // ctx->queues[queue_index]->wait_for_timestamp(staging_buffer_timestamp); + + // VmaAllocation staging_allocation = (VmaAllocation)ctx->handle_manager->get_handle(queue_index, buffer->staging_allocations_handle, 0); + + // void* mapped; + // VK_CALL(vmaMapMemory(ctx->allocators[device_index], staging_allocation, &mapped)); + // memcpy(mapped, data, size); + // vmaUnmapMemory(ctx->allocators[device_index], staging_allocation); uint64_t buffers_handle = buffer->buffers_handle; uint64_t staging_buffers_handle = buffer->staging_buffers_handle; @@ -187,28 +233,28 @@ void write_to_buffer(Context* ctx, struct Buffer* buffer, void* data, unsigned l vkCmdCopyBuffer(cmd_buffer, stagingBuffer, buffer, 1, &bufferCopy); + VkMemoryBarrier compute_barrier = { + VK_STRUCTURE_TYPE_MEMORY_BARRIER, + 0, + VK_ACCESS_TRANSFER_WRITE_BIT, + VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_UNIFORM_READ_BIT, + }; + + vkCmdPipelineBarrier( + cmd_buffer, + VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + 0, + 1, &compute_barrier, 0, NULL, 0, NULL + ); + Signal* signal = (Signal*)ctx->handle_manager->get_handle(indicies.queue_index, signals_pointers_handle, 0); - signal->notify(); + signal->notify(indicies.queue_index, timestamp); } ); } -void buffer_write_extern(struct Buffer* buffer, void* data, unsigned long long offset, unsigned long long size, int index) { - LOG_INFO("Writing data to buffer (%p) at offset %d with size %d", buffer, offset, size); - - struct Context* ctx = buffer->ctx; - - if(index != -1) { - write_to_buffer(ctx, buffer, data, offset, size, index); - return; - } - - for(int i = 0; i < ctx->queues.size(); i++) { - write_to_buffer(ctx, buffer, data, offset, size, i); - } -} - -void buffer_read_extern(struct Buffer* buffer, void* data, unsigned long long offset, unsigned long long size, int queue_index) { +void buffer_read_extern(struct Buffer* buffer, unsigned long long offset, unsigned long long size, int queue_index) { LOG_INFO("Reading data from buffer (%p) at offset %d with size %d", buffer, offset, size); struct Context* ctx = buffer->ctx; @@ -217,7 +263,7 @@ void buffer_read_extern(struct Buffer* buffer, void* data, unsigned long long of Signal* signal = (Signal*)ctx->handle_manager->get_handle(queue_index, signals_pointers_handle, 0); // wait for the recording thread to finish - signal->wait(); + //signal->wait(); signal->reset(); uint64_t buffers_handle = buffer->buffers_handle; @@ -229,6 +275,21 @@ void buffer_read_extern(struct Buffer* buffer, void* data, unsigned long long of VkBuffer stagingBuffer = (VkBuffer)ctx->handle_manager->get_handle(indicies.queue_index, staging_buffers_handle, timestamp); VkBuffer buffer = (VkBuffer)ctx->handle_manager->get_handle(indicies.queue_index, buffers_handle, timestamp); + VkMemoryBarrier compute_barrier = { + VK_STRUCTURE_TYPE_MEMORY_BARRIER, + 0, + VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_UNIFORM_READ_BIT, + VK_ACCESS_TRANSFER_WRITE_BIT | VK_ACCESS_TRANSFER_READ_BIT, + }; + + vkCmdPipelineBarrier( + cmd_buffer, + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT, + 0, + 1, &compute_barrier, 0, NULL, 0, NULL + ); + VkBufferCopy bufferCopy; bufferCopy.size = size; bufferCopy.dstOffset = 0; @@ -239,7 +300,7 @@ void buffer_read_extern(struct Buffer* buffer, void* data, unsigned long long of VkMemoryBarrier barrier = { VK_STRUCTURE_TYPE_MEMORY_BARRIER, 0, - VK_ACCESS_TRANSFER_WRITE_BIT, + VK_ACCESS_TRANSFER_WRITE_BIT | VK_ACCESS_TRANSFER_READ_BIT, VK_ACCESS_HOST_READ_BIT, }; vkCmdPipelineBarrier( @@ -251,23 +312,23 @@ void buffer_read_extern(struct Buffer* buffer, void* data, unsigned long long of ); Signal* signal = (Signal*)ctx->handle_manager->get_handle(indicies.queue_index, signals_pointers_handle, 0); - signal->notify(); + signal->notify(indicies.queue_index, timestamp); } ); // wait for the recording thread to finish again - signal->wait(); + // signal->wait(); - // wait for the staging buffer to be ready - uint64_t staging_buffer_timestamp = ctx->handle_manager->get_handle_timestamp(queue_index, buffer->staging_buffers_handle); - ctx->queues[queue_index]->wait_for_timestamp(staging_buffer_timestamp); + // // wait for the staging buffer to be ready + // uint64_t staging_buffer_timestamp = ctx->handle_manager->get_handle_timestamp(queue_index, buffer->staging_buffers_handle); + // ctx->queues[queue_index]->wait_for_timestamp(staging_buffer_timestamp); - int device_index = ctx->queues[queue_index]->device_index; + // int device_index = ctx->queues[queue_index]->device_index; - VmaAllocation staging_allocation = (VmaAllocation)ctx->handle_manager->get_handle(queue_index, buffer->staging_allocations_handle, 0); + // VmaAllocation staging_allocation = (VmaAllocation)ctx->handle_manager->get_handle(queue_index, buffer->staging_allocations_handle, 0); - void* mapped; - VK_CALL(vmaMapMemory(ctx->allocators[device_index], staging_allocation, &mapped)); - memcpy(data, mapped, size); - vmaUnmapMemory(ctx->allocators[device_index], staging_allocation); -} \ No newline at end of file + // void* mapped; + // VK_CALL(vmaMapMemory(ctx->allocators[device_index], staging_allocation, &mapped)); + // memcpy(data, mapped, size); + // vmaUnmapMemory(ctx->allocators[device_index], staging_allocation); +} diff --git a/vkdispatch_native/objects/buffer.hh b/vkdispatch_native/objects/buffer.hh index a6393ded..63594996 100644 --- a/vkdispatch_native/objects/buffer.hh +++ b/vkdispatch_native/objects/buffer.hh @@ -20,11 +20,6 @@ struct Buffer { uint64_t allocations_handle; uint64_t staging_buffers_handle; uint64_t staging_allocations_handle; - - //std::vector buffers; - //std::vector allocations; - //std::vector stagingBuffers; - //std::vector stagingAllocations; }; #endif // SRC_BUFFER_H_ \ No newline at end of file diff --git a/vkdispatch_native/objects/command_list.cpp b/vkdispatch_native/objects/command_list.cpp index 4bb33c5c..a273823e 100644 --- a/vkdispatch_native/objects/command_list.cpp +++ b/vkdispatch_native/objects/command_list.cpp @@ -55,16 +55,18 @@ void command_list_reset_extern(struct CommandList* command_list) { LOG_INFO("Command list reset"); } -void command_list_submit_extern(struct CommandList* command_list, void* instance_buffer, unsigned int instance_count, int index, int recordType) { +bool command_list_submit_extern(struct CommandList* command_list, void* instance_buffer, unsigned int instance_count, int index, int recordType, const char* name) { struct Context* ctx = command_list->ctx; LOG_INFO("Submitting command list with handle %p to queue %d", command_list, index); - if(index == -2) { - for(int i = 0; i < ctx->queues.size(); i++) { - ctx->work_queue->push(command_list, instance_buffer, instance_count, i, recordType); - } - } else { - ctx->work_queue->push(command_list, instance_buffer, instance_count, index, recordType); + if(index != -2) + return ctx->work_queue->push(command_list, instance_buffer, instance_count, index, recordType, name); + + for(int i = 0; i < ctx->queues.size(); i++) { + if(!ctx->work_queue->push(command_list, instance_buffer, instance_count, i, recordType, name)) + return false; } + + return true; } \ No newline at end of file diff --git a/vkdispatch_native/objects/image.cpp b/vkdispatch_native/objects/image.cpp index 1ef3c91d..0a40b1ae 100644 --- a/vkdispatch_native/objects/image.cpp +++ b/vkdispatch_native/objects/image.cpp @@ -175,7 +175,7 @@ struct Image* image_create_extern(struct Context* context, VkExtent3D a_extent, ctx->handle_manager->set_handle(indicies.queue_index, staging_allocations_handle, (uint64_t)h_staging_allocation); Signal* signal = (Signal*)ctx->handle_manager->get_handle(indicies.queue_index, signals_pointers_handle, 0); - signal->notify(); + signal->notify(indicies.queue_index, timestamp); } ); @@ -190,7 +190,9 @@ void image_destroy_extern(struct Image* image) { Signal* signal = (Signal*)ctx->handle_manager->get_handle(queue_index, image->signals_pointers_handle, 0); // wait for the recording thread to finish - signal->wait(); + while(!signal->try_wait(false, queue_index)) { + LOG_INFO("Waiting for image %p signal %p queue %d to be notified before destroying", image, signal, queue_index); + } ctx->handle_manager->destroy_handle(queue_index, image->signals_pointers_handle); @@ -325,7 +327,9 @@ void write_to_image(struct Context* ctx, struct Image* image, void* data, VkOffs LOG_INFO("waiting for recording thread to finish for image %p signal %p queue %d", image, signal, queue_index); // wait for the recording thread to finish - signal->wait(); + while(!signal->try_wait(false, queue_index)) { + LOG_INFO("Waiting for image %p signal %p queue %d to be notified before destroying", image, signal, queue_index); + } signal->reset(); LOG_INFO( @@ -440,7 +444,7 @@ void write_to_image(struct Context* ctx, struct Image* image, void* data, VkOffs } Signal* signal = (Signal*)ctx->handle_manager->get_handle(indicies.queue_index, signals_pointers_handle, 0); - signal->notify(); + signal->notify(indicies.queue_index, timestamp); } ); } @@ -469,7 +473,9 @@ void image_read_extern(struct Image* image, void* data, VkOffset3D offset, VkExt Signal* signal = (Signal*)ctx->handle_manager->get_handle(queue_index, signals_pointers_handle, 0); // wait for the recording thread to finish - signal->wait(); + while(!signal->try_wait(false, queue_index)) { + LOG_INFO("Waiting for image %p signal %p queue %d to be notified before destroying", image, signal, queue_index); + } signal->reset(); uint64_t images_handle = image->images_handle; @@ -508,11 +514,13 @@ void image_read_extern(struct Image* image, void* data, VkOffset3D offset, VkExt insert_barrier(cmd_buffer, barrier, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); Signal* signal = (Signal*)ctx->handle_manager->get_handle(indicies.queue_index, signals_pointers_handle, 0); - signal->notify(); + signal->notify(indicies.queue_index, timestamp); } ); - signal->wait(); + while(!signal->try_wait(false, queue_index)) { + LOG_INFO("Waiting for image %p signal %p queue %d to be notified before destroying", image, signal, queue_index); + } // wait for the staging buffer to be ready uint64_t staging_buffer_timestamp = ctx->handle_manager->get_handle_timestamp(queue_index, image->staging_buffers_handle); diff --git a/vkdispatch_native/objects/objects_extern.hh b/vkdispatch_native/objects/objects_extern.hh index 7bd1c0d1..cebe4058 100644 --- a/vkdispatch_native/objects/objects_extern.hh +++ b/vkdispatch_native/objects/objects_extern.hh @@ -39,8 +39,14 @@ struct ImageReadInfo { struct Buffer* buffer_create_extern(struct Context* context, unsigned long long size, int per_device); void buffer_destroy_extern(struct Buffer* buffer); -void buffer_write_extern(struct Buffer* buffer, void* data, unsigned long long offset, unsigned long long size, int index); -void buffer_read_extern(struct Buffer* buffer, void* data, unsigned long long offset, unsigned long long size, int index); +void* buffer_get_queue_signal_extern(struct Buffer* buffer, int queue_index); +bool buffer_wait_staging_idle_extern(struct Buffer* buffer, int queue_index); + +void buffer_write_staging_extern(struct Buffer* buffer, int queue_index, void* data, unsigned long long size); +void buffer_read_staging_extern(struct Buffer* buffer, int queue_index, void* data, unsigned long long size); + +void buffer_write_extern(struct Buffer* buffer, unsigned long long offset, unsigned long long size, int index); +void buffer_read_extern(struct Buffer* buffer, unsigned long long offset, unsigned long long size, int index); struct CommandList* command_list_create_extern(struct Context* context); void command_list_destroy_extern(struct CommandList* command_list); @@ -48,7 +54,7 @@ void command_list_destroy_extern(struct CommandList* command_list); unsigned long long command_list_get_instance_size_extern(struct CommandList* command_list); void command_list_reset_extern(struct CommandList* command_list); -void command_list_submit_extern(struct CommandList* command_list, void* instance_buffer, unsigned int instanceCount, int index, int recordType); +bool command_list_submit_extern(struct CommandList* command_list, void* instance_buffer, unsigned int instanceCount, int index, int recordType, const char* name); struct DescriptorSet* descriptor_set_create_extern(struct ComputePlan* plan); void descriptor_set_destroy_extern(struct DescriptorSet* descriptor_set); diff --git a/vkdispatch_native/objects/objects_extern.pxd b/vkdispatch_native/objects/objects_extern.pxd index 1c97cb35..cbefeed7 100644 --- a/vkdispatch_native/objects/objects_extern.pxd +++ b/vkdispatch_native/objects/objects_extern.pxd @@ -26,14 +26,20 @@ cdef extern from "objects/objects_extern.hh": Buffer* buffer_create_extern(Context* context, unsigned long long size, int per_device) void buffer_destroy_extern(Buffer* buffer) - void buffer_write_extern(Buffer* buffer, void* data, unsigned long long offset, unsigned long long size, int index) - void buffer_read_extern(Buffer* buffer, void* data, unsigned long long offset, unsigned long long size, int index) + void* buffer_get_queue_signal_extern(Buffer* buffer, int queue_index) + bool buffer_wait_staging_idle_extern(Buffer* buffer, int queue_index) + + void buffer_write_staging_extern(Buffer* buffer, int queue_index, void* data, unsigned long long size) + void buffer_read_staging_extern(Buffer* buffer, int queue_index, void* data, unsigned long long size) + + void buffer_write_extern(Buffer* buffer, unsigned long long offset, unsigned long long size, int index) + void buffer_read_extern(Buffer* buffer, unsigned long long offset, unsigned long long size, int index) CommandList* command_list_create_extern(Context* context) void command_list_destroy_extern(CommandList* command_list) unsigned long long command_list_get_instance_size_extern(CommandList* command_list) void command_list_reset_extern(CommandList* command_list) - void command_list_submit_extern(CommandList* command_list, void* instance_buffer, unsigned int instanceCount, int index, int recordType) + bool command_list_submit_extern(CommandList* command_list, void* instance_buffer, unsigned int instanceCount, int index, int recordType, const char* name) DescriptorSet* descriptor_set_create_extern(ComputePlan* plan) void descriptor_set_destroy_extern(DescriptorSet* descriptor_set) @@ -71,18 +77,30 @@ cpdef inline buffer_create(unsigned long long context, unsigned long long size, cpdef inline buffer_destroy(unsigned long long buffer): buffer_destroy_extern(buffer) -cpdef inline buffer_write(unsigned long long buffer, bytes data, unsigned long long offset, unsigned long long size, int index): +cpdef inline buffer_get_queue_signal(unsigned long long buffer, int queue_index): + return buffer_get_queue_signal_extern(buffer, queue_index) + +cpdef inline buffer_wait_staging_idle(unsigned long long buffer, int queue_index): + return buffer_wait_staging_idle_extern(buffer, queue_index) + +cpdef inline buffer_write_staging(unsigned long long buffer, int queue_index, bytes data, unsigned long long size): cdef const char* data_view = data - buffer_write_extern(buffer, data_view, offset, size, index) + buffer_write_staging_extern(buffer, queue_index, data_view, size) -cpdef inline buffer_read(unsigned long long buffer, unsigned long long offset, unsigned long long size, int index): +cpdef inline buffer_read_staging(unsigned long long buffer, int queue_index, unsigned long long size): cdef bytes data = bytes(size) cdef char* data_view = data - buffer_read_extern(buffer, data_view, offset, size, index) + buffer_read_staging_extern(buffer, queue_index, data_view, size) return data +cpdef inline buffer_write(unsigned long long buffer, unsigned long long offset, unsigned long long size, int index): + buffer_write_extern(buffer, offset, size, index) + +cpdef inline buffer_read(unsigned long long buffer, unsigned long long offset, unsigned long long size, int index): + buffer_read_extern(buffer,offset, size, index) + cpdef inline command_list_create(unsigned long long context): return command_list_create_extern(context) @@ -100,7 +118,7 @@ cpdef inline command_list_submit(unsigned long long command_list, bytes data, un if data is not None: data_view = data - command_list_submit_extern(command_list, data_view, instance_count, index, 0) + return command_list_submit_extern(command_list, data_view, instance_count, index, 0, "User Command List") cpdef inline descriptor_set_create(unsigned long long plan): cdef ComputePlan* p = plan diff --git a/vkdispatch_native/queue/queue.cpp b/vkdispatch_native/queue/queue.cpp index fa2e6351..ae5ac2e6 100644 --- a/vkdispatch_native/queue/queue.cpp +++ b/vkdispatch_native/queue/queue.cpp @@ -79,6 +79,7 @@ Queue::Queue( this->run_queue.store(true); if(this->recording_thread_count > 1) { + LOG_INFO("Starting ingest, %d record, and submit threads for queue %d", recording_thread_count, this->queue_index); submit_thread = std::thread([this]() { this->submit_worker(); }); record_threads = new std::thread[recording_thread_count]; @@ -88,6 +89,7 @@ Queue::Queue( ingest_thread = std::thread([this]() { this->ingest_worker(); }); } else { + LOG_INFO("Starting fused worker thread for queue %d", this->queue_index); submit_thread = std::thread([this]() { this->fused_worker(); }); } } @@ -138,32 +140,41 @@ void Queue::destroy() { recording_results.clear(); } -void Queue::wait_for_timestamp(uint64_t timestamp) { +bool Queue::try_wait_for_timestamp(uint64_t timestamp) { uint64_t last_completed = 0; - VK_CALL(vkGetSemaphoreCounterValue(device, timeline_semaphore, &last_completed)); + VK_CALL_RETURN(vkGetSemaphoreCounterValue(device, timeline_semaphore, &last_completed), true); if (last_completed >= timestamp) { - return; + return true; } - while(last_completed < timestamp) { - VkSemaphoreWaitInfo wi = {}; - wi.sType = VK_STRUCTURE_TYPE_SEMAPHORE_WAIT_INFO; - wi.semaphoreCount = 1; - wi.pSemaphores = &timeline_semaphore; - wi.pValues = ×tamp; - VkResult result = vkWaitSemaphores(device, &wi, 1000000000); - if (result != VK_TIMEOUT) { - if(result != VK_SUCCESS) { - set_error("Failed to wait for semaphore: %d", result); - } - return; - } + LOG_INFO("Last completed timestamp: %llu, waiting for timestamp: %llu on queue %d", last_completed, timestamp, this->queue_index); + + VkSemaphoreWaitInfo wi = {}; + wi.sType = VK_STRUCTURE_TYPE_SEMAPHORE_WAIT_INFO; + wi.semaphoreCount = 1; + wi.pSemaphores = &timeline_semaphore; + wi.pValues = ×tamp; + VkResult result = vkWaitSemaphores(device, &wi, 1000000000); + + if (result == VK_TIMEOUT) { + LOG_INFO("Timeout while waiting for semaphore %d on queue %d", timestamp, this->queue_index); + return false; + } + + if(result != VK_SUCCESS) { + set_error("Failed to wait for semaphore: %d", result); + } + + return true; +} + +void Queue::wait_for_timestamp(uint64_t timestamp) { + while(!try_wait_for_timestamp(timestamp)) { + LOG_VERBOSE("Timeout while waiting for timestamp %llu on queue %d, (running=%d) checking again...", timestamp, this->queue_index, this->run_queue.load()); if(!this->run_queue.load()) { return; } - - VK_CALL(vkGetSemaphoreCounterValue(device, timeline_semaphore, &last_completed)); } } @@ -177,11 +188,12 @@ void ingest_work_item( LOG_VERBOSE("Ingesting work item for queue %d, current index %llu", queue->queue_index, current_index); if (current_index + 1 > queue->inflight_cmd_buffer_count) { + LOG_VERBOSE("Waiting for timestamp %llu on queue %d", current_index + 1 - queue->inflight_cmd_buffer_count, queue->queue_index); queue->wait_for_timestamp(current_index + 1 - queue->inflight_cmd_buffer_count); } if(!work_queue->pop(&work_header, queue->queue_index)) { - LOG_INFO("Thread worker for device %d, queue %d has no more work", queue->device_index, queue->queue_index); + LOG_VERBOSE("Thread worker for device %d, queue %d has no more work", queue->device_index, queue->queue_index); queue->run_queue.store(false); return; } @@ -222,7 +234,7 @@ void Queue::ingest_worker() { } } - LOG_INFO("Thread worker for device %d, queue %d has quit", device_index, queue_index); + LOG_VERBOSE("Thread worker for device %d, queue %d has quit", device_index, queue_index); } int record_work_item( @@ -253,7 +265,7 @@ int record_work_item( exec_indices.queue_index = queue->queue_index; exec_indices.recorder_index = worker_id; - LOG_INFO("Recording work item %p on queue %d, worker %d, instance count %d", work_item.work_header, queue->queue_index, worker_id, work_item.work_header->instance_count); + LOG_VERBOSE("Recording work item %p on queue %d, worker %d, instance count %d", work_item.work_header, queue->queue_index, worker_id, work_item.work_header->instance_count); char* current_instance_data = (char*)&work_item.work_header[1]; for(size_t instance = 0; instance < work_item.work_header->instance_count; instance++) { @@ -273,7 +285,7 @@ int record_work_item( queue->ctx->work_queue->finish(work_item.work_header); - LOG_INFO("Finished recording work item %p on queue %d, worker %d, instance count %d", work_item.work_header, queue->queue_index, worker_id, work_item.work_header->instance_count); + LOG_VERBOSE("Finished recording work item %p on queue %d, worker %d, instance count %d", work_item.work_header, queue->queue_index, worker_id, work_item.work_header->instance_count); return cmd_buffer_index; } @@ -393,7 +405,7 @@ void submit_work_item( submit_info.signalSemaphoreCount = 1; submit_info.pSignalSemaphores = &queue->timeline_semaphore; - LOG_INFO("Submitting command buffer %p with signal value %llu to queue %d", work_item.recording_result->commandBuffer, signalValue, queue->queue_index); + LOG_INFO("Submitting command buffer %p with signal value %llu to queue %d with name '%s'", work_item.recording_result->commandBuffer, signalValue, queue->queue_index, work_item.work_header->name); VK_CALL(vkQueueSubmit(queue->queue, 1, &submit_info, VK_NULL_HANDLE)); diff --git a/vkdispatch_native/queue/queue.hh b/vkdispatch_native/queue/queue.hh index 629ec42f..ef00e292 100644 --- a/vkdispatch_native/queue/queue.hh +++ b/vkdispatch_native/queue/queue.hh @@ -17,7 +17,6 @@ struct RecordingResultData { struct WorkQueueItem { uint64_t current_index; struct WorkHeader* work_header; - //Signal* signal; RecordingResultData* recording_result; VkPipelineStageFlags* waitStage; }; @@ -41,6 +40,7 @@ public: void record_worker(int worker_id); void submit_worker(); + bool try_wait_for_timestamp(uint64_t timestamp); void wait_for_timestamp(uint64_t timestamp); void fused_worker(); diff --git a/vkdispatch_native/queue/signal.cpp b/vkdispatch_native/queue/signal.cpp index d4c33eab..aceecdd7 100644 --- a/vkdispatch_native/queue/signal.cpp +++ b/vkdispatch_native/queue/signal.cpp @@ -5,16 +5,21 @@ #include "../context/context.hh" +#define NULL_TIMESTAMP ((uint64_t)0xFFFFFFFFFFFFFFFF) Signal::Signal(struct Context* context) : state(false) { this->ctx = context; + this->timestamp = NULL_TIMESTAMP; + this->timestamp_queue_index = -1; } /* * This function sets the state of the signal to true, indicating that the condition has occurred. */ -void Signal::notify() { +void Signal::notify(int queue_index, uint64_t timestamp) { std::unique_lock lock(mutex); + this->timestamp = timestamp; + this->timestamp_queue_index = queue_index; state.store(true, std::memory_order_release); cv.notify_all(); } @@ -28,32 +33,67 @@ void Signal::reset() { state.store(false, std::memory_order_release); } +bool Signal::try_host_wait() { + std::unique_lock lock(mutex); + + bool notified = cv.wait_for(lock, std::chrono::seconds(1), [this] { + LOG_VERBOSE("Checking signal"); + + if(ctx->running.load(std::memory_order_acquire) == false) { + set_error("Context is not running, cannot wait for signal"); + return true; + } + + return state.load(std::memory_order_acquire); + }); + + return notified; +} + +bool Signal::try_device_wait(int queue_index) { + if(this->timestamp == NULL_TIMESTAMP) { + set_error("Signal timestamp is NULL, cannot wait for device"); + return false; + } + + if(queue_index < 0 || queue_index >= ctx->queues.size()) { + set_error("Invalid queue index %d for device wait", queue_index); + return false; + } + + return ctx->queues[queue_index]->try_wait_for_timestamp(timestamp); +} + /* * This function blocks the calling thread until the signal is notified. */ -void Signal::wait() { +bool Signal::try_wait(bool wait_for_timestamp, int queue_index) { + LOG_VERBOSE("Trying to wait on signal %p (wait_for_timestamp=%d, queue_index=%d)...", this, wait_for_timestamp, queue_index); + if (state.load(std::memory_order_acquire)) { - return; // If the signal is already notified, return immediately - } + LOG_VERBOSE("Signal %p already notified", this); - std::unique_lock lock(mutex); - - while(true) { - bool ready = cv.wait_for(lock, std::chrono::seconds(1), [this] { - LOG_VERBOSE("Checking signal"); - - if(ctx->running.load(std::memory_order_acquire) == false) { - set_error("Context is not running, cannot wait for signal"); - return true; - } - - return state.load(std::memory_order_acquire); - }); - - if (ready) { - return; + if (!wait_for_timestamp) { + LOG_VERBOSE("No need to wait for timestamp, returning"); + return true; } - LOG_VERBOSE("Timeout expired, rechecking..."); + LOG_VERBOSE("Waiting for timestamp %llu on queue %d", this->timestamp, queue_index); + + return try_device_wait(queue_index); } + + LOG_VERBOSE("Waiting for host notification on signal %p...", this); + if(!try_host_wait()) { + LOG_VERBOSE("Host wait for signal %p timed out", this); + return false; + } + + if(!wait_for_timestamp) { + LOG_VERBOSE("No need to wait for timestamp, returning"); + return true; + } + + LOG_VERBOSE("Waiting for timestamp %llu on queue %d", this->timestamp, queue_index); + return try_device_wait(queue_index); } \ No newline at end of file diff --git a/vkdispatch_native/queue/signal.hh b/vkdispatch_native/queue/signal.hh index 9aa8b5b3..d9aaa0f2 100644 --- a/vkdispatch_native/queue/signal.hh +++ b/vkdispatch_native/queue/signal.hh @@ -26,7 +26,7 @@ public: * This function sets the state of the signal to true, indicating that the condition has occurred. * It wakes up any waiting threads. */ - void notify(); + void notify(int queue_index, uint64_t timestamp); /** * @brief Resets the signal to the initial state. @@ -41,10 +41,21 @@ public: * * This function blocks the calling thread until the signal is notified. * If the signal is already in the notified state, the function returns immediately. + * + * This function will return after one second even if the signal is not notified, to prevent deadlocks. + * @return true if the signal was notified, false if the wait timed out. */ - void wait(); + bool try_wait(bool wait_for_timestamp, int queue_index); + +private: + bool try_host_wait(); + bool try_device_wait(int queue_index); + +public: struct Context* ctx; + uint64_t timestamp; + int timestamp_queue_index; std::mutex mutex; std::condition_variable cv; std::atomic state; diff --git a/vkdispatch_native/queue/work_queue.cpp b/vkdispatch_native/queue/work_queue.cpp index 7b75ca2b..9ce61626 100644 --- a/vkdispatch_native/queue/work_queue.cpp +++ b/vkdispatch_native/queue/work_queue.cpp @@ -21,6 +21,7 @@ WorkQueue::WorkQueue(int max_work_items, int max_programs) { memset(work_infos[i].header, 0, sizeof(struct WorkHeader) + 16 * 1024); work_infos[i].header->array_size = 16 * 1024; work_infos[i].header->info_index = i; + work_infos[i].header->name = nullptr; } for(int i = 0; i < max_programs; i++) { @@ -36,124 +37,141 @@ void WorkQueue::stop() { this->cv_push.notify_all(); } -void WorkQueue::push(struct CommandList* command_list, void* instance_buffer, unsigned int instance_count, int queue_index, int record_type) { - std::unique_lock lock(this->mutex); - - auto start = std::chrono::high_resolution_clock::now(); +int WorkQueue::get_program_index(struct CommandList* command_list) { + int program_index = -1; - int found_indicies[2] = {-1, -1}; - - this->cv_pop.wait(lock, [this, start, command_list, &found_indicies] () { - if(!running) { - return true; + for(int i = 0; i < this->program_info_count; i++) { + // Sanity check + if(this->program_infos[i].ref_count < 0) { + set_error("Program reference count (%d) is negative!", this->program_infos[i].ref_count); + return -2; } - auto end = std::chrono::high_resolution_clock::now(); - std::chrono::duration elapsed = end - start; - - if(elapsed.count() > 500) { - set_error("Timed out waiting for room in queue"); - return true; + // Program already exists, return its index + if(this->program_infos[i].program_id == command_list->program_id) { + return i; } - int program_index = -1; - - for(int i = 0; i < this->program_info_count; i++) { - if(this->program_infos[i].ref_count < 0) { - set_error("Program reference count (%d) is negative!!!!", this->program_infos[i].ref_count); - return true; - } - - if(this->program_infos[i].program_id == command_list->program_id) { - program_index = i; - break; - } - - if(this->program_infos[i].ref_count == 0) { - program_index = i; - } - } - - if(program_index == -1) { - return false; + // Found an available slot + if(this->program_infos[i].ref_count == 0) { + program_index = i; } + } - int work_index = -1; - - for(int i = 0; i < this->work_info_count; i++) { - if(!this->work_infos[i].dirty) { - work_index = i; - break; - } - } + return program_index; +} - if(work_index == -1) { - return false; +int WorkQueue::get_work_index() { + for(int i = 0; i < this->work_info_count; i++) { + if(!this->work_infos[i].dirty) { + return i; } - - found_indicies[0] = program_index; - found_indicies[1] = work_index; - - return true; - }); - - if(!running) { - return; } - RETURN_ON_ERROR(;) - - auto end = std::chrono::high_resolution_clock::now(); - std::chrono::duration elapsed = end - start; - - if(elapsed.count() >= 5) { - return; - } + return -1; +} - work_infos[found_indicies[1]].program_index = found_indicies[0]; - work_infos[found_indicies[1]].queue_index = queue_index; - work_infos[found_indicies[1]].dirty = true; - work_infos[found_indicies[1]].state = WORK_STATE_PENDING; - work_infos[found_indicies[1]].work_id = __work_id; +void WorkQueue::prepare_work(int work_index, int program_index, struct CommandList* command_list, void* instance_buffer, unsigned int instance_count, int queue_index, int record_type, const char* name) { + // Setup work info + work_infos[work_index].program_index = program_index; + work_infos[work_index].queue_index = queue_index; + work_infos[work_index].dirty = true; + work_infos[work_index].state = WORK_STATE_PENDING; + work_infos[work_index].work_id = __work_id; __work_id += 1; - struct WorkHeader* work_header = this->work_infos[found_indicies[1]].header; + struct WorkHeader* work_header = this->work_infos[work_index].header; - if(this->program_infos[found_indicies[0]].program_id != command_list->program_id) { - if(this->program_infos[found_indicies[0]].ref_count != 0) { + // Update the program if needed + if(this->program_infos[program_index].program_id != command_list->program_id) { + // Sanity check + if(this->program_infos[program_index].ref_count != 0) { set_error("Program ID mismatch!!"); return; } - this->program_infos[found_indicies[0]].commands->clear(); + // Update program commands + this->program_infos[program_index].commands->clear(); for(CommandInfo command : command_list->commands) { - this->program_infos[found_indicies[0]].commands->push_back(command); + this->program_infos[program_index].commands->push_back(command); } - this->program_infos[found_indicies[0]].program_id = command_list->program_id; + // Update program ID + this->program_infos[program_index].program_id = command_list->program_id; } size_t work_size = command_list_get_instance_size_extern(command_list) * instance_count; + // Resize work header if needed if(work_size > work_header->array_size) { work_header = (struct WorkHeader*)realloc(work_header, sizeof(struct WorkHeader) + work_size); work_header->array_size = work_size; - work_header->info_index = found_indicies[1]; - this->work_infos[found_indicies[1]].header = work_header; + work_header->info_index = work_index; + this->work_infos[work_index].header = work_header; } + // Setup work header work_header->instance_count = instance_count; work_header->instance_size = command_list_get_instance_size_extern(command_list); - work_header->commands = this->program_infos[found_indicies[0]].commands; - work_header->program_info_index = found_indicies[0]; - work_header->record_type = (RecordType)record_type; + work_header->commands = this->program_infos[program_index].commands; + work_header->program_info_index = program_index; + work_header->record_type = (RecordType)record_type; + work_header->name = name; + // Copy instance data if needed if(work_size > 0) memcpy(&work_header[1], instance_buffer, work_size); - this->program_infos[found_indicies[0]].ref_count += 1; + // Increment program reference count + this->program_infos[program_index].ref_count += 1; +} + +bool WorkQueue::push(struct CommandList* command_list, void* instance_buffer, unsigned int instance_count, int queue_index, int record_type, const char* name) { + std::unique_lock lock(this->mutex); + + int found_indicies[2] = {-1, -1}; + + bool ready = this->cv_pop.wait_for(lock, std::chrono::seconds(1), [this, command_list, &found_indicies] () { + if(!running) { + return true; + } + + int program_index = get_program_index(command_list); + + // Error occurred, return now and exit + if(program_index == -2) + return true; + + // No available program slots, try again later + if(program_index == -1) + return false; + + int work_index = get_work_index(); + + // No available work slots, try again later + if(work_index == -1) + return false; + + found_indicies[0] = program_index; + found_indicies[1] = work_index; + + return true; + }); + + if(!ready) + return false; + + if(!running) { + return true; + } + + RETURN_ON_ERROR(true) + + prepare_work(found_indicies[1], found_indicies[0], command_list, instance_buffer, instance_count, queue_index, record_type, name); this->cv_push.notify_all(); + + return true; } bool WorkQueue::pop(struct WorkHeader** header, int queue_index) { diff --git a/vkdispatch_native/queue/work_queue.hh b/vkdispatch_native/queue/work_queue.hh index b1186c78..7277b310 100644 --- a/vkdispatch_native/queue/work_queue.hh +++ b/vkdispatch_native/queue/work_queue.hh @@ -21,6 +21,7 @@ struct WorkHeader { unsigned int instance_count; unsigned int instance_size; RecordType record_type; + const char* name; }; enum WorkState { @@ -43,7 +44,10 @@ public: WorkQueue(int max_work_items, int max_programs); void stop(); - void push(struct CommandList* command_list, void* instance_buffer, unsigned int instance_count, int queue_index, int record_type); + int get_program_index(struct CommandList* command_list); + int get_work_index(); + void prepare_work(int work_index, int program_index, struct CommandList* command_list, void* instance_buffer, unsigned int instance_count, int queue_index, int record_type, const char* name); + bool push(struct CommandList* command_list, void* instance_buffer, unsigned int instance_count, int queue_index, int record_type, const char* name); bool pop(struct WorkHeader** header, int queue_index); void finish(struct WorkHeader* header); diff --git a/vkdispatch_native/stages/stage_fft.cpp b/vkdispatch_native/stages/stage_fft.cpp index e182d307..f0b98bc2 100644 --- a/vkdispatch_native/stages/stage_fft.cpp +++ b/vkdispatch_native/stages/stage_fft.cpp @@ -17,38 +17,80 @@ struct FFTPlan { }; void print_vkfft_config(VkFFTConfiguration* config) { - LOG_INFO(R"( - VkConfig: - Size: (%d, %d, %d) - Omit Dimention: (%d, %d, %d) - Input Buffer Size: %d - Is Input Formatted: %d - Frequency Zero Padding: %d - Kernel Convolution: %d - Perform Convolution: %d - Coordinate Features: %d - Number Kernels: %d - Kernel Size: %d - Normalize: %d - Buffer Size: %d - Perform R2C: %d - Number Batches: %d - )", - config->size[0], config->size[1], config->size[2], - config->omitDimension[0], config->omitDimension[1], config->omitDimension[2], - *config->inputBufferSize, - config->isInputFormatted, - config->frequencyZeroPadding, - config->kernelConvolution, - config->performConvolution, - config->coordinateFeatures, - config->numberKernels, - *config->kernelSize, - config->normalize, - *config->bufferSize, - config->performR2C, - config->numberBatches); - //config->singleKernelMultipleBatches); + LOG_INFO(R"( +VkConfig: + FFTDim: %d + size[0]: %d + size[1]: %d + size[2]: %d + bufferSize: %llu + inputBufferSize: %llu + kernelSize: %llu + numberBatches: %d + omitDimension[0]: %d + omitDimension[1]: %d + omitDimension[2]: %d + normalize: %d + performR2C: %d + isInputFormatted: %d + performZeropadding[0]: %d + performZeropadding[1]: %d + performZeropadding[2]: %d + fft_zeropad_left[0]: %llu + fft_zeropad_left[1]: %llu + fft_zeropad_left[2]: %llu + fft_zeropad_right[0]: %llu + fft_zeropad_right[1]: %llu + fft_zeropad_right[2]: %llu + frequencyZeroPadding: %d + performConvolution: %d + conjugateConvolution: %d + coordinateFeatures: %d + numberKernels: %d + kernelConvolution: %d + maxComputeWorkGroupCount[0]: %d + maxComputeWorkGroupCount[1]: %d + maxComputeWorkGroupCount[2]: %d + maxComputeWorkGroupSize[0]: %d + maxComputeWorkGroupSize[1]: %d + maxComputeWorkGroupSize[2]: %d + )", + config->FFTdim, + config->size[0], + config->size[1], + config->size[2], + config->bufferSize ? *config->bufferSize : 0, + config->inputBufferSize ? *config->inputBufferSize : 0, + config->kernelSize ? *config->kernelSize : 0, + config->numberBatches, + config->omitDimension[0], + config->omitDimension[1], + config->omitDimension[2], + config->normalize, + config->performR2C, + config->isInputFormatted, + config->performZeropadding[0], + config->performZeropadding[1], + config->performZeropadding[2], + config->fft_zeropad_left[0], + config->fft_zeropad_left[1], + config->fft_zeropad_left[2], + config->fft_zeropad_right[0], + config->fft_zeropad_right[1], + config->fft_zeropad_right[2], + config->frequencyZeroPadding, + config->performConvolution, + config->conjugateConvolution, + config->coordinateFeatures, + config->numberKernels, + config->kernelConvolution, + config->maxComputeWorkGroupCount[0], + config->maxComputeWorkGroupCount[1], + config->maxComputeWorkGroupCount[2], + config->maxComputeWorkGroupSize[0], + config->maxComputeWorkGroupSize[1], + config->maxComputeWorkGroupSize[2] + ); } struct FFTPlan* stage_fft_plan_create_extern( @@ -111,6 +153,18 @@ struct FFTPlan* stage_fft_plan_create_extern( (VkCommandBuffer cmd_buffer, ExecIndicies indicies, void* pc_data, BarrierManager* barrier_manager, uint64_t timestamp) { LOG_VERBOSE("Initializing FFT on device %d, queue %d, recorder %d", indicies.device_index, indicies.queue_index, indicies.recorder_index); + unsigned long long true_rows = rows; + + if(do_r2c) { + true_rows = (rows / 2) + 1; + } + + int convolution_multiplier = 1; + + if(kernel_num > 0) { + convolution_multiplier = kernel_num * convolution_features; + } + VkFFTConfiguration config = {}; config.FFTdim = dims; @@ -118,12 +172,25 @@ struct FFTPlan* stage_fft_plan_create_extern( config.size[1] = cols; config.size[2] = depth; - config.disableSetLocale = 1; + config.bufferSize = (uint64_t*)malloc(sizeof(uint64_t)); + config.inputBufferSize = (uint64_t*)malloc(sizeof(uint64_t)); + config.kernelSize = (uint64_t*)malloc(sizeof(uint64_t)); + + *config.bufferSize = num_batches * convolution_multiplier * true_rows * cols * depth * sizeof(float) * 2; + *config.inputBufferSize = input_buffer_size; + *config.kernelSize = 2 * sizeof(float) * num_batches * kernel_num * convolution_features * true_rows * config.size[1] * config.size[2]; + config.numberBatches = num_batches; config.omitDimension[0] = omit_rows; config.omitDimension[1] = omit_cols; config.omitDimension[2] = omit_depth; + config.normalize = normalize; + config.performR2C = do_r2c; + config.isInputFormatted = input_buffer_size > 0; + config.keepShaderCode = keep_shader_code; + config.disableSetLocale = 1; + config.performZeropadding[0] = pad_right_rows != 0; config.performZeropadding[1] = pad_right_cols != 0; config.performZeropadding[2] = pad_right_depth != 0; @@ -135,31 +202,14 @@ struct FFTPlan* stage_fft_plan_create_extern( config.fft_zeropad_right[0] = pad_right_rows; config.fft_zeropad_right[1] = pad_right_cols; config.fft_zeropad_right[2] = pad_right_depth; - - config.keepShaderCode = keep_shader_code; - - config.inputBufferSize = (uint64_t*)malloc(sizeof(uint64_t)); - *config.inputBufferSize = input_buffer_size; - config.isInputFormatted = input_buffer_size > 0; - + config.frequencyZeroPadding = frequency_zeropadding; - unsigned long long true_rows = rows; - - if(do_r2c) { - true_rows = (rows / 2) + 1; - } - - config.kernelConvolution = kernel_convolution; - config.performConvolution = kernel_num > 0; config.conjugateConvolution = conjugate_convolution; config.coordinateFeatures = convolution_features; config.numberKernels = kernel_num; - config.kernelSize = (uint64_t*)malloc(sizeof(uint64_t)); - *config.kernelSize = 2 * sizeof(float) * kernel_num * convolution_features * true_rows * config.size[1] * config.size[2]; - - //config.singleKernelMultipleBatches = single_kernel_multiple_batches; + config.kernelConvolution = kernel_convolution; glslang_resource_t* resource = reinterpret_cast(ctx->glslang_resource_limits); @@ -171,20 +221,6 @@ struct FFTPlan* stage_fft_plan_create_extern( config.maxComputeWorkGroupSize[1] = resource->max_compute_work_group_size_y; config.maxComputeWorkGroupSize[2] = resource->max_compute_work_group_size_z; - config.normalize = normalize; - - int convolution_multiplier = 1; - - if(kernel_num > 0) { - convolution_multiplier = kernel_num * convolution_features; - } - - config.bufferSize = (uint64_t*)malloc(sizeof(uint64_t)); - *config.bufferSize = num_batches * convolution_multiplier * true_rows * cols * depth * sizeof(float) * 2; - config.performR2C = do_r2c; - - config.numberBatches = num_batches; - config.isCompilerInitialized = true; config.glslang_mutex = &ctx->glslang_mutex; config.queue_mutex = &ctx->queues[indicies.queue_index]->queue_usage_mutex;