14 void Init(wgpu::Device &device);
16 void Solve(wgpu::Buffer data_keys, wgpu::Buffer data_vals,
size_t array_size);
23 wgpu::ComputePipeline cmp_pl_radix_count;
24 wgpu::ComputePipeline cmp_pl_prefix_sum_sub;
25 wgpu::ComputePipeline cmp_pl_add_block_sums;
26 wgpu::ComputePipeline cmp_pl_radix_reorder;
29 void InitBindGroupLayout();
31 wgpu::BindGroup GetBindGroup0(wgpu::Buffer data_keys, wgpu::Buffer data_vals,
32 wgpu::Buffer data_keys_bak,
33 wgpu::Buffer data_vals_bak, wgpu::Buffer bit_id,
36 wgpu::BindGroup GetBindGroup1(
37 wgpu::Buffer buffer_data_src_0, wgpu::Buffer buffer_presum_src_0,
38 wgpu::Buffer buffer_dst_0, wgpu::Buffer buffer_data_src_1,
39 wgpu::Buffer buffer_presum_src_1, wgpu::Buffer buffer_dst_1,
40 size_t offset_data_src,
size_t offset_presum_src,
size_t offset_dst,
41 size_t array_size_src,
size_t array_size_dst);
45 void RadixCount(wgpu::Buffer buffer_keys, wgpu::Buffer buffer_vals,
46 wgpu::Buffer buffer_keys_bak, wgpu::Buffer buffer_vals_bak,
47 wgpu::Buffer buffer_bit_id, wgpu::Buffer buffer_data_src_0,
48 wgpu::Buffer buffer_presum_src_0, wgpu::Buffer buffer_dst_0,
49 wgpu::Buffer buffer_data_src_1,
50 wgpu::Buffer buffer_presum_src_1, wgpu::Buffer buffer_dst_1,
51 size_t array_size_src,
size_t array_size_dst,
52 wgpu::ComputePassEncoder cmp_pass);
54 void PrefixSumSub(wgpu::Buffer buffer_data_src_0,
55 wgpu::Buffer buffer_presum_src_0, wgpu::Buffer buffer_dst_0,
56 wgpu::Buffer buffer_data_src_1,
57 wgpu::Buffer buffer_presum_src_1, wgpu::Buffer buffer_dst_1,
58 size_t offset_data_src,
size_t offset_presum_src,
59 size_t offset_dst,
size_t array_size_src,
60 size_t array_size_dst, wgpu::ComputePassEncoder cmp_pass);
62 void AddBlockSums(wgpu::Buffer buffer_data_src_0,
63 wgpu::Buffer buffer_presum_src_0, wgpu::Buffer buffer_dst_0,
64 wgpu::Buffer buffer_data_src_1,
65 wgpu::Buffer buffer_presum_src_1, wgpu::Buffer buffer_dst_1,
66 size_t offset_data_src,
size_t offset_presum_src,
67 size_t offset_dst,
size_t array_size_src,
68 size_t array_size_dst, wgpu::ComputePassEncoder cmp_pass);
70 void RadixReorder(wgpu::Buffer buffer_data_src_0,
71 wgpu::Buffer buffer_presum_src_0, wgpu::Buffer buffer_dst_0,
72 wgpu::Buffer buffer_data_src_1,
73 wgpu::Buffer buffer_presum_src_1, wgpu::Buffer buffer_dst_1,
74 size_t array_size_src,
size_t array_size_dst,
75 wgpu::ComputePassEncoder cmp_pass);
77 int GetMaxBitID(wgpu::Buffer data_vals,
size_t array_size);
79 const size_t workgroup_size = 256;
80 const size_t ele_size =
sizeof(u_int32_t);
81 const size_t align_size = 256 / ele_size;