NetDEM v1.0
Loading...
Searching...
No Matches
radix_count.wgsl.h
Go to the documentation of this file.
1#pragma once
2
3#include <string>
4
6// Define the binding group layout
7@group(0) @binding(0) var<storage, read_write> data_keys: array<u32>;
8@group(0) @binding(1) var<storage, read_write> data_vals: array<u32>;
9@group(0) @binding(2) var<storage, read_write> data_keys_bak: array<u32>;
10@group(0) @binding(3) var<storage, read_write> data_vals_bak: array<u32>;
11@group(0) @binding(4) var<storage, read_write> bit_id: u32;
12
13@group(1) @binding(0) var<storage, read_write> bit_0_flag: array<u32>;
14@group(1) @binding(1) var<storage, read_write> bit_0_prefix_sum: array<u32>;
15@group(1) @binding(2) var<storage, read_write> bit_0_aux_src: array<u32>;
16
17@group(1) @binding(3) var<storage, read_write> bit_1_flag: array<u32>;
18@group(1) @binding(4) var<storage, read_write> bit_1_prefix_sum: array<u32>;
19@group(1) @binding(5) var<storage, read_write> bit_1_aux_src: array<u32>;
20
21// Shared memory for the workgroup
22const WORKGROUP_SIZE : u32 = 256u;
23var<workgroup> shared_data_0: array<u32, WORKGROUP_SIZE>;
24var<workgroup> shared_data_1: array<u32, WORKGROUP_SIZE>;
25
26@compute @workgroup_size(WORKGROUP_SIZE)
27fn RadixCount(
28 @builtin(global_invocation_id) global_id: vec3<u32>,
29 @builtin(local_invocation_id) local_id: vec3<u32>,
30 @builtin(workgroup_id) group_id: vec3<u32>
31) {
32 // ids
33 let tid = local_id.x;
34 let gid = global_id.x;
35
36 // Load data into shared memory
37 if gid < arrayLength(&data_keys) {
38 data_keys_bak[gid] = data_keys[gid];
39 data_vals_bak[gid] = data_vals[gid];
40
41 let extract_bits: u32 = (data_vals[gid] >> bit_id) & 0x1;
42 bit_0_flag[gid] = 1 - extract_bits;
43 bit_1_flag[gid] = extract_bits;
44
45 shared_data_0[tid] = bit_0_flag[gid];
46 shared_data_1[tid] = bit_1_flag[gid];
47 }
48 workgroupBarrier();
49
50 // Up-sweep (reduce) phase
51 var offset = 1u;
52 while offset < WORKGROUP_SIZE {
53 let idx = (tid + 1u) * offset * 2u - 1u;
54 if idx < WORKGROUP_SIZE {
55 shared_data_0[idx] += shared_data_0[idx - offset];
56 shared_data_1[idx] += shared_data_1[idx - offset];
57 }
58 offset *= 2u;
59 workgroupBarrier();
60 }
61
62 // Clear the last element
63 if tid == 0u {
64 shared_data_0[WORKGROUP_SIZE - 1u] = 0u;
65 shared_data_1[WORKGROUP_SIZE - 1u] = 0u;
66 }
67 workgroupBarrier();
68
69 // Down-sweep phase
70 offset = WORKGROUP_SIZE / 2u;
71 while offset > 0u {
72 let idx = (tid + 1u) * offset * 2u - 1u;
73 if idx < WORKGROUP_SIZE {
74 let t_0 = shared_data_0[idx - offset];
75 shared_data_0[idx - offset] = shared_data_0[idx];
76 shared_data_0[idx] += t_0;
77
78 let t_1 = shared_data_1[idx - offset];
79 shared_data_1[idx - offset] = shared_data_1[idx];
80 shared_data_1[idx] += t_1;
81 }
82 offset /= 2u;
83 workgroupBarrier();
84 }
85
86 // Write the result back to the global memory
87 if gid < arrayLength(&data_keys) {
88 bit_0_prefix_sum[gid] = shared_data_0[tid];
89 bit_1_prefix_sum[gid] = shared_data_1[tid];
90
91 if tid == (WORKGROUP_SIZE - 1u) {
92 bit_0_aux_src[group_id.x] = bit_0_flag[gid] + bit_0_prefix_sum[gid];
93 bit_1_aux_src[group_id.x] = bit_1_flag[gid] + bit_1_prefix_sum[gid];
94 }
95 }
96}
97)";
shader_source
Definition radix_count.wgsl.h:5