6// Define the binding group layout
7@group(0) @binding(0) var<storage, read_write> data_keys: array<u32>;
8@group(0) @binding(1) var<storage, read_write> data_vals: array<u32>;
9@group(0) @binding(2) var<storage, read_write> data_keys_bak: array<u32>;
10@group(0) @binding(3) var<storage, read_write> data_vals_bak: array<u32>;
11@group(0) @binding(4) var<storage, read_write> bit_id: u32;
13@group(1) @binding(0) var<storage, read_write> bit_0_flag: array<u32>;
14@group(1) @binding(1) var<storage, read_write> bit_0_prefix_sum: array<u32>;
15@group(1) @binding(2) var<storage, read_write> bit_0_aux_src: array<u32>;
17@group(1) @binding(3) var<storage, read_write> bit_1_flag: array<u32>;
18@group(1) @binding(4) var<storage, read_write> bit_1_prefix_sum: array<u32>;
19@group(1) @binding(5) var<storage, read_write> bit_1_aux_src: array<u32>;
21// Shared memory for the workgroup
22const WORKGROUP_SIZE : u32 = 256u;
23var<workgroup> shared_data_0: array<u32, WORKGROUP_SIZE>;
24var<workgroup> shared_data_1: array<u32, WORKGROUP_SIZE>;
26@compute @workgroup_size(WORKGROUP_SIZE)
28 @builtin(global_invocation_id) global_id: vec3<u32>,
29 @builtin(local_invocation_id) local_id: vec3<u32>,
30 @builtin(workgroup_id) group_id: vec3<u32>
34 let gid = global_id.x;
36 // Load data into shared memory
37 if gid < arrayLength(&data_keys) {
38 data_keys_bak[gid] = data_keys[gid];
39 data_vals_bak[gid] = data_vals[gid];
41 let extract_bits: u32 = (data_vals[gid] >> bit_id) & 0x1;
42 bit_0_flag[gid] = 1 - extract_bits;
43 bit_1_flag[gid] = extract_bits;
45 shared_data_0[tid] = bit_0_flag[gid];
46 shared_data_1[tid] = bit_1_flag[gid];
50 // Up-sweep (reduce) phase
52 while offset < WORKGROUP_SIZE {
53 let idx = (tid + 1u) * offset * 2u - 1u;
54 if idx < WORKGROUP_SIZE {
55 shared_data_0[idx] += shared_data_0[idx - offset];
56 shared_data_1[idx] += shared_data_1[idx - offset];
62 // Clear the last element
64 shared_data_0[WORKGROUP_SIZE - 1u] = 0u;
65 shared_data_1[WORKGROUP_SIZE - 1u] = 0u;
70 offset = WORKGROUP_SIZE / 2u;
72 let idx = (tid + 1u) * offset * 2u - 1u;
73 if idx < WORKGROUP_SIZE {
74 let t_0 = shared_data_0[idx - offset];
75 shared_data_0[idx - offset] = shared_data_0[idx];
76 shared_data_0[idx] += t_0;
78 let t_1 = shared_data_1[idx - offset];
79 shared_data_1[idx - offset] = shared_data_1[idx];
80 shared_data_1[idx] += t_1;
86 // Write the result back to the global memory
87 if gid < arrayLength(&data_keys) {
88 bit_0_prefix_sum[gid] = shared_data_0[tid];
89 bit_1_prefix_sum[gid] = shared_data_1[tid];
91 if tid == (WORKGROUP_SIZE - 1u) {
92 bit_0_aux_src[group_id.x] = bit_0_flag[gid] + bit_0_prefix_sum[gid];
93 bit_1_aux_src[group_id.x] = bit_1_flag[gid] + bit_1_prefix_sum[gid];
shader_source
Definition radix_count.wgsl.h:5