6// Define the binding group layout
7@group(1) @binding(0) var<storage, read_write> data_0: array<u32>;
8@group(1) @binding(1) var<storage, read_write> result_0: array<u32>;
9@group(1) @binding(2) var<storage, read_write> aux_0: array<u32>;
11@group(1) @binding(3) var<storage, read_write> data_1: array<u32>;
12@group(1) @binding(4) var<storage, read_write> result_1: array<u32>;
13@group(1) @binding(5) var<storage, read_write> aux_1: array<u32>;
15// Shared memory for the workgroup
16const WORKGROUP_SIZE : u32 = 256u;
17var<workgroup> shared_data_0: array<u32, WORKGROUP_SIZE>;
18var<workgroup> shared_data_1: array<u32, WORKGROUP_SIZE>;
20@compute @workgroup_size(WORKGROUP_SIZE)
22 @builtin(global_invocation_id) global_id: vec3<u32>,
23 @builtin(local_invocation_id) local_id: vec3<u32>,
24 @builtin(workgroup_id) group_id: vec3<u32>
28 let gid = global_id.x;
30 // Load data into shared memory
31 if gid < arrayLength(&data_0) {
32 shared_data_0[tid] = data_0[gid];
33 shared_data_1[tid] = data_1[gid];
37 // Up-sweep (reduce) phase
39 while offset < WORKGROUP_SIZE {
40 let idx = (tid + 1u) * offset * 2u - 1u;
41 if idx < WORKGROUP_SIZE {
42 shared_data_0[idx] += shared_data_0[idx - offset];
43 shared_data_1[idx] += shared_data_1[idx - offset];
49 // Clear the last element
51 shared_data_0[WORKGROUP_SIZE - 1u] = 0u;
52 shared_data_1[WORKGROUP_SIZE - 1u] = 0u;
57 offset = WORKGROUP_SIZE / 2u;
59 let idx = (tid + 1u) * offset * 2u - 1u;
60 if idx < WORKGROUP_SIZE {
61 let t_0 = shared_data_0[idx - offset];
62 shared_data_0[idx - offset] = shared_data_0[idx];
63 shared_data_0[idx] += t_0;
65 let t_1 = shared_data_1[idx - offset];
66 shared_data_1[idx - offset] = shared_data_1[idx];
67 shared_data_1[idx] += t_1;
73 // Write the result back to the global memory
74 if gid < arrayLength(&data_0) {
75 result_0[gid] = shared_data_0[tid];
76 result_1[gid] = shared_data_1[tid];
78 if tid == (WORKGROUP_SIZE - 1u) {
79 aux_0[group_id.x] = data_0[gid] + result_0[gid];
80 aux_1[group_id.x] = data_1[gid] + result_1[gid];
shader_source
Definition prefix_sum_sub.wgsl.h:5