diff --git a/src/b3d/chisight/gen3d/image_kernel.py b/src/b3d/chisight/gen3d/image_kernel.py index ca6b1d0b..936d6dbd 100644 --- a/src/b3d/chisight/gen3d/image_kernel.py +++ b/src/b3d/chisight/gen3d/image_kernel.py @@ -259,7 +259,9 @@ def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping) -> FloatArr ), ) return jax.vmap( - jax.vmap(vertex_kernel.sample, in_axes=(0, 0, None, None, 0, 0, None, None)), + jax.vmap( + vertex_kernel.sample, in_axes=(0, 0, None, None, 0, 0, None, None) + ), in_axes=(0, 0, None, None, 0, 0, None, None), )( keys, @@ -269,7 +271,7 @@ def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping) -> FloatArr pixel_visibility_prob, pixel_depth_nonreturn_prob, hyperparams["intrinsics"], - hyperparams["unexplained_depth_nonreturn_prob"] + hyperparams["unexplained_depth_nonreturn_prob"], ) def logpdf( @@ -285,7 +287,9 @@ def logpdf( ) vertex_kernel = self.get_rgbd_vertex_kernel() - scores = jax.vmap(vertex_kernel.logpdf, in_axes=(0, 0, None, None, 0, 0, None, None))( + scores = jax.vmap( + vertex_kernel.logpdf, in_axes=(0, 0, None, None, 0, 0, None, None) + )( observed_rgbd_per_point, latent_rgbd_per_point, state["color_scale"], @@ -293,7 +297,7 @@ def logpdf( state["visibility_prob"], state["depth_nonreturn_prob"], hyperparams["intrinsics"], - hyperparams["unexplained_depth_nonreturn_prob"] + hyperparams["unexplained_depth_nonreturn_prob"], ) # Points that don't hit the camera plane should not contribute to the score. @@ -431,7 +435,7 @@ def logpdf( state["visibility_prob"][point_indices_for_observed_rgbds], state["depth_nonreturn_prob"][point_indices_for_observed_rgbds], hyperparams["intrinsics"], - hyperparams["unexplained_depth_nonreturn_prob"] + hyperparams["unexplained_depth_nonreturn_prob"], ) total_score_for_explained_pixels = jnp.where(is_valid, scores, 0.0).sum() diff --git a/src/b3d/chisight/gen3d/image_kernel_new.py b/src/b3d/chisight/gen3d/image_kernel_new.py index 59764a1d..c1b5108f 100644 --- a/src/b3d/chisight/gen3d/image_kernel_new.py +++ b/src/b3d/chisight/gen3d/image_kernel_new.py @@ -1,25 +1,23 @@ -from abc import abstractmethod -from functools import cached_property from typing import Mapping import genjax import jax import jax.numpy as jnp from genjax import Pytree -from genjax.typing import FloatArray, IntArray, PRNGKey +from genjax.typing import IntArray, PRNGKey import b3d.utils from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import ( PixelRGBDDistribution, - is_unexplained, ) # using this in combination with mode="drop" in the .at[] # methods can help filter out vertices that are not visible in the image INVALID_IDX = jnp.iinfo(jnp.int32).min # -2147483648 + class PixelsPointsAssociation(Pytree): - observed_pixel_indices : IntArray + observed_pixel_indices: IntArray def from_pose_intrinsics_vertices(pose, intrinsics, vertices): image_height, image_width = ( @@ -27,7 +25,7 @@ def from_pose_intrinsics_vertices(pose, intrinsics, vertices): intrinsics["image_width"].unwrap(), ) transformed_points = pose.apply(vertices) - + # Sort the vertices by depth. sort_order = jnp.argsort(transformed_points[..., 2]) transformed_points_sorted = transformed_points[sort_order] @@ -45,7 +43,9 @@ def from_pose_intrinsics_vertices(pose, intrinsics, vertices): ).astype(jnp.int32) projected_coords = jnp.nan_to_num(projected_coords, nan=INVALID_IDX) # handle the case where the projected coordinates are outside the image - projected_coords = jnp.where(projected_coords > 0, projected_coords, INVALID_IDX) + projected_coords = jnp.where( + projected_coords > 0, projected_coords, INVALID_IDX + ) projected_coords = jnp.where( projected_coords < jnp.array([image_height, image_width]), projected_coords, @@ -62,7 +62,9 @@ def from_pose_intrinsics_vertices(pose, intrinsics, vertices): ) # Reorder the unique pixel coordinates, to the original point array indexing scheme - observed_pixel_coordinates_per_point = -jnp.ones((transformed_points.shape[0], 2), dtype=jnp.int32) + observed_pixel_coordinates_per_point = -jnp.ones( + (transformed_points.shape[0], 2), dtype=jnp.int32 + ) observed_pixel_coordinates_per_point = observed_pixel_coordinates_per_point.at[ sort_order[unique_indices] ].set(unique_pixel_coordinates) @@ -71,7 +73,8 @@ def from_pose_intrinsics_vertices(pose, intrinsics, vertices): def get_pixel_index(self, point_index): return self.observed_pixel_indices[point_index] - + + @Pytree.dataclass class UniquePixelsImageKernel(genjax.ExactDensity): rgbd_vertex_kernel: PixelRGBDDistribution @@ -83,7 +86,7 @@ def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping): return jax.vmap( jax.vmap( lambda i, j: self.rgbd_vertex_kernel.sample( - key, + key, ) ) - ) \ No newline at end of file + ) diff --git a/src/b3d/chisight/gen3d/inference/inference.py b/src/b3d/chisight/gen3d/inference/inference.py index fdc482b3..a98a90af 100644 --- a/src/b3d/chisight/gen3d/inference/inference.py +++ b/src/b3d/chisight/gen3d/inference/inference.py @@ -128,10 +128,13 @@ def inference_step( def get_trace_generated_during_inference( - key, trace, pose, inference_hyperparams, + key, + trace, + pose, + inference_hyperparams, do_advance_time=True, observed_rgbd=None, - just_return_trace=True + just_return_trace=True, ): """ Get the trace generated at pose `pose` with key `key` by inference_step, @@ -255,11 +258,11 @@ def propose_other_latents_given_pose(key, advanced_trace, pose, inference_hyperp k1, k2, k3, k4 = split(key, 4) trace = update_field(k1, advanced_trace, "pose", pose) - + sup = get_hypers(trace)["color_scale_kernel"].support val = get_prev_state(advanced_trace)["color_scale"] idx = jnp.argmin(jnp.abs(sup - val)) - newidx = jnp.minimum(idx+1, sup.shape[0]-1) + newidx = jnp.minimum(idx + 1, sup.shape[0] - 1) trace = update_field(k1, trace, "color_scale", sup[newidx]) k2a, k2b = split(k2) diff --git a/src/b3d/chisight/gen3d/inference/point_attribute_proposals.py b/src/b3d/chisight/gen3d/inference/point_attribute_proposals.py index 7b0b3aac..8ffe0352 100644 --- a/src/b3d/chisight/gen3d/inference/point_attribute_proposals.py +++ b/src/b3d/chisight/gen3d/inference/point_attribute_proposals.py @@ -130,7 +130,9 @@ def score_attribute_assignment(color, visprob, dnrprob): visibility_prob=visprob, depth_nonreturn_prob=dnrprob, intrinsics=intrinsics, - invisible_depth_nonreturn_prob=hyperparams["unexplained_depth_nonreturn_prob"], + invisible_depth_nonreturn_prob=hyperparams[ + "unexplained_depth_nonreturn_prob" + ], ) return ( visprob_transition_score diff --git a/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py b/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py index 3dfcfacc..7ead7bff 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py @@ -3,10 +3,8 @@ import genjax import jax -import jax.numpy as jnp from b3d.modeling_utils import ( _FIXED_COLOR_UNIFORM_WINDOW, - PythonMixtureDistribution, renormalized_laplace, truncated_laplace, ) diff --git a/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py b/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py index fe06ca94..561fe307 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py @@ -3,11 +3,11 @@ import genjax import jax import jax.numpy as jnp -from jax.random import split from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import PixelColorDistribution from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import PixelDepthDistribution from genjax import Pytree from genjax.typing import FloatArray, PRNGKey +from jax.random import split def is_unexplained(latent_value: FloatArray) -> bool: @@ -68,6 +68,7 @@ def logpdf( ) -> float: raise NotImplementedError + @Pytree.dataclass class RGBDDist(genjax.ExactDensity): """ @@ -82,6 +83,7 @@ class RGBDDist(genjax.ExactDensity): Calls a color distribution and a "valid depth return" depth distribution to sample the pixel. """ + color_distribution: PixelColorDistribution depth_distribution: PixelDepthDistribution @@ -95,34 +97,31 @@ def sample( intrinsics: dict, ) -> FloatArray: k1, k2, k3 = split(key, 3) - color = self.color_distribution.sample( - k1, latent_rgbd[:3], color_scale - ) + color = self.color_distribution.sample(k1, latent_rgbd[:3], color_scale) depth_if_return = self.depth_distribution.sample( k2, latent_rgbd[3], depth_scale, intrinsics["near"], intrinsics["far"] ) depth = jnp.where( - jax.random.bernoulli(k3, depth_nonreturn_prob), - 0.0, - depth_if_return + jax.random.bernoulli(k3, depth_nonreturn_prob), 0.0, depth_if_return ) return jnp.concatenate([color, depth]) - - def logpdf(self, obs, latent, color_scale, depth_scale, depth_nonreturn_prob, intrinsics): - color_logpdf = self.color_distribution.logpdf( - obs[:3], latent[:3], color_scale - ) + + def logpdf( + self, obs, latent, color_scale, depth_scale, depth_nonreturn_prob, intrinsics + ): + color_logpdf = self.color_distribution.logpdf(obs[:3], latent[:3], color_scale) depth_logpdf_if_return = self.depth_distribution.logpdf( obs[3], latent[3], depth_scale, intrinsics["near"], intrinsics["far"] ) depth_logpdf = jnp.where( obs[3] == 0.0, jnp.log(depth_nonreturn_prob), - jnp.log(1 - depth_nonreturn_prob) + depth_logpdf_if_return + jnp.log(1 - depth_nonreturn_prob) + depth_logpdf_if_return, ) return color_logpdf + depth_logpdf + @Pytree.dataclass class FullPixelRGBDDistribution(PixelRGBDDistribution): """ @@ -144,10 +143,12 @@ class FullPixelRGBDDistribution(PixelRGBDDistribution): @property def inlier_distribution(self): return RGBDDist(self.inlier_color_distribution, self.inlier_depth_distribution) - + @property def outlier_distribution(self): - return RGBDDist(self.outlier_color_distribution, self.outlier_depth_distribution) + return RGBDDist( + self.outlier_color_distribution, self.outlier_depth_distribution + ) def sample( self, @@ -158,16 +159,26 @@ def sample( visibility_prob: float, depth_nonreturn_prob: float, intrinsics: dict, - depth_nonreturn_prob_for_invisible: float + depth_nonreturn_prob_for_invisible: float, ) -> FloatArray: k1, k2, k3 = split(key, 3) return jnp.where( jax.random.bernoulli(k1, visibility_prob), self.inlier_distribution.sample( - k2, latent_rgbd, color_scale, depth_scale, depth_nonreturn_prob, intrinsics + k2, + latent_rgbd, + color_scale, + depth_scale, + depth_nonreturn_prob, + intrinsics, ), self.outlier_distribution.sample( - k3, latent_rgbd, color_scale, depth_scale, depth_nonreturn_prob_for_invisible, intrinsics + k3, + latent_rgbd, + color_scale, + depth_scale, + depth_nonreturn_prob_for_invisible, + intrinsics, ), ) @@ -181,14 +192,26 @@ def logpdf( visibility_prob: float, depth_nonreturn_prob: float, intrinsics: dict, - invisible_depth_nonreturn_prob: float + invisible_depth_nonreturn_prob: float, ) -> float: return jnp.logaddexp( - jnp.log(visibility_prob) + self.inlier_distribution.logpdf( - observed_rgbd, latent_rgbd, color_scale, depth_scale, depth_nonreturn_prob, intrinsics + jnp.log(visibility_prob) + + self.inlier_distribution.logpdf( + observed_rgbd, + latent_rgbd, + color_scale, + depth_scale, + depth_nonreturn_prob, + intrinsics, ), - jnp.log(1 - visibility_prob) + self.outlier_distribution.logpdf( - observed_rgbd, latent_rgbd, color_scale, depth_scale, invisible_depth_nonreturn_prob, intrinsics + jnp.log(1 - visibility_prob) + + self.outlier_distribution.logpdf( + observed_rgbd, + latent_rgbd, + color_scale, + depth_scale, + invisible_depth_nonreturn_prob, + intrinsics, ), )