Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify backwards rasterization #115

Merged
merged 5 commits into from
Feb 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 101 additions & 141 deletions crates/brush-train/src/shaders/rasterize_backwards.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,31 @@ const BATCH_SIZE = helpers::TILE_SIZE;
var<workgroup> local_batch: array<helpers::ProjectedSplat, BATCH_SIZE>;
var<workgroup> local_id: array<i32, BATCH_SIZE>;

// This kernel use a new technique to reduce the overhead of atomic gradient accumulation, especially when
// using software CAS loops this helps performance a lot. Originally, each thread calculated
// a gradient, summed them together in a subgroup, and one thread of these subgroups would then atomically add
// this gradient to the global gradient. Instead, we push each subgroup gradient to a buffer
// until it has N threads gradients, which are then written to the global gradients all at once.

// Current queue of gradients to be flushed.
var<workgroup> grad_count: atomic<i32>;

const TOTAL_GRADS = BATCH_SIZE * 11;
var<workgroup> gather_grads: array<f32, TOTAL_GRADS>;
var<workgroup> gather_grad_id: array<i32, BATCH_SIZE>;

fn add_bitcast(cur: u32, add: f32) -> u32 {
return bitcast<u32>(bitcast<f32>(cur) + add);
}

fn write_grads_atomic(grads: f32, id: i32) {
fn write_grads_atomic(id: i32, grads: f32) {
let p = &v_splats[id];
#ifdef HARD_FLOAT
atomicAdd(&v_splats[id], grads);
atomicAdd(p, grads);
#else
var old_value = atomicLoad(&v_splats[id]);
var old_value = atomicLoad(p);
loop {
let cas = atomicCompareExchangeWeak(&v_splats[id], old_value, add_bitcast(old_value, grads));
let cas = atomicCompareExchangeWeak(p, old_value, add_bitcast(old_value, grads));
if cas.exchanged { break; } else { old_value = cas.old_value; }
}
#endif
}

fn write_refine_atomic(id: i32, grads: f32) {
let p = &v_refine_grad[id];
#ifdef HARD_FLOAT
atomicAdd(p, grads);
#else
var old_value = atomicLoad(p);
loop {
let cas = atomicCompareExchangeWeak(p, old_value, add_bitcast(old_value, grads));
if cas.exchanged { break; } else { old_value = cas.old_value; }
}
#endif
Expand Down Expand Up @@ -106,163 +107,122 @@ fn main(
v_out = v_output[pix_id];
}

// Make sure all groups start with empty gradient queue.
atomicStore(&grad_count, 0);

let sg_per_tile = helpers::ceil_div(i32(helpers::TILE_SIZE), i32(subgroup_size));
let microbatch_size = i32(helpers::TILE_SIZE) / sg_per_tile;

for (var b = 0; b < num_batches; b++) {
// each thread fetch 1 gaussian from back to front
// 0 index will be furthest back in batch
// index of gaussian to load
let batch_end = range.y - b * i32(BATCH_SIZE);
let remaining = min(i32(BATCH_SIZE), batch_end - range.x);

// Gather N gaussians.
var load_compact_gid = 0;
// Each thread first gathers one gaussian.
if i32(local_idx) < remaining {
let load_isect_id = batch_end - 1 - i32(local_idx);
load_compact_gid = compact_gid_from_isect[load_isect_id];
let load_compact_gid = compact_gid_from_isect[load_isect_id];
local_id[local_idx] = load_compact_gid;
local_batch[local_idx] = projected_splats[load_compact_gid];
}

for (var tb = 0; tb < remaining; tb += microbatch_size) {
if local_idx == 0 {
atomicStore(&grad_count, 0);
}
workgroupBarrier();

for(var tt = 0; tt < microbatch_size; tt++) {
let t = tb + tt;

if t >= remaining {
break;
}

let isect_id = batch_end - 1 - t;
// Wait for all threads to finish loading.
workgroupBarrier();

var v_xy = vec2f(0.0);
var v_conic = vec3f(0.0);
var v_colors = vec4f(0.0);
var v_refine = vec2f(0.0);
for (var t = 0; t < remaining; t += 1) {
let isect_id = batch_end - 1 - t;

var splat_active = false;
var v_xy = vec2f(0.0);
var v_conic = vec3f(0.0);
var v_colors = vec4f(0.0);
var v_refine = vec2f(0.0);

if inside && isect_id < final_isect {
let projected = local_batch[t];
var splat_active = false;

let xy = vec2f(projected.xy_x, projected.xy_y);
let conic = vec3f(projected.conic_x, projected.conic_y, projected.conic_z);
let color = vec4f(projected.color_r, projected.color_g, projected.color_b, projected.color_a);
if inside && isect_id < final_isect {
let projected = local_batch[t];

let delta = xy - pixel_coord;
let sigma = 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + conic.y * delta.x * delta.y;
let vis = exp(-sigma);
let alpha = min(0.99f, color.w * vis);
let xy = vec2f(projected.xy_x, projected.xy_y);
let conic = vec3f(projected.conic_x, projected.conic_y, projected.conic_z);
let color = vec4f(projected.color_r, projected.color_g, projected.color_b, projected.color_a);

// Nb: Don't continue; here - local_idx == 0 always
// needs to write out gradients.
// compute the current T for this gaussian
if (sigma >= 0.0 && alpha >= 1.0 / 255.0) {
splat_active = true;
let delta = xy - pixel_coord;
let sigma = 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + conic.y * delta.x * delta.y;
let vis = exp(-sigma);
let alpha = min(0.99f, color.w * vis);

let ra = 1.0 / (1.0 - alpha);
T *= ra;
// update v_colors for this gaussian
let fac = alpha * T;
// Nb: Don't continue; here - local_idx == 0 always
// needs to write out gradients.
// compute the current T for this gaussian
if (sigma >= 0.0 && alpha >= 1.0 / 255.0) {
splat_active = true;

// contribution from this pixel
let clamped_rgb = max(color.rgb, vec3f(0.0));
var v_alpha = dot(clamped_rgb * T - buffer * ra, v_out.rgb);
v_alpha += T_final * ra * v_out.a;
let ra = 1.0 / (1.0 - alpha);
T *= ra;
// update v_colors for this gaussian
let fac = alpha * T;

// update the running sum
buffer += clamped_rgb * fac;
// contribution from this pixel
let clamped_rgb = max(color.rgb, vec3f(0.0));
var v_alpha = dot(clamped_rgb * T - buffer * ra, v_out.rgb);
v_alpha += T_final * ra * v_out.a;

let v_sigma = -color.a * vis * v_alpha;
// update the running sum
buffer += clamped_rgb * fac;

v_xy = v_sigma * vec2f(
conic.x * delta.x + conic.y * delta.y,
conic.y * delta.x + conic.z * delta.y
);
let v_sigma = -color.a * vis * v_alpha;

v_conic = vec3f(0.5f * v_sigma * delta.x * delta.x,
v_sigma * delta.x * delta.y,
0.5f * v_sigma * delta.y * delta.y);
v_xy = v_sigma * vec2f(
conic.x * delta.x + conic.y * delta.y,
conic.y * delta.x + conic.z * delta.y
);

let v_rgb = select(vec3f(0.0), fac * v_out.rgb, color.rgb > vec3f(0.0));
v_colors = vec4f(v_rgb, vis * v_alpha);
v_conic = vec3f(0.5f * v_sigma * delta.x * delta.x,
v_sigma * delta.x * delta.y,
0.5f * v_sigma * delta.y * delta.y);

v_refine = abs(v_xy);
}
}
let v_rgb = select(vec3f(0.0), fac * v_out.rgb, color.rgb > vec3f(0.0));
v_colors = vec4f(v_rgb, vis * v_alpha);

// Queue a new gradient if this subgroup has any.
// The gradient is sum of all gradients in the subgroup.
if subgroupAny(splat_active) {
let v_xy_sum = subgroupAdd(v_xy);
let v_conic_sum = subgroupAdd(v_conic);
let v_colors_sum = subgroupAdd(v_colors);
let v_refine_sum = subgroupAdd(v_refine);

// First thread of subgroup writes the gradient. This should be a
// subgroupBallot() when it's supported.
if subgroup_invocation_id == 0 {
let grad_idx = atomicAdd(&grad_count, 1);
gather_grads[grad_idx * 11 + 0] = v_xy_sum.x;
gather_grads[grad_idx * 11 + 1] = v_xy_sum.y;
gather_grads[grad_idx * 11 + 2] = v_conic_sum.x;
gather_grads[grad_idx * 11 + 3] = v_conic_sum.y;
gather_grads[grad_idx * 11 + 4] = v_conic_sum.z;
gather_grads[grad_idx * 11 + 5] = v_colors_sum.x;
gather_grads[grad_idx * 11 + 6] = v_colors_sum.y;
gather_grads[grad_idx * 11 + 7] = v_colors_sum.z;
gather_grads[grad_idx * 11 + 8] = v_colors_sum.w;

gather_grads[grad_idx * 11 + 9] = v_refine_sum.x;
gather_grads[grad_idx * 11 + 10] = v_refine_sum.y;

gather_grad_id[grad_idx] = local_id[t];
}
v_refine = abs(v_xy);
}
}

// Make sure all threads are done, and flush a batch of gradients.
workgroupBarrier();
if local_idx < u32(grad_count) {
let compact_gid = gather_grad_id[local_idx];
write_grads_atomic(gather_grads[local_idx * 11 + 0], compact_gid * 9 + 0);
write_grads_atomic(gather_grads[local_idx * 11 + 1], compact_gid * 9 + 1);
write_grads_atomic(gather_grads[local_idx * 11 + 2], compact_gid * 9 + 2);
write_grads_atomic(gather_grads[local_idx * 11 + 3], compact_gid * 9 + 3);
write_grads_atomic(gather_grads[local_idx * 11 + 4], compact_gid * 9 + 4);
write_grads_atomic(gather_grads[local_idx * 11 + 5], compact_gid * 9 + 5);
write_grads_atomic(gather_grads[local_idx * 11 + 6], compact_gid * 9 + 6);
write_grads_atomic(gather_grads[local_idx * 11 + 7], compact_gid * 9 + 7);
write_grads_atomic(gather_grads[local_idx * 11 + 8], compact_gid * 9 + 8);

let refine_grad_x = gather_grads[local_idx * 11 + 9];
let refine_grad_y = gather_grads[local_idx * 11 + 10];

#ifdef HARD_FLOAT
atomicAdd(&v_refine_grad[compact_gid * 2 + 0], refine_grad_x);
atomicAdd(&v_refine_grad[compact_gid * 2 + 1], refine_grad_y);
#else
var old_value = atomicLoad(&v_refine_grad[compact_gid * 2 + 0]);
loop {
let cas = atomicCompareExchangeWeak(&v_refine_grad[compact_gid * 2 + 0], old_value, add_bitcast(old_value, refine_grad_x));
if cas.exchanged { break; } else { old_value = cas.old_value; }
}
old_value = atomicLoad(&v_refine_grad[compact_gid * 2 + 1]);
loop {
let cas = atomicCompareExchangeWeak(&v_refine_grad[compact_gid * 2 + 1], old_value, add_bitcast(old_value, refine_grad_y));
if cas.exchanged { break; } else { old_value = cas.old_value; }
// Queue a new gradient if this subgroup has any.
// The gradient is sum of all gradients in the subgroup.
if subgroupAny(splat_active) {
let compact_gid = local_id[t];

let v_xy_sum = subgroupAdd(v_xy);
let v_conic_sum = subgroupAdd(v_conic);
let v_colors_sum = subgroupAdd(v_colors);
let v_refine_sum = subgroupAdd(v_refine);

switch subgroup_invocation_id {
case 0u: { write_grads_atomic(compact_gid * 9 + 0, v_xy_sum.x); }
case 1u: { write_grads_atomic(compact_gid * 9 + 1, v_xy_sum.y); }
case 2u: { write_grads_atomic(compact_gid * 9 + 2, v_conic_sum.x); }
case 3u: { write_grads_atomic(compact_gid * 9 + 3, v_conic_sum.y); }
case 4u: { write_grads_atomic(compact_gid * 9 + 4, v_conic_sum.z); }
case 5u: { write_grads_atomic(compact_gid * 9 + 5, v_colors_sum.x); }
case 6u: { write_grads_atomic(compact_gid * 9 + 6, v_colors_sum.y); }
case 7u: {
write_grads_atomic(compact_gid * 9 + 7, v_colors_sum.z);

// Subgroups of size 8 need to be handled separately as there's not enough threads to write
// all the gaussian fields. The next size (16) is fine.
if subgroup_size == 8u {
write_grads_atomic(compact_gid * 9 + 8, v_colors_sum.w);
write_refine_atomic(compact_gid * 2 + 0, v_refine_sum.x);
write_refine_atomic(compact_gid * 2 + 1, v_refine_sum.y);
}
}
#endif

case 8u: { write_grads_atomic(compact_gid * 9 + 8, v_colors_sum.w); }
case 9u: { write_refine_atomic(compact_gid * 2 + 0, v_refine_sum.x); }
case 10u: { write_refine_atomic(compact_gid * 2 + 1, v_refine_sum.y); }
default: {}
}
}
workgroupBarrier();
}

// Wait for all gradients to be written.
workgroupBarrier();
}
}
Loading