6// Define the binding group layout
7@group(0) @binding(0) var<storage, read_write> data: array<u32>;
8@group(0) @binding(1) var<storage, read_write> result: array<u32>;
9@group(0) @binding(2) var<storage, read_write> aux: array<u32>;
11// Shared memory for the workgroup
12const WORKGROUP_SIZE : u32 = 256u;
13var<workgroup> shared_data: array<u32, WORKGROUP_SIZE>;
15@compute @workgroup_size(WORKGROUP_SIZE)
17 @builtin(global_invocation_id) global_id: vec3<u32>,
18 @builtin(local_invocation_id) local_id: vec3<u32>,
19 @builtin(workgroup_id) group_id: vec3<u32>
23 let gid = global_id.x;
25 // Load data into shared memory
26 if gid < arrayLength(&data) {
27 shared_data[tid] = data[gid];
31 // Up-sweep (reduce) phase
33 while offset < WORKGROUP_SIZE {
34 let idx = (tid + 1u) * offset * 2u - 1u;
35 if idx < WORKGROUP_SIZE {
36 shared_data[idx] += shared_data[idx - offset];
42 // Clear the last element
44 shared_data[WORKGROUP_SIZE - 1u] = 0u;
49 offset = WORKGROUP_SIZE / 2u;
51 let idx = (tid + 1u) * offset * 2u - 1u;
52 if idx < WORKGROUP_SIZE {
53 let t = shared_data[idx - offset];
54 shared_data[idx - offset] = shared_data[idx];
55 shared_data[idx] += t;
61 // Write the result back to the global memory
62 if gid < arrayLength(&data) {
63 result[gid] = shared_data[tid];
64 if tid == (WORKGROUP_SIZE - 1u) {
65 aux[group_id.x] = data[gid] + result[gid];
shader_source
Definition prefix_sum_sub.wgsl.h:5