Skip to content

Commit

Permalink
AbsGS, fix some rotations de-normalizing over training (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurBrussee authored Feb 23, 2025
1 parent 8c14a69 commit d7b7280
Show file tree
Hide file tree
Showing 14 changed files with 219 additions and 251 deletions.
3 changes: 1 addition & 2 deletions crates/brush-dataset/src/splat_export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ async fn read_splat_data<B: Backend>(splats: Splats<B>) -> Result<Vec<GaussianDa
}

pub async fn splat_to_ply<B: Backend>(splats: Splats<B>) -> anyhow::Result<Vec<u8>> {
let mut splats = splats;
splats.norm_rotations();
let splats = splats.with_normed_rotations();

let data = read_splat_data(splats.clone())
.await
Expand Down
6 changes: 3 additions & 3 deletions crates/brush-dataset/src/splat_import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,14 +411,14 @@ pub fn load_splat_from_ply<T: AsyncRead + Unpin + 'static, B: Backend>(
splats.log_scales.val()
};

let mut new_splat = Splats::from_tensor_data(
let new_splat = Splats::from_tensor_data(
means,
rotations,
log_scales,
splats.sh_coeffs.val(),
splats.raw_opacity.val(),
);
new_splat.norm_rotations();
)
.with_normed_rotations();

// Emit newly animated splat.
emitter
Expand Down
11 changes: 8 additions & 3 deletions crates/brush-render/src/gaussian_splats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ pub struct Splats<B: Backend> {
}

fn norm_vec<B: Backend>(vec: Tensor<B, 2>) -> Tensor<B, 2> {
vec.clone() / Tensor::clamp_min(Tensor::sum_dim(vec.powf_scalar(2.0), 1).sqrt(), 1e-12)
let magnitudes = Tensor::clamp_min(
Tensor::sum_dim(vec.clone().powf_scalar(2.0), 1).sqrt(),
1e-32,
);
vec / magnitudes
}

pub fn inverse_sigmoid(x: f32) -> f32 {
Expand Down Expand Up @@ -228,8 +232,9 @@ impl<B: Backend> Splats<B> {
norm_vec(self.rotation.val())
}

pub fn norm_rotations(&mut self) {
self.rotation = self.rotation.clone().map(|r| norm_vec(r));
pub fn with_normed_rotations(mut self) -> Self {
self.rotation = self.rotation.map(|r| norm_vec(r));
self
}

pub fn sh_degree(&self) -> u32 {
Expand Down
16 changes: 13 additions & 3 deletions crates/brush-render/src/shaders/project_forward.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

@group(0) @binding(7) var<storage, read_write> radii: array<f32>;

const INV_SIGMOID_THRESH: f32 = -5.537334267018537;

@compute
@workgroup_size(helpers::MAIN_WG, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3u) {
Expand All @@ -37,11 +39,19 @@ fn main(@builtin(global_invocation_id) global_id: vec3u) {
}

let scale = exp(helpers::as_vec(log_scales[global_gid]));
let quat = normalize(quats[global_gid]);
var quat = quats[global_gid];

// Skip any invalid rotations. This will mean overtime
// these gaussians just die off while optimizing. For the viewer, the importer
// atm always normalizes the quaternions.
if length(quat) < 1e-32 {
return;
}
quat = normalize(quat);

let raw_opac = raw_opacities[global_gid];

// inv_sigmoid(1.0 / 255.0);
if raw_opac <= -5.537 {
if raw_opac < INV_SIGMOID_THRESH {
return;
}

Expand Down
1 change: 1 addition & 0 deletions crates/brush-render/src/shaders/project_visible.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ fn main(@builtin(global_invocation_id) gid: vec3u) {
// Project world space to camera space.
let mean = helpers::as_vec(means[global_gid]);
let scale = exp(helpers::as_vec(log_scales[global_gid]));
// Safe to normalize, splats with length(quat) == 0 are invisible.
let quat = normalize(quats[global_gid]);
let opac = helpers::sigmoid(raw_opacities[global_gid]);

Expand Down
24 changes: 12 additions & 12 deletions crates/brush-train/src/burn_glue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl<B: Backend + SplatBackwardOps<B>> Backward<B, NUM_ARGS> for RenderBackwards

// Register gradients for parent nodes (This code is already skipped entirely
// if no parent nodes require gradients).
let [mean_parent, xys_parent, log_scales_parent, quats_parent, coeffs_parent, raw_opacity_parent] =
let [mean_parent, refine_weight, log_scales_parent, quats_parent, coeffs_parent, raw_opacity_parent] =
ops.parents;

let v_tens = B::render_splats_bwd(state, v_output);
Expand All @@ -135,8 +135,8 @@ impl<B: Backend + SplatBackwardOps<B>> Backward<B, NUM_ARGS> for RenderBackwards
}

// Register the gradients for the dummy xy input.
if let Some(node) = xys_parent {
grads.register::<B>(node.id, v_tens.v_xy);
if let Some(node) = refine_weight {
grads.register::<B>(node.id, v_tens.v_refine_weight);
}

if let Some(node) = log_scales_parent {
Expand All @@ -160,7 +160,7 @@ impl<B: Backend + SplatBackwardOps<B>> Backward<B, NUM_ARGS> for RenderBackwards
pub struct SplatOutputDiff<B: Backend> {
pub img: FloatTensor<B>,
pub aux: RenderAuxPrimitive<B>,
pub xy_grad_holder: Tensor<B, 2>,
pub refine_weight_holder: Tensor<B, 1>,
}

// Implement
Expand All @@ -180,13 +180,13 @@ impl<B: Backend + SplatBackwardOps<B> + SplatForward<B>, C: CheckpointStrategy>
// in the future.
let device =
Tensor::<Self, 2>::from_primitive(TensorPrimitive::Float(means.clone())).device();
let xy_grad_holder = Tensor::<Self, 2>::zeros([1, 2], &device).require_grad();
let refine_weight_holder = Tensor::<Self, 1>::zeros([1], &device).require_grad();

// Prepare backward pass, and check if we even need to do it. Store nodes that need gradients.
let prep_nodes = RenderBackwards
.prepare::<C>([
means.node.clone(),
xy_grad_holder.clone().into_primitive().tensor().node,
refine_weight_holder.clone().into_primitive().tensor().node,
log_scales.node.clone(),
quats.node.clone(),
sh_coeffs.node.clone(),
Expand Down Expand Up @@ -245,7 +245,7 @@ impl<B: Backend + SplatBackwardOps<B> + SplatForward<B>, C: CheckpointStrategy>
SplatOutputDiff {
img: out_img,
aux: wrapped_aux,
xy_grad_holder,
refine_weight_holder,
}
}
OpsKind::UnTracked(prep) => {
Expand All @@ -254,7 +254,7 @@ impl<B: Backend + SplatBackwardOps<B> + SplatForward<B>, C: CheckpointStrategy>
SplatOutputDiff {
img: prep.finish(out_img),
aux: wrapped_aux,
xy_grad_holder,
refine_weight_holder,
}
}
}
Expand All @@ -277,7 +277,7 @@ impl<F: FloatElement, I: IntElement, BT: BoolElement> SplatBackwardOps<Self>
Operation<FusionJitRuntime<WgpuRuntime, BT>> for CustomOp<F, I, BT>
{
fn execute(self: Box<Self>, h: &mut HandleContainer<JitFusionHandle<WgpuRuntime>>) {
let ([v_output], [v_means, v_quats, v_scales, v_coeffs, v_raw_opac, v_xy]) =
let ([v_output], [v_means, v_quats, v_scales, v_coeffs, v_raw_opac, v_refine]) =
self.desc.consume();

let state = self.state;
Expand Down Expand Up @@ -315,7 +315,7 @@ impl<F: FloatElement, I: IntElement, BT: BoolElement> SplatBackwardOps<Self>
h.register_float_tensor::<BBase<F, I, BT>>(&v_scales.id, grads.v_scales);
h.register_float_tensor::<BBase<F, I, BT>>(&v_coeffs.id, grads.v_coeffs);
h.register_float_tensor::<BBase<F, I, BT>>(&v_raw_opac.id, grads.v_raw_opac);
h.register_float_tensor::<BBase<F, I, BT>>(&v_xy.id, grads.v_xy);
h.register_float_tensor::<BBase<F, I, BT>>(&v_refine.id, grads.v_refine_weight);
}
}

Expand All @@ -332,7 +332,7 @@ impl<F: FloatElement, I: IntElement, BT: BoolElement> SplatBackwardOps<Self>
v_scales: client.tensor_uninitialized(vec![num_points, 3], DType::F32),
v_coeffs: client.tensor_uninitialized(vec![num_points, coeffs, 3], DType::F32),
v_raw_opac: client.tensor_uninitialized(vec![num_points], DType::F32),
v_xy: client.tensor_uninitialized(vec![num_points, 2], DType::F32),
v_refine_weight: client.tensor_uninitialized(vec![num_points, 2], DType::F32),
};

let desc = CustomOpIr::new(
Expand All @@ -344,7 +344,7 @@ impl<F: FloatElement, I: IntElement, BT: BoolElement> SplatBackwardOps<Self>
grads.v_scales.to_ir_out(),
grads.v_coeffs.to_ir_out(),
grads.v_raw_opac.to_ir_out(),
grads.v_xy.to_ir_out(),
grads.v_refine_weight.to_ir_out(),
],
);

Expand Down
22 changes: 9 additions & 13 deletions crates/brush-train/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ pub struct SplatGrads<B: Backend> {
pub v_scales: FloatTensor<B>,
pub v_coeffs: FloatTensor<B>,
pub v_raw_opac: FloatTensor<B>,
pub v_xy: FloatTensor<B>,

pub v_refine_weight: FloatTensor<B>,
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -51,8 +52,6 @@ pub(crate) fn render_backward<F: FloatElement, I: IntElement, BT: BoolElement>(

// Nb: these are packed vec3 values, special care is taken in the kernel to respect alignment.
// Nb: These have to be zeroed out - as we only write to visible splats.
//
let v_xys_local = BBase::<F, I, BT>::float_zeros([num_points, 2].into(), device);
let v_means = BBase::<F, I, BT>::float_zeros([num_points, 3].into(), device);
let v_scales = BBase::<F, I, BT>::float_zeros([num_points, 3].into(), device);
let v_quats = BBase::<F, I, BT>::float_zeros([num_points, 4].into(), device);
Expand All @@ -74,8 +73,8 @@ pub(crate) fn render_backward<F: FloatElement, I: IntElement, BT: BoolElement>(
let invocations = tile_bounds.x * tile_bounds.y;

// These gradients are atomically added to so important to zero them.
let v_conics = BBase::<F, I, BT>::float_zeros([num_points, 3].into(), device);
let v_colors = BBase::<F, I, BT>::float_zeros([num_points, 4].into(), device);
let v_grads = BBase::<F, I, BT>::float_zeros([num_points, 9].into(), device);
let v_refine_weight = BBase::<F, I, BT>::float_zeros([num_points, 2].into(), device);

let hard_floats = client
.properties()
Expand All @@ -95,13 +94,11 @@ pub(crate) fn render_backward<F: FloatElement, I: IntElement, BT: BoolElement>(
final_index.handle.binding(),
out_img.handle.binding(),
v_output.handle.binding(),
v_xys_local.clone().handle.binding(),
v_conics.clone().handle.binding(),
v_colors.clone().handle.binding(),
v_grads.clone().handle.binding(),
v_refine_weight.clone().handle.binding(),
],
);
});

let _span = tracing::trace_span!("GatherGrads", sync_burn = true).entered();

// SAFETY: Kernel has to contain no OOB indexing.
Expand All @@ -114,7 +111,7 @@ pub(crate) fn render_backward<F: FloatElement, I: IntElement, BT: BoolElement>(
global_from_compact_gid.clone().handle.binding(),
raw_opac.handle.binding(),
means.clone().handle.binding(),
v_colors.handle.binding(),
v_grads.clone().handle.binding(),
v_coeffs.handle.clone().binding(),
v_raw_opac.handle.clone().binding(),
],
Expand All @@ -133,8 +130,7 @@ pub(crate) fn render_backward<F: FloatElement, I: IntElement, BT: BoolElement>(
log_scales.handle.binding(),
quats.handle.binding(),
global_from_compact_gid.handle.binding(),
v_xys_local.handle.clone().binding(),
v_conics.handle.binding(),
v_grads.handle.binding(),
v_means.handle.clone().binding(),
v_scales.handle.clone().binding(),
v_quats.handle.clone().binding(),
Expand All @@ -148,6 +144,6 @@ pub(crate) fn render_backward<F: FloatElement, I: IntElement, BT: BoolElement>(
v_scales,
v_coeffs,
v_raw_opac,
v_xy: v_xys_local,
v_refine_weight,
}
}
12 changes: 7 additions & 5 deletions crates/brush-train/src/shaders/gather_grads.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

@group(0) @binding(2) var<storage, read> raw_opacities: array<f32>;
@group(0) @binding(3) var<storage, read> means: array<helpers::PackedVec3>;
@group(0) @binding(4) var<storage, read> v_colors: array<vec4f>;

@group(0) @binding(4) var<storage, read> v_grads: array<f32>;

@group(0) @binding(5) var<storage, read_write> v_coeffs: array<f32>;
@group(0) @binding(6) var<storage, read_write> v_opacs: array<f32>;
Expand Down Expand Up @@ -170,7 +171,8 @@ fn main(@builtin(global_invocation_id) gid: vec3u) {
}

// Load colors gradients.
var v_color = v_colors[compact_gid];
let v_color = vec3f(v_grads[compact_gid * 9 + 5], v_grads[compact_gid * 9 + 6], v_grads[compact_gid * 9 + 7]);
let v_opac = v_grads[compact_gid * 9 + 8];

// Convert RGB to global SH gradients.
let global_gid = global_from_compact_gid[compact_gid];
Expand All @@ -179,7 +181,7 @@ fn main(@builtin(global_invocation_id) gid: vec3u) {
let viewdir = normalize(mean - uniforms.camera_position.xyz);

let sh_degree = uniforms.sh_degree;
let v_coeff = sh_coeffs_to_color_fast_vjp(sh_degree, viewdir, v_color.xyz);
let v_coeff = sh_coeffs_to_color_fast_vjp(sh_degree, viewdir, v_color);
let num_coeffs = num_sh_coeffs(sh_degree);
var base_id = global_gid * i32(num_coeffs) * 3;

Expand Down Expand Up @@ -219,6 +221,6 @@ fn main(@builtin(global_invocation_id) gid: vec3u) {

// Transform alpha gradient to opacity gradient.
let raw_opac = raw_opacities[global_gid];
let v_opac = v_color.w * v_sigmoid(raw_opac);
v_opacs[global_gid] = v_opac;
let v_opac_raw = v_opac * v_sigmoid(raw_opac);
v_opacs[global_gid] = v_opac_raw;
}
20 changes: 10 additions & 10 deletions crates/brush-train/src/shaders/project_backwards.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,22 @@

@group(0) @binding(4) var<storage, read> global_from_compact_gid: array<i32>;

@group(0) @binding(5) var<storage, read> v_xys: array<vec2f>;
@group(0) @binding(6) var<storage, read> v_conics: array<helpers::PackedVec3>;

@group(0) @binding(7) var<storage, read_write> v_means: array<helpers::PackedVec3>;
@group(0) @binding(8) var<storage, read_write> v_scales: array<helpers::PackedVec3>;
@group(0) @binding(9) var<storage, read_write> v_quats: array<vec4f>;
@group(0) @binding(5) var<storage, read> v_grads: array<f32>;

@group(0) @binding(6) var<storage, read_write> v_means: array<helpers::PackedVec3>;
@group(0) @binding(7) var<storage, read_write> v_scales: array<helpers::PackedVec3>;
@group(0) @binding(8) var<storage, read_write> v_quats: array<vec4f>;

// TODO: What do for quat len == 0.0?
fn normalize_vjp(quat: vec4f) -> mat4x4f {
let quat_sqr = quat * quat;
let quat_len_sqr = dot(quat, quat);
let quat_len = length(quat_len_sqr);
let quat_len = sqrt(quat_len_sqr);

let cross_complex = -quat.xyz * quat.yzx;
let cross_scalar = -quat.xyz * quat.w;

return mat4x4<f32>(
return mat4x4f(
vec4f(quat_len_sqr - quat_sqr.x, cross_complex.x, cross_complex.z, cross_scalar.x),
vec4f(cross_complex.x, quat_len_sqr - quat_sqr.y, cross_complex.y, cross_scalar.y),
vec4f(cross_complex.z, cross_complex.y, quat_len_sqr - quat_sqr.z, cross_scalar.z),
Expand Down Expand Up @@ -165,10 +164,11 @@ fn main(@builtin(global_invocation_id) gid: vec3u) {
let mean = helpers::as_vec(means[global_gid]);
let scale = exp(helpers::as_vec(log_scales[global_gid]));
let quat_unorm = quats[global_gid];
// Safe to normalize, quats with norm 0 are invisible.
let quat = normalize(quat_unorm);

let v_conics = helpers::as_vec(v_conics[compact_gid]);
let v_mean2d = v_xys[compact_gid];
let v_mean2d = vec2f(v_grads[compact_gid * 9 + 0], v_grads[compact_gid * 9 + 1]);
let v_conics = vec3f(v_grads[compact_gid * 9 + 2], v_grads[compact_gid * 9 + 3], v_grads[compact_gid * 9 + 4]);

let R = mat3x3f(viewmat[0].xyz, viewmat[1].xyz, viewmat[2].xyz);
let mean_c = R * mean + viewmat[3].xyz;
Expand Down
Loading

0 comments on commit d7b7280

Please sign in to comment.