NetDEM v1.0
Loading...
Searching...
No Matches
prefix_sum_sub.wgsl.h
Go to the documentation of this file.
1#pragma once
2
3#include <string>
4
6// Define the binding group layout
7@group(0) @binding(0) var<storage, read_write> data: array<u32>;
8@group(0) @binding(1) var<storage, read_write> result: array<u32>;
9@group(0) @binding(2) var<storage, read_write> aux: array<u32>;
10
11// Shared memory for the workgroup
12const WORKGROUP_SIZE : u32 = 256u;
13var<workgroup> shared_data: array<u32, WORKGROUP_SIZE>;
14
15@compute @workgroup_size(WORKGROUP_SIZE)
16fn PrefixSumSub(
17 @builtin(global_invocation_id) global_id: vec3<u32>,
18 @builtin(local_invocation_id) local_id: vec3<u32>,
19 @builtin(workgroup_id) group_id: vec3<u32>
20) {
21 // ids
22 let tid = local_id.x;
23 let gid = global_id.x;
24
25 // Load data into shared memory
26 if gid < arrayLength(&data) {
27 shared_data[tid] = data[gid];
28 }
29 workgroupBarrier();
30
31 // Up-sweep (reduce) phase
32 var offset = 1u;
33 while offset < WORKGROUP_SIZE {
34 let idx = (tid + 1u) * offset * 2u - 1u;
35 if idx < WORKGROUP_SIZE {
36 shared_data[idx] += shared_data[idx - offset];
37 }
38 offset *= 2u;
39 workgroupBarrier();
40 }
41
42 // Clear the last element
43 if tid == 0u {
44 shared_data[WORKGROUP_SIZE - 1u] = 0u;
45 }
46 workgroupBarrier();
47
48 // Down-sweep phase
49 offset = WORKGROUP_SIZE / 2u;
50 while offset > 0u {
51 let idx = (tid + 1u) * offset * 2u - 1u;
52 if idx < WORKGROUP_SIZE {
53 let t = shared_data[idx - offset];
54 shared_data[idx - offset] = shared_data[idx];
55 shared_data[idx] += t;
56 }
57 offset /= 2u;
58 workgroupBarrier();
59 }
60
61 // Write the result back to the global memory
62 if gid < arrayLength(&data) {
63 result[gid] = shared_data[tid];
64 if tid == (WORKGROUP_SIZE - 1u) {
65 aux[group_id.x] = data[gid] + result[gid];
66 }
67 }
68}
69)";
shader_source
Definition prefix_sum_sub.wgsl.h:5