diff --git a/crates/brush-process/src/rerun_tools.rs b/crates/brush-process/src/rerun_tools.rs index b296133..95d9659 100644 --- a/crates/brush-process/src/rerun_tools.rs +++ b/crates/brush-process/src/rerun_tools.rs @@ -261,14 +261,13 @@ impl VisualizeTools { if let Some(rec) = self.rec.as_ref() { if rec.is_enabled() { rec.set_time_sequence("iterations", iter); - let _ = rec.log( - "refine/num_transparent_pruned", - &rerun::Scalar::new(refine.num_transparent_pruned as f64), + "refine/num_added", + &rerun::Scalar::new(refine.num_added as f64), ); let _ = rec.log( - "refine/num_scale_pruned", - &rerun::Scalar::new(refine.num_scale_pruned as f64), + "refine/num_pruned", + &rerun::Scalar::new(refine.num_pruned as f64), ); } } diff --git a/crates/brush-train/src/train.rs b/crates/brush-train/src/train.rs index a585ec9..4cf76ed 100644 --- a/crates/brush-train/src/train.rs +++ b/crates/brush-train/src/train.rs @@ -1,3 +1,5 @@ +use std::f32::consts::SQRT_2; + use anyhow::Result; use brush_render::gaussian_splats::{Splats, inverse_sigmoid}; use brush_render::render::sh_coeffs_for_degree; @@ -59,8 +61,8 @@ pub struct TrainConfig { lr_coeffs_dc: f64, /// How much to divide the learning rate by for higher SH orders. - #[config(default = 20.0)] - #[arg(long, help_heading = "Training options", default_value = "20.0")] + #[config(default = 40.0)] + #[arg(long, help_heading = "Training options", default_value = "40.0")] lr_coeffs_sh_scale: f32, /// Learning rate for the opacity. @@ -89,8 +91,8 @@ pub struct TrainConfig { opac_refine_subtract: f32, /// Threshold for positional gradient norm - #[config(default = 0.0006)] - #[arg(long, help_heading = "Refine options", default_value = "0.0006")] + #[config(default = 0.0008)] + #[arg(long, help_heading = "Refine options", default_value = "0.0008")] densify_grad_thresh: f32, /// Gaussians bigger than this size in screenspace radius are split @@ -98,11 +100,6 @@ pub struct TrainConfig { #[arg(long, help_heading = "Refine options", default_value = "0.1")] densify_radius_threshold: f32, - /// Below this size, gaussians are cloned, otherwise split - #[config(default = 0.01)] - #[arg(long, help_heading = "Refine options", default_value = "0.01")] - densify_size_threshold: f32, - /// Gaussians bigger than this size in percent of the scene extent are culled #[config(default = 0.8)] #[arg(long, help_heading = "Refine options", default_value = "0.8")] @@ -145,10 +142,8 @@ pub struct SceneBatch { #[derive(Clone)] pub struct RefineStats { - pub num_split: u32, - pub num_cloned: u32, - pub num_transparent_pruned: u32, - pub num_scale_pruned: u32, + pub num_added: u32, + pub num_pruned: u32, } #[derive(Clone)] @@ -444,8 +439,15 @@ impl SplatTrainer { // This is slightly wrong wrt to adam gradients, but that's fine. let splats = splats.with_normed_rotations(); + // Skip a refine after every reset. + let time_per_reset = self.config.reset_alpha_every_refine * self.config.refine_every; + let time_since_reset = iter % time_per_reset; + // If not refining, update splat to step with gradients applied. - if iter >= self.config.refine_start_iter && iter < self.config.refine_stop_iter { + if iter >= self.config.refine_start_iter + && iter < self.config.refine_stop_iter + && time_since_reset > self.config.refine_every + { let (splats, refine) = self.refine_splats(iter, splats, scene_extent).await; (splats, Some(refine)) } else { @@ -462,6 +464,8 @@ impl SplatTrainer { splats: Splats, scene_extent: f32, ) -> (Splats, RefineStats) { + let device = splats.means.device(); + let mut record = self .optim .take() @@ -474,86 +478,24 @@ impl SplatTrainer { .expect("Can only refin if refin stats are initialized"); // Otherwise, do refinement, but do the split/clone on gaussians with no grads applied. - let avg_grad = refiner.refine_weight_norm / refiner.visible_counts.clamp_min(1).float(); + let avg_refine_grad = + refiner.refine_weight_norm / refiner.visible_counts.clamp_min(1).float(); let mut splats = splats; - let device = splats.means.device(); - - let is_grad_high = avg_grad.greater_equal_elem(self.config.densify_grad_thresh); - let split_clone_size_mask = splats - .scales() - .inner() - .max_dim(1) - .squeeze(1) - .lower_elem(self.config.densify_size_threshold * scene_extent); - - let mut append_means = vec![]; - let mut append_rots = vec![]; - let mut append_coeffs = vec![]; - let mut append_opac = vec![]; - let mut append_scales = vec![]; - - let clone_mask = - Tensor::stack::<2>(vec![is_grad_high.clone(), split_clone_size_mask.clone()], 1) - .all_dim(1) - .squeeze::<1>(1); - - let clone_inds = clone_mask.clone().argwhere_async().await; - - // Clone splats - let clone_count = clone_inds.dims()[0] as u32; - if clone_count > 0 { - let clone_inds = clone_inds.squeeze(1); - let cur_means = splats.means.val().inner().select(0, clone_inds.clone()); - let cur_rots = splats.rotation.val().inner().select(0, clone_inds.clone()); - let cur_scale = splats - .log_scales - .val() - .inner() - .select(0, clone_inds.clone()); - let cur_coeff = splats.sh_coeffs.val().inner().select(0, clone_inds.clone()); - let cur_raw_opac = splats.raw_opacity.val().inner().select(0, clone_inds); - - let samples = quaternion_vec_multiply( - cur_rots.clone(), - Tensor::random( - [clone_count as usize, 3], - Distribution::Normal(0.0, 1.0), - &device, - ) * cur_scale.clone().exp(), - ); - - append_means.push(cur_means + samples); - append_rots.push(cur_rots); - append_scales.push(cur_scale); - append_coeffs.push(cur_coeff); - append_opac.push(cur_raw_opac); - } - - // Split splats. - let split_mask = Tensor::stack::<2>( - vec![is_grad_high.clone(), split_clone_size_mask.bool_not()], - 1, - ) - .all_dim(1) - .squeeze::<1>(1); + let refine_over_threshold = + avg_refine_grad.greater_equal_elem(self.config.densify_grad_thresh); let radii_grow = refiner .max_radii .greater_elem(self.config.densify_radius_threshold); - - let split_mask = Tensor::stack::<2>(vec![split_mask, radii_grow], 1) - .any_dim(1) - .squeeze::<1>(1); - + let split_mask = refine_over_threshold.bool_or(radii_grow); let split_inds = split_mask.clone().argwhere_async().await; - let split_count = split_inds.dims()[0] as u32; + let split_count = split_inds.dims()[0]; if split_count > 0 { let split_inds = split_inds.squeeze(1); - // Some parts can be straightforwardly copied to the new splats. let cur_means = splats.means.val().inner().select(0, split_inds.clone()); let cur_coeff = splats.sh_coeffs.val().inner().select(0, split_inds.clone()); let cur_raw_opac = splats @@ -562,72 +504,90 @@ impl SplatTrainer { .inner() .select(0, split_inds.clone()); let cur_rots = splats.rotation.val().inner().select(0, split_inds.clone()); - let cur_scale = splats.log_scales.val().inner().select(0, split_inds); + let cur_log_scale = splats.log_scales.val().inner().select(0, split_inds); + // The amount to offset the scale and opacity should maybe depend on how far away we have sampled these gaussians. let samples = quaternion_vec_multiply( cur_rots.clone(), - Tensor::random( - [split_count as usize, 3], - Distribution::Normal(0.0, 1.0), - &device, - ) * cur_scale.clone().exp(), + Tensor::random([split_count, 3], Distribution::Normal(0.0, 1.0), &device) + * cur_log_scale.clone().exp(), ); - let scale_div: f32 = 1.6; + // Delete the current points + (splats, _) = prune_points(splats, &mut record, split_mask).await; - append_means.push(cur_means.clone() + samples.clone()); - append_rots.push(cur_rots.clone()); - append_scales.push(cur_scale.clone() - scale_div.ln()); - append_coeffs.push(cur_coeff.clone()); - append_opac.push(cur_raw_opac.clone()); + // Add in the new points. + let scale_div = cur_log_scale + .clone() + .exp() + .greater_elem(0.01 * scene_extent) + .any_dim(1) + .float() + * 0.6 + + 1.0; - append_means.push(cur_means - samples); - append_rots.push(cur_rots); - append_scales.push(cur_scale - scale_div.ln()); - append_coeffs.push(cur_coeff); - append_opac.push(cur_raw_opac); - } + let new_log_scale = cur_log_scale.clone() - scale_div.log(); - (splats, _) = prune_points(splats, &mut record, split_mask.clone()).await; + // let one = Tensor::ones([1], &device); + // let new_opac = inv_sigmoid( + // (one.clone() - (one - sigmoid(cur_raw_opac))) + // .clamp(1e-16, 1.0 - 1e-16) + // .sqrt(), + // ); + + let cur_count = splats.means.dims()[0]; + let sh_dim = splats.sh_coeffs.dims()[1]; + + splats = map_splats_and_opt( + splats, + &mut record, + |x| Tensor::cat(vec![x, cur_means.clone(), cur_means + samples], 0), + |x| Tensor::cat(vec![x, cur_rots.clone(), cur_rots], 0), + |x| Tensor::cat(vec![x, new_log_scale.clone(), new_log_scale], 0), + |x| Tensor::cat(vec![x, cur_coeff.clone(), cur_coeff], 0), + |x| Tensor::cat(vec![x, cur_raw_opac.clone(), cur_raw_opac], 0), + |x| { + Tensor::zeros([cur_count + split_count, 3], &device) + .slice_assign([0..cur_count, 0..3], x) + }, + |x| { + Tensor::zeros([cur_count + split_count, 4], &device) + .slice_assign([0..cur_count, 0..4], x) + }, + |x| { + Tensor::zeros([cur_count + split_count, 3], &device) + .slice_assign([0..cur_count, 0..3], x) + }, + |x| { + Tensor::zeros([cur_count + split_count, sh_dim, 3], &device) + .slice_assign([0..cur_count, 0..sh_dim, 0..3], x) + }, + |x| { + Tensor::zeros([cur_count + split_count], &device) + .slice_assign([0..cur_count], x) + }, + ); + } // Do some more processing. Important to do this last as otherwise you might mess up the correspondence // of gradient <-> splat. - + // // Remove barely visible gaussians. let alpha_mask = splats .raw_opacity .val() .inner() .lower_elem(inverse_sigmoid(MIN_OPACITY)); - let (splats, alpha_pruned) = prune_points(splats, &mut record, alpha_mask).await; // Delete Gaussians with too large of a radius in world-units. - let scale_big = splats + let scale_mask = splats .log_scales .val() .inner() .greater_elem((self.config.cull_scale3d_percentage_threshold * scene_extent).ln()); - let scale_mask = Tensor::any_dim(scale_big, 1).squeeze(1); - let (mut splats, scale_pruned) = prune_points(splats, &mut record, scale_mask).await; - - if !append_means.is_empty() { - let append_means = Tensor::cat(append_means, 0); - let append_rots = Tensor::cat(append_rots, 0); - let append_coeffs = Tensor::cat(append_coeffs, 0); - let append_opac = Tensor::cat(append_opac, 0); - let append_scales = Tensor::cat(append_scales, 0); - - splats = concat_splats( - splats, - &mut record, - append_means, - append_rots, - append_scales, - append_coeffs, - append_opac, - ); - } + let prune_mask = alpha_mask.bool_or(Tensor::any_dim(scale_mask, 1).squeeze(1)); + let (mut splats, pruned) = prune_points(splats, &mut record, prune_mask).await; let refine_step = iter / self.config.refine_every; if refine_step % self.config.reset_alpha_every_refine == 0 { @@ -638,13 +598,8 @@ impl SplatTrainer { Tensor::zeros_like(&s) }); } else { - // Skip a refine after every reset. - let time_per_reset = self.config.reset_alpha_every_refine * self.config.refine_every; - let time_since_reset = iter % time_per_reset; - // Slowly lower opacity. - if self.config.opac_refine_subtract > 0.0 && time_since_reset > self.config.refine_every - { + if self.config.opac_refine_subtract > 0.0 { splats.raw_opacity = splats.raw_opacity.map(|op| { let op = op.inner(); Tensor::from_inner(inv_sigmoid( @@ -659,10 +614,8 @@ impl SplatTrainer { self.optim = Some(create_default_optimizer().load_record(record)); let stats = RefineStats { - num_split: split_count, - num_cloned: clone_count, - num_transparent_pruned: alpha_pruned, - num_scale_pruned: scale_pruned, + num_added: split_count as u32, + num_pruned: pruned, }; (splats, stats) @@ -781,49 +734,6 @@ pub async fn prune_points( (splats, start_splats - new_points) } -pub fn concat_splats( - splats: Splats, - record: &mut HashMap>, - means: Tensor, - rotations: Tensor, - log_scales: Tensor, - sh_coeffs: Tensor, - raw_opac: Tensor, -) -> Splats { - let device = splats.means.device(); - - let cur_count = splats.means.dims()[0]; - let append_count = means.dims()[0]; - let sh_dim = splats.sh_coeffs.dims()[1]; - - map_splats_and_opt( - splats, - record, - |x| Tensor::cat(vec![x, means], 0), - |x| Tensor::cat(vec![x, rotations], 0), - |x| Tensor::cat(vec![x, log_scales], 0), - |x| Tensor::cat(vec![x, sh_coeffs], 0), - |x| Tensor::cat(vec![x, raw_opac], 0), - |x| { - Tensor::zeros([cur_count + append_count, 3], &device) - .slice_assign([0..cur_count, 0..3], x) - }, - |x| { - Tensor::zeros([cur_count + append_count, 4], &device) - .slice_assign([0..cur_count, 0..4], x) - }, - |x| { - Tensor::zeros([cur_count + append_count, 3], &device) - .slice_assign([0..cur_count, 0..3], x) - }, - |x| { - Tensor::zeros([cur_count + append_count, sh_dim, 3], &device) - .slice_assign([0..cur_count, 0..sh_dim, 0..3], x) - }, - |x| Tensor::zeros([cur_count + append_count], &device).slice_assign([0..cur_count], x), - ) -} - #[cfg(test)] mod tests { use burn::{