From 13c7aff6347e6bbf8833902c57ef5d8c416ed047 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Sun, 23 Feb 2025 22:24:25 +0000 Subject: [PATCH 1/5] Simplify backwards rasterization --- .../src/shaders/rasterize_backwards.wgsl | 250 ++++++++---------- 1 file changed, 109 insertions(+), 141 deletions(-) diff --git a/crates/brush-train/src/shaders/rasterize_backwards.wgsl b/crates/brush-train/src/shaders/rasterize_backwards.wgsl index d2ad0f6..c08ef8a 100644 --- a/crates/brush-train/src/shaders/rasterize_backwards.wgsl +++ b/crates/brush-train/src/shaders/rasterize_backwards.wgsl @@ -25,30 +25,31 @@ const BATCH_SIZE = helpers::TILE_SIZE; var local_batch: array; var local_id: array; -// This kernel use a new technique to reduce the overhead of atomic gradient accumulation, especially when -// using software CAS loops this helps performance a lot. Originally, each thread calculated -// a gradient, summed them together in a subgroup, and one thread of these subgroups would then atomically add -// this gradient to the global gradient. Instead, we push each subgroup gradient to a buffer -// until it has N threads gradients, which are then written to the global gradients all at once. - -// Current queue of gradients to be flushed. -var grad_count: atomic; - -const TOTAL_GRADS = BATCH_SIZE * 11; -var gather_grads: array; -var gather_grad_id: array; - fn add_bitcast(cur: u32, add: f32) -> u32 { return bitcast(bitcast(cur) + add); } -fn write_grads_atomic(grads: f32, id: i32) { +fn write_grads_atomic(id: i32, grads: f32) { + let p = &v_splats[id]; +#ifdef HARD_FLOAT + atomicAdd(p, grads); +#else + var old_value = atomicLoad(p); + loop { + let cas = atomicCompareExchangeWeak(p, old_value, add_bitcast(old_value, grads)); + if cas.exchanged { break; } else { old_value = cas.old_value; } + } +#endif +} + +fn write_refine_atomic(id: i32, grads: f32) { + let p = &v_refine_grad[id]; #ifdef HARD_FLOAT - atomicAdd(&v_splats[id], grads); + atomicAdd(p, grads); #else - var old_value = atomicLoad(&v_splats[id]); + var old_value = atomicLoad(p); loop { - let cas = atomicCompareExchangeWeak(&v_splats[id], old_value, add_bitcast(old_value, grads)); + let cas = atomicCompareExchangeWeak(p, old_value, add_bitcast(old_value, grads)); if cas.exchanged { break; } else { old_value = cas.old_value; } } #endif @@ -106,12 +107,6 @@ fn main( v_out = v_output[pix_id]; } - // Make sure all groups start with empty gradient queue. - atomicStore(&grad_count, 0); - - let sg_per_tile = helpers::ceil_div(i32(helpers::TILE_SIZE), i32(subgroup_size)); - let microbatch_size = i32(helpers::TILE_SIZE) / sg_per_tile; - for (var b = 0; b < num_batches; b++) { // each thread fetch 1 gaussian from back to front // 0 index will be furthest back in batch @@ -119,150 +114,123 @@ fn main( let batch_end = range.y - b * i32(BATCH_SIZE); let remaining = min(i32(BATCH_SIZE), batch_end - range.x); - // Gather N gaussians. - var load_compact_gid = 0; + // Each thread first gathers one gaussian. if i32(local_idx) < remaining { let load_isect_id = batch_end - 1 - i32(local_idx); - load_compact_gid = compact_gid_from_isect[load_isect_id]; + let load_compact_gid = compact_gid_from_isect[load_isect_id]; local_id[local_idx] = load_compact_gid; local_batch[local_idx] = projected_splats[load_compact_gid]; } - for (var tb = 0; tb < remaining; tb += microbatch_size) { - if local_idx == 0 { - atomicStore(&grad_count, 0); - } - workgroupBarrier(); + // Wait for all threads to finish loading. + workgroupBarrier(); - for(var tt = 0; tt < microbatch_size; tt++) { - let t = tb + tt; + for (var t = 0; t < remaining; t += 1) { + let isect_id = batch_end - 1 - t; - if t >= remaining { - break; - } + var v_xy = vec2f(0.0); + var v_conic = vec3f(0.0); + var v_colors = vec4f(0.0); + var v_refine = vec2f(0.0); - let isect_id = batch_end - 1 - t; + var splat_active = false; - var v_xy = vec2f(0.0); - var v_conic = vec3f(0.0); - var v_colors = vec4f(0.0); - var v_refine = vec2f(0.0); + if inside && isect_id < final_isect { + let projected = local_batch[t]; - var splat_active = false; + let xy = vec2f(projected.xy_x, projected.xy_y); + let conic = vec3f(projected.conic_x, projected.conic_y, projected.conic_z); + let color = vec4f(projected.color_r, projected.color_g, projected.color_b, projected.color_a); - if inside && isect_id < final_isect { - let projected = local_batch[t]; + let delta = xy - pixel_coord; + let sigma = 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + conic.y * delta.x * delta.y; + let vis = exp(-sigma); + let alpha = min(0.99f, color.w * vis); - let xy = vec2f(projected.xy_x, projected.xy_y); - let conic = vec3f(projected.conic_x, projected.conic_y, projected.conic_z); - let color = vec4f(projected.color_r, projected.color_g, projected.color_b, projected.color_a); + // Nb: Don't continue; here - local_idx == 0 always + // needs to write out gradients. + // compute the current T for this gaussian + if (sigma >= 0.0 && alpha >= 1.0 / 255.0) { + splat_active = true; - let delta = xy - pixel_coord; - let sigma = 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + conic.y * delta.x * delta.y; - let vis = exp(-sigma); - let alpha = min(0.99f, color.w * vis); + let ra = 1.0 / (1.0 - alpha); + T *= ra; + // update v_colors for this gaussian + let fac = alpha * T; - // Nb: Don't continue; here - local_idx == 0 always - // needs to write out gradients. - // compute the current T for this gaussian - if (sigma >= 0.0 && alpha >= 1.0 / 255.0) { - splat_active = true; + // contribution from this pixel + let clamped_rgb = max(color.rgb, vec3f(0.0)); + var v_alpha = dot(clamped_rgb * T - buffer * ra, v_out.rgb); + v_alpha += T_final * ra * v_out.a; - let ra = 1.0 / (1.0 - alpha); - T *= ra; - // update v_colors for this gaussian - let fac = alpha * T; + // update the running sum + buffer += clamped_rgb * fac; - // contribution from this pixel - let clamped_rgb = max(color.rgb, vec3f(0.0)); - var v_alpha = dot(clamped_rgb * T - buffer * ra, v_out.rgb); - v_alpha += T_final * ra * v_out.a; + let v_sigma = -color.a * vis * v_alpha; - // update the running sum - buffer += clamped_rgb * fac; + v_xy = v_sigma * vec2f( + conic.x * delta.x + conic.y * delta.y, + conic.y * delta.x + conic.z * delta.y + ); - let v_sigma = -color.a * vis * v_alpha; + v_conic = vec3f(0.5f * v_sigma * delta.x * delta.x, + v_sigma * delta.x * delta.y, + 0.5f * v_sigma * delta.y * delta.y); - v_xy = v_sigma * vec2f( - conic.x * delta.x + conic.y * delta.y, - conic.y * delta.x + conic.z * delta.y - ); + let v_rgb = select(vec3f(0.0), fac * v_out.rgb, color.rgb > vec3f(0.0)); + v_colors = vec4f(v_rgb, vis * v_alpha); - v_conic = vec3f(0.5f * v_sigma * delta.x * delta.x, - v_sigma * delta.x * delta.y, - 0.5f * v_sigma * delta.y * delta.y); + v_refine = abs(v_xy); + } + } - let v_rgb = select(vec3f(0.0), fac * v_out.rgb, color.rgb > vec3f(0.0)); - v_colors = vec4f(v_rgb, vis * v_alpha); + // Queue a new gradient if this subgroup has any. + // The gradient is sum of all gradients in the subgroup. + if subgroupAny(splat_active) { + let compact_gid = local_id[t]; - v_refine = abs(v_xy); - } - } + let v_xy_sum = subgroupAdd(v_xy); + let v_conic_sum = subgroupAdd(v_conic); + let v_colors_sum = subgroupAdd(v_colors); + let v_refine_sum = subgroupAdd(v_refine); - // Queue a new gradient if this subgroup has any. - // The gradient is sum of all gradients in the subgroup. - if subgroupAny(splat_active) { - let v_xy_sum = subgroupAdd(v_xy); - let v_conic_sum = subgroupAdd(v_conic); - let v_colors_sum = subgroupAdd(v_colors); - let v_refine_sum = subgroupAdd(v_refine); - - // First thread of subgroup writes the gradient. This should be a - // subgroupBallot() when it's supported. - if subgroup_invocation_id == 0 { - let grad_idx = atomicAdd(&grad_count, 1); - gather_grads[grad_idx * 11 + 0] = v_xy_sum.x; - gather_grads[grad_idx * 11 + 1] = v_xy_sum.y; - gather_grads[grad_idx * 11 + 2] = v_conic_sum.x; - gather_grads[grad_idx * 11 + 3] = v_conic_sum.y; - gather_grads[grad_idx * 11 + 4] = v_conic_sum.z; - gather_grads[grad_idx * 11 + 5] = v_colors_sum.x; - gather_grads[grad_idx * 11 + 6] = v_colors_sum.y; - gather_grads[grad_idx * 11 + 7] = v_colors_sum.z; - gather_grads[grad_idx * 11 + 8] = v_colors_sum.w; - - gather_grads[grad_idx * 11 + 9] = v_refine_sum.x; - gather_grads[grad_idx * 11 + 10] = v_refine_sum.y; - - gather_grad_id[grad_idx] = local_id[t]; - } + if subgroup_invocation_id == 0 { + write_grads_atomic(compact_gid * 9 + 0, v_xy_sum.x); + } + else if subgroup_invocation_id == 1 { + write_grads_atomic(compact_gid * 9 + 1, v_xy_sum.y); + } + else if subgroup_invocation_id == 2 { + write_grads_atomic(compact_gid * 9 + 2, v_conic_sum.x); + } + else if subgroup_invocation_id == 3 { + write_grads_atomic(compact_gid * 9 + 3, v_conic_sum.y); + } + else if subgroup_invocation_id == 4 { + write_grads_atomic(compact_gid * 9 + 4, v_conic_sum.z); + } + else if subgroup_invocation_id == 5 { + write_grads_atomic(compact_gid * 9 + 5, v_colors_sum.x); + } + else if subgroup_invocation_id == 6 { + write_grads_atomic(compact_gid * 9 + 6, v_colors_sum.y); + } + else if subgroup_invocation_id == 7 { + write_grads_atomic(compact_gid * 9 + 7, v_colors_sum.z); + } + else if subgroup_invocation_id == 8 { + write_grads_atomic(compact_gid * 9 + 8, v_colors_sum.w); + } + else if subgroup_invocation_id == 9 { + write_refine_atomic(compact_gid * 2 + 0, v_refine_sum.x); + } + else if subgroup_invocation_id == 10 { + write_refine_atomic(compact_gid * 2 + 1, v_refine_sum.y); } } - - // Make sure all threads are done, and flush a batch of gradients. - workgroupBarrier(); - if local_idx < u32(grad_count) { - let compact_gid = gather_grad_id[local_idx]; - write_grads_atomic(gather_grads[local_idx * 11 + 0], compact_gid * 9 + 0); - write_grads_atomic(gather_grads[local_idx * 11 + 1], compact_gid * 9 + 1); - write_grads_atomic(gather_grads[local_idx * 11 + 2], compact_gid * 9 + 2); - write_grads_atomic(gather_grads[local_idx * 11 + 3], compact_gid * 9 + 3); - write_grads_atomic(gather_grads[local_idx * 11 + 4], compact_gid * 9 + 4); - write_grads_atomic(gather_grads[local_idx * 11 + 5], compact_gid * 9 + 5); - write_grads_atomic(gather_grads[local_idx * 11 + 6], compact_gid * 9 + 6); - write_grads_atomic(gather_grads[local_idx * 11 + 7], compact_gid * 9 + 7); - write_grads_atomic(gather_grads[local_idx * 11 + 8], compact_gid * 9 + 8); - - let refine_grad_x = gather_grads[local_idx * 11 + 9]; - let refine_grad_y = gather_grads[local_idx * 11 + 10]; - - #ifdef HARD_FLOAT - atomicAdd(&v_refine_grad[compact_gid * 2 + 0], refine_grad_x); - atomicAdd(&v_refine_grad[compact_gid * 2 + 1], refine_grad_y); - #else - var old_value = atomicLoad(&v_refine_grad[compact_gid * 2 + 0]); - loop { - let cas = atomicCompareExchangeWeak(&v_refine_grad[compact_gid * 2 + 0], old_value, add_bitcast(old_value, refine_grad_x)); - if cas.exchanged { break; } else { old_value = cas.old_value; } - } - old_value = atomicLoad(&v_refine_grad[compact_gid * 2 + 1]); - loop { - let cas = atomicCompareExchangeWeak(&v_refine_grad[compact_gid * 2 + 1], old_value, add_bitcast(old_value, refine_grad_y)); - if cas.exchanged { break; } else { old_value = cas.old_value; } - } - #endif - } - workgroupBarrier(); } + + // Wait for all gradients to be written. + workgroupBarrier(); } } From bee0762cea3eb5fb1b7bd62271f55b3585af9ab9 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Sun, 23 Feb 2025 22:35:37 +0000 Subject: [PATCH 2/5] Support small workgroups --- crates/brush-train/src/shaders/rasterize_backwards.wgsl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/crates/brush-train/src/shaders/rasterize_backwards.wgsl b/crates/brush-train/src/shaders/rasterize_backwards.wgsl index c08ef8a..3d4e218 100644 --- a/crates/brush-train/src/shaders/rasterize_backwards.wgsl +++ b/crates/brush-train/src/shaders/rasterize_backwards.wgsl @@ -194,6 +194,9 @@ fn main( let v_colors_sum = subgroupAdd(v_colors); let v_refine_sum = subgroupAdd(v_refine); + // For super small subgroups, handle the last few elements seperately. + let is_last_in_sg = subgroup_invocation_id == subgroup_size - 1; + if subgroup_invocation_id == 0 { write_grads_atomic(compact_gid * 9 + 0, v_xy_sum.x); } @@ -218,13 +221,13 @@ fn main( else if subgroup_invocation_id == 7 { write_grads_atomic(compact_gid * 9 + 7, v_colors_sum.z); } - else if subgroup_invocation_id == 8 { + else if subgroup_invocation_id == 8 || subgroup_size <= 8 && is_last_in_sg { write_grads_atomic(compact_gid * 9 + 8, v_colors_sum.w); } - else if subgroup_invocation_id == 9 { + else if subgroup_invocation_id == 9 || subgroup_size <= 9 && is_last_in_sg { write_refine_atomic(compact_gid * 2 + 0, v_refine_sum.x); } - else if subgroup_invocation_id == 10 { + else if subgroup_invocation_id == 10 || subgroup_size <= 10 && is_last_in_sg { write_refine_atomic(compact_gid * 2 + 1, v_refine_sum.y); } } From ced0d459c8d62cf50aa07652efc7be52f4e4545d Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Sun, 23 Feb 2025 22:36:37 +0000 Subject: [PATCH 3/5] Typo --- crates/brush-train/src/shaders/rasterize_backwards.wgsl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/brush-train/src/shaders/rasterize_backwards.wgsl b/crates/brush-train/src/shaders/rasterize_backwards.wgsl index 3d4e218..1e7c3bd 100644 --- a/crates/brush-train/src/shaders/rasterize_backwards.wgsl +++ b/crates/brush-train/src/shaders/rasterize_backwards.wgsl @@ -194,7 +194,7 @@ fn main( let v_colors_sum = subgroupAdd(v_colors); let v_refine_sum = subgroupAdd(v_refine); - // For super small subgroups, handle the last few elements seperately. + // For super small subgroups, handle the last few elements separately. let is_last_in_sg = subgroup_invocation_id == subgroup_size - 1; if subgroup_invocation_id == 0 { From 9271b5e41345d2006b848ccaf4fcc2980767fe0d Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Sun, 23 Feb 2025 22:38:41 +0000 Subject: [PATCH 4/5] Only handle sg == 8 --- .../src/shaders/rasterize_backwards.wgsl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/crates/brush-train/src/shaders/rasterize_backwards.wgsl b/crates/brush-train/src/shaders/rasterize_backwards.wgsl index 1e7c3bd..862a616 100644 --- a/crates/brush-train/src/shaders/rasterize_backwards.wgsl +++ b/crates/brush-train/src/shaders/rasterize_backwards.wgsl @@ -107,6 +107,12 @@ fn main( v_out = v_output[pix_id]; } + // For super small subgroups, handle the last few elements separately. + let is_last_in_sg = subgroup_invocation_id == subgroup_size - 1; + // Subgroups of size 8 need to be handled separately as there's not enough threads to write + // all the gaussian fields. The next size (16) is fine. + let small_sg = subgroup_size == 8; + for (var b = 0; b < num_batches; b++) { // each thread fetch 1 gaussian from back to front // 0 index will be furthest back in batch @@ -194,8 +200,6 @@ fn main( let v_colors_sum = subgroupAdd(v_colors); let v_refine_sum = subgroupAdd(v_refine); - // For super small subgroups, handle the last few elements separately. - let is_last_in_sg = subgroup_invocation_id == subgroup_size - 1; if subgroup_invocation_id == 0 { write_grads_atomic(compact_gid * 9 + 0, v_xy_sum.x); @@ -221,13 +225,13 @@ fn main( else if subgroup_invocation_id == 7 { write_grads_atomic(compact_gid * 9 + 7, v_colors_sum.z); } - else if subgroup_invocation_id == 8 || subgroup_size <= 8 && is_last_in_sg { + else if subgroup_invocation_id == 8 || small_sg && is_last_in_sg { write_grads_atomic(compact_gid * 9 + 8, v_colors_sum.w); } - else if subgroup_invocation_id == 9 || subgroup_size <= 9 && is_last_in_sg { + else if subgroup_invocation_id == 9 || small_sg && is_last_in_sg { write_refine_atomic(compact_gid * 2 + 0, v_refine_sum.x); } - else if subgroup_invocation_id == 10 || subgroup_size <= 10 && is_last_in_sg { + else if subgroup_invocation_id == 10 || small_sg && is_last_in_sg { write_refine_atomic(compact_gid * 2 + 1, v_refine_sum.y); } } From 76ec8d6be9529fa39c10f30f6d500dd07e36a539 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Sun, 23 Feb 2025 22:56:37 +0000 Subject: [PATCH 5/5] Switch statement --- .../src/shaders/rasterize_backwards.wgsl | 63 +++++++------------ 1 file changed, 24 insertions(+), 39 deletions(-) diff --git a/crates/brush-train/src/shaders/rasterize_backwards.wgsl b/crates/brush-train/src/shaders/rasterize_backwards.wgsl index 862a616..20c40bc 100644 --- a/crates/brush-train/src/shaders/rasterize_backwards.wgsl +++ b/crates/brush-train/src/shaders/rasterize_backwards.wgsl @@ -107,12 +107,6 @@ fn main( v_out = v_output[pix_id]; } - // For super small subgroups, handle the last few elements separately. - let is_last_in_sg = subgroup_invocation_id == subgroup_size - 1; - // Subgroups of size 8 need to be handled separately as there's not enough threads to write - // all the gaussian fields. The next size (16) is fine. - let small_sg = subgroup_size == 8; - for (var b = 0; b < num_batches; b++) { // each thread fetch 1 gaussian from back to front // 0 index will be furthest back in batch @@ -200,39 +194,30 @@ fn main( let v_colors_sum = subgroupAdd(v_colors); let v_refine_sum = subgroupAdd(v_refine); - - if subgroup_invocation_id == 0 { - write_grads_atomic(compact_gid * 9 + 0, v_xy_sum.x); - } - else if subgroup_invocation_id == 1 { - write_grads_atomic(compact_gid * 9 + 1, v_xy_sum.y); - } - else if subgroup_invocation_id == 2 { - write_grads_atomic(compact_gid * 9 + 2, v_conic_sum.x); - } - else if subgroup_invocation_id == 3 { - write_grads_atomic(compact_gid * 9 + 3, v_conic_sum.y); - } - else if subgroup_invocation_id == 4 { - write_grads_atomic(compact_gid * 9 + 4, v_conic_sum.z); - } - else if subgroup_invocation_id == 5 { - write_grads_atomic(compact_gid * 9 + 5, v_colors_sum.x); - } - else if subgroup_invocation_id == 6 { - write_grads_atomic(compact_gid * 9 + 6, v_colors_sum.y); - } - else if subgroup_invocation_id == 7 { - write_grads_atomic(compact_gid * 9 + 7, v_colors_sum.z); - } - else if subgroup_invocation_id == 8 || small_sg && is_last_in_sg { - write_grads_atomic(compact_gid * 9 + 8, v_colors_sum.w); - } - else if subgroup_invocation_id == 9 || small_sg && is_last_in_sg { - write_refine_atomic(compact_gid * 2 + 0, v_refine_sum.x); - } - else if subgroup_invocation_id == 10 || small_sg && is_last_in_sg { - write_refine_atomic(compact_gid * 2 + 1, v_refine_sum.y); + switch subgroup_invocation_id { + case 0u: { write_grads_atomic(compact_gid * 9 + 0, v_xy_sum.x); } + case 1u: { write_grads_atomic(compact_gid * 9 + 1, v_xy_sum.y); } + case 2u: { write_grads_atomic(compact_gid * 9 + 2, v_conic_sum.x); } + case 3u: { write_grads_atomic(compact_gid * 9 + 3, v_conic_sum.y); } + case 4u: { write_grads_atomic(compact_gid * 9 + 4, v_conic_sum.z); } + case 5u: { write_grads_atomic(compact_gid * 9 + 5, v_colors_sum.x); } + case 6u: { write_grads_atomic(compact_gid * 9 + 6, v_colors_sum.y); } + case 7u: { + write_grads_atomic(compact_gid * 9 + 7, v_colors_sum.z); + + // Subgroups of size 8 need to be handled separately as there's not enough threads to write + // all the gaussian fields. The next size (16) is fine. + if subgroup_size == 8u { + write_grads_atomic(compact_gid * 9 + 8, v_colors_sum.w); + write_refine_atomic(compact_gid * 2 + 0, v_refine_sum.x); + write_refine_atomic(compact_gid * 2 + 1, v_refine_sum.y); + } + } + + case 8u: { write_grads_atomic(compact_gid * 9 + 8, v_colors_sum.w); } + case 9u: { write_refine_atomic(compact_gid * 2 + 0, v_refine_sum.x); } + case 10u: { write_refine_atomic(compact_gid * 2 + 1, v_refine_sum.y); } + default: {} } } }