NetDEM v1.0
Loading...
Searching...
No Matches
prefix_sum_sub.wgsl.h
Go to the documentation of this file.
1#pragma once
2
3#include <string>
4
6// Define the binding group layout
7@group(1) @binding(0) var<storage, read_write> data_0: array<u32>;
8@group(1) @binding(1) var<storage, read_write> result_0: array<u32>;
9@group(1) @binding(2) var<storage, read_write> aux_0: array<u32>;
10
11@group(1) @binding(3) var<storage, read_write> data_1: array<u32>;
12@group(1) @binding(4) var<storage, read_write> result_1: array<u32>;
13@group(1) @binding(5) var<storage, read_write> aux_1: array<u32>;
14
15// Shared memory for the workgroup
16const WORKGROUP_SIZE : u32 = 256u;
17var<workgroup> shared_data_0: array<u32, WORKGROUP_SIZE>;
18var<workgroup> shared_data_1: array<u32, WORKGROUP_SIZE>;
19
20@compute @workgroup_size(WORKGROUP_SIZE)
21fn PrefixSumSub(
22 @builtin(global_invocation_id) global_id: vec3<u32>,
23 @builtin(local_invocation_id) local_id: vec3<u32>,
24 @builtin(workgroup_id) group_id: vec3<u32>
25) {
26 // ids
27 let tid = local_id.x;
28 let gid = global_id.x;
29
30 // Load data into shared memory
31 if gid < arrayLength(&data_0) {
32 shared_data_0[tid] = data_0[gid];
33 shared_data_1[tid] = data_1[gid];
34 }
35 workgroupBarrier();
36
37 // Up-sweep (reduce) phase
38 var offset = 1u;
39 while offset < WORKGROUP_SIZE {
40 let idx = (tid + 1u) * offset * 2u - 1u;
41 if idx < WORKGROUP_SIZE {
42 shared_data_0[idx] += shared_data_0[idx - offset];
43 shared_data_1[idx] += shared_data_1[idx - offset];
44 }
45 offset *= 2u;
46 workgroupBarrier();
47 }
48
49 // Clear the last element
50 if tid == 0u {
51 shared_data_0[WORKGROUP_SIZE - 1u] = 0u;
52 shared_data_1[WORKGROUP_SIZE - 1u] = 0u;
53 }
54 workgroupBarrier();
55
56 // Down-sweep phase
57 offset = WORKGROUP_SIZE / 2u;
58 while offset > 0u {
59 let idx = (tid + 1u) * offset * 2u - 1u;
60 if idx < WORKGROUP_SIZE {
61 let t_0 = shared_data_0[idx - offset];
62 shared_data_0[idx - offset] = shared_data_0[idx];
63 shared_data_0[idx] += t_0;
64
65 let t_1 = shared_data_1[idx - offset];
66 shared_data_1[idx - offset] = shared_data_1[idx];
67 shared_data_1[idx] += t_1;
68 }
69 offset /= 2u;
70 workgroupBarrier();
71 }
72
73 // Write the result back to the global memory
74 if gid < arrayLength(&data_0) {
75 result_0[gid] = shared_data_0[tid];
76 result_1[gid] = shared_data_1[tid];
77
78 if tid == (WORKGROUP_SIZE - 1u) {
79 aux_0[group_id.x] = data_0[gid] + result_0[gid];
80 aux_1[group_id.x] = data_1[gid] + result_1[gid];
81 }
82 }
83}
84)";
shader_source
Definition prefix_sum_sub.wgsl.h:5