diff --git a/examples/compute_workgroups.py b/examples/compute_workgroups.py new file mode 100644 index 00000000..16ec58c9 --- /dev/null +++ b/examples/compute_workgroups.py @@ -0,0 +1,116 @@ +""" +A simple compute example demonstrating GPU workgroups and invocation IDs. + +Each thread writes its global, local, and workgroup IDs into a storage buffer +so the relationship between them can be inspected. +""" + +import numpy as np +import wgpu + +# define workgroup configuration + +workgroup_size = 4 +workgroups = 3 +total_threads = workgroup_size * workgroups + +# Each thread writes 3 uint32 values +output_elements = total_threads * 3 +output_bytes = output_elements * 4 + +# compute shader + +shader_source = f""" +@group(0) @binding(0) +var out: array; + +@compute +@workgroup_size({workgroup_size}, 1, 1) +fn main( + @builtin(global_invocation_id) global_id : vec3, + @builtin(local_invocation_id) local_id : vec3, + @builtin(workgroup_id) wg_id : vec3, +) {{ + let base: u32 = global_id.x * 3u; + + out[base] = global_id.x; + out[base + 1] = local_id.x; + out[base + 2] = wg_id.x; +}} +""" + +# adapter and device + +adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance") +device = adapter.request_device_sync() + +# storage buffer + +output_buffer = device.create_buffer( + size=output_bytes, + usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_SRC, +) + +# Shader module and compute pipeline + +shader_module = device.create_shader_module(code=shader_source) + +compute_pipeline = device.create_compute_pipeline( + layout="auto", + compute={"module": shader_module, "entry_point": "main"}, +) + +# bind group + +bind_group = device.create_bind_group( + layout=compute_pipeline.get_bind_group_layout(0), + entries=[ + { + "binding": 0, + "resource": { + "buffer": output_buffer, + "offset": 0, + "size": output_buffer.size, + }, + } + ], +) + +# encode, dispatch and submit + +command_encoder = device.create_command_encoder() +compute_pass = command_encoder.begin_compute_pass() + +compute_pass.set_pipeline(compute_pipeline) +compute_pass.set_bind_group(0, bind_group) +compute_pass.dispatch_workgroups(workgroups, 1, 1) + +compute_pass.end() + +device.queue.submit([command_encoder.finish()]) + +# results + +raw = device.queue.read_buffer(output_buffer) +values = np.frombuffer(raw, dtype=np.uint32) + +print( + f"Dispatched {workgroups} workgroup(s) of {workgroup_size} thread(s) each " + f"({total_threads} threads total).\n" +) + +print(f"{'Thread':>6} {'global_id':>9} {'local_id':>8} {'workgroup_id':>12} ") + +for i in range(total_threads): + global_id = values[i * 3] + local_id = values[i * 3 + 1] + workgroup_id = values[i * 3 + 2] + + print(f"{i:>6} {global_id:>9} {local_id:>8} {workgroup_id:>12}") + + # verify invocation ID relationships + assert global_id == i + assert local_id == i % workgroup_size + assert workgroup_id == i // workgroup_size + +print("Invocation ID mapping verified")