Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions examples/compute_workgroups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
A simple compute example demonstrating GPU workgroups and invocation IDs.

Each thread writes its global, local, and workgroup IDs into a storage buffer
so the relationship between them can be inspected.
Comment thread
almarklein marked this conversation as resolved.
"""

import numpy as np
import wgpu

# define workgroup configuration

workgroup_size = 4
workgroups = 3
total_threads = workgroup_size * workgroups

# Each thread writes 3 uint32 values
output_elements = total_threads * 3
output_bytes = output_elements * 4

# compute shader

shader_source = f"""
@group(0) @binding(0)
var<storage, read_write> out: array<u32>;

@compute
@workgroup_size({workgroup_size}, 1, 1)
fn main(
@builtin(global_invocation_id) global_id : vec3<u32>,
@builtin(local_invocation_id) local_id : vec3<u32>,
@builtin(workgroup_id) wg_id : vec3<u32>,
) {{
let base: u32 = global_id.x * 3u;

out[base] = global_id.x;
out[base + 1] = local_id.x;
out[base + 2] = wg_id.x;
}}
"""

# adapter and device

adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance")
device = adapter.request_device_sync()

# storage buffer

output_buffer = device.create_buffer(
size=output_bytes,
usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_SRC,
)

# Shader module and compute pipeline

shader_module = device.create_shader_module(code=shader_source)

compute_pipeline = device.create_compute_pipeline(
layout="auto",
compute={"module": shader_module, "entry_point": "main"},
)

# bind group

bind_group = device.create_bind_group(
layout=compute_pipeline.get_bind_group_layout(0),
entries=[
{
"binding": 0,
"resource": {
"buffer": output_buffer,
"offset": 0,
"size": output_buffer.size,
},
}
],
)

# encode, dispatch and submit

command_encoder = device.create_command_encoder()
compute_pass = command_encoder.begin_compute_pass()

compute_pass.set_pipeline(compute_pipeline)
compute_pass.set_bind_group(0, bind_group)
compute_pass.dispatch_workgroups(workgroups, 1, 1)

compute_pass.end()

device.queue.submit([command_encoder.finish()])

# results

raw = device.queue.read_buffer(output_buffer)
values = np.frombuffer(raw, dtype=np.uint32)

print(
f"Dispatched {workgroups} workgroup(s) of {workgroup_size} thread(s) each "
f"({total_threads} threads total).\n"
)

print(f"{'Thread':>6} {'global_id':>9} {'local_id':>8} {'workgroup_id':>12} ")
Comment thread
almarklein marked this conversation as resolved.

for i in range(total_threads):
global_id = values[i * 3]
local_id = values[i * 3 + 1]
workgroup_id = values[i * 3 + 2]

print(f"{i:>6} {global_id:>9} {local_id:>8} {workgroup_id:>12}")

# verify invocation ID relationships
assert global_id == i
assert local_id == i % workgroup_size
assert workgroup_id == i // workgroup_size

print("Invocation ID mapping verified")
Loading