Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 20, 2024
1 parent 341f690 commit d4c533e
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 46 deletions.
14 changes: 9 additions & 5 deletions src/b3d/chisight/gen3d/image_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping) -> FloatArr
),
)
return jax.vmap(
jax.vmap(vertex_kernel.sample, in_axes=(0, 0, None, None, 0, 0, None, None)),
jax.vmap(
vertex_kernel.sample, in_axes=(0, 0, None, None, 0, 0, None, None)
),
in_axes=(0, 0, None, None, 0, 0, None, None),
)(
keys,
Expand All @@ -269,7 +271,7 @@ def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping) -> FloatArr
pixel_visibility_prob,
pixel_depth_nonreturn_prob,
hyperparams["intrinsics"],
hyperparams["unexplained_depth_nonreturn_prob"]
hyperparams["unexplained_depth_nonreturn_prob"],
)

def logpdf(
Expand All @@ -285,15 +287,17 @@ def logpdf(
)

vertex_kernel = self.get_rgbd_vertex_kernel()
scores = jax.vmap(vertex_kernel.logpdf, in_axes=(0, 0, None, None, 0, 0, None, None))(
scores = jax.vmap(
vertex_kernel.logpdf, in_axes=(0, 0, None, None, 0, 0, None, None)
)(
observed_rgbd_per_point,
latent_rgbd_per_point,
state["color_scale"],
state["depth_scale"],
state["visibility_prob"],
state["depth_nonreturn_prob"],
hyperparams["intrinsics"],
hyperparams["unexplained_depth_nonreturn_prob"]
hyperparams["unexplained_depth_nonreturn_prob"],
)

# Points that don't hit the camera plane should not contribute to the score.
Expand Down Expand Up @@ -431,7 +435,7 @@ def logpdf(
state["visibility_prob"][point_indices_for_observed_rgbds],
state["depth_nonreturn_prob"][point_indices_for_observed_rgbds],
hyperparams["intrinsics"],
hyperparams["unexplained_depth_nonreturn_prob"]
hyperparams["unexplained_depth_nonreturn_prob"],
)
total_score_for_explained_pixels = jnp.where(is_valid, scores, 0.0).sum()

Expand Down
25 changes: 14 additions & 11 deletions src/b3d/chisight/gen3d/image_kernel_new.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
from abc import abstractmethod
from functools import cached_property
from typing import Mapping

import genjax
import jax
import jax.numpy as jnp
from genjax import Pytree
from genjax.typing import FloatArray, IntArray, PRNGKey
from genjax.typing import IntArray, PRNGKey

import b3d.utils
from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import (
PixelRGBDDistribution,
is_unexplained,
)

# using this in combination with mode="drop" in the .at[]
# methods can help filter out vertices that are not visible in the image
INVALID_IDX = jnp.iinfo(jnp.int32).min # -2147483648


class PixelsPointsAssociation(Pytree):
observed_pixel_indices : IntArray
observed_pixel_indices: IntArray

def from_pose_intrinsics_vertices(pose, intrinsics, vertices):
image_height, image_width = (
intrinsics["image_height"].unwrap(),
intrinsics["image_width"].unwrap(),
)
transformed_points = pose.apply(vertices)

# Sort the vertices by depth.
sort_order = jnp.argsort(transformed_points[..., 2])
transformed_points_sorted = transformed_points[sort_order]
Expand All @@ -45,7 +43,9 @@ def from_pose_intrinsics_vertices(pose, intrinsics, vertices):
).astype(jnp.int32)
projected_coords = jnp.nan_to_num(projected_coords, nan=INVALID_IDX)
# handle the case where the projected coordinates are outside the image
projected_coords = jnp.where(projected_coords > 0, projected_coords, INVALID_IDX)
projected_coords = jnp.where(
projected_coords > 0, projected_coords, INVALID_IDX
)
projected_coords = jnp.where(
projected_coords < jnp.array([image_height, image_width]),
projected_coords,
Expand All @@ -62,7 +62,9 @@ def from_pose_intrinsics_vertices(pose, intrinsics, vertices):
)

# Reorder the unique pixel coordinates, to the original point array indexing scheme
observed_pixel_coordinates_per_point = -jnp.ones((transformed_points.shape[0], 2), dtype=jnp.int32)
observed_pixel_coordinates_per_point = -jnp.ones(
(transformed_points.shape[0], 2), dtype=jnp.int32
)
observed_pixel_coordinates_per_point = observed_pixel_coordinates_per_point.at[
sort_order[unique_indices]
].set(unique_pixel_coordinates)
Expand All @@ -71,7 +73,8 @@ def from_pose_intrinsics_vertices(pose, intrinsics, vertices):

def get_pixel_index(self, point_index):
return self.observed_pixel_indices[point_index]



@Pytree.dataclass
class UniquePixelsImageKernel(genjax.ExactDensity):
rgbd_vertex_kernel: PixelRGBDDistribution
Expand All @@ -83,7 +86,7 @@ def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping):
return jax.vmap(
jax.vmap(
lambda i, j: self.rgbd_vertex_kernel.sample(
key,
key,
)
)
)
)
11 changes: 7 additions & 4 deletions src/b3d/chisight/gen3d/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,13 @@ def inference_step(


def get_trace_generated_during_inference(
key, trace, pose, inference_hyperparams,
key,
trace,
pose,
inference_hyperparams,
do_advance_time=True,
observed_rgbd=None,
just_return_trace=True
just_return_trace=True,
):
"""
Get the trace generated at pose `pose` with key `key` by inference_step,
Expand Down Expand Up @@ -255,11 +258,11 @@ def propose_other_latents_given_pose(key, advanced_trace, pose, inference_hyperp
k1, k2, k3, k4 = split(key, 4)

trace = update_field(k1, advanced_trace, "pose", pose)

sup = get_hypers(trace)["color_scale_kernel"].support
val = get_prev_state(advanced_trace)["color_scale"]
idx = jnp.argmin(jnp.abs(sup - val))
newidx = jnp.minimum(idx+1, sup.shape[0]-1)
newidx = jnp.minimum(idx + 1, sup.shape[0] - 1)
trace = update_field(k1, trace, "color_scale", sup[newidx])

k2a, k2b = split(k2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def score_attribute_assignment(color, visprob, dnrprob):
visibility_prob=visprob,
depth_nonreturn_prob=dnrprob,
intrinsics=intrinsics,
invisible_depth_nonreturn_prob=hyperparams["unexplained_depth_nonreturn_prob"],
invisible_depth_nonreturn_prob=hyperparams[
"unexplained_depth_nonreturn_prob"
],
)
return (
visprob_transition_score
Expand Down
2 changes: 0 additions & 2 deletions src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@

import genjax
import jax
import jax.numpy as jnp
from b3d.modeling_utils import (
_FIXED_COLOR_UNIFORM_WINDOW,
PythonMixtureDistribution,
renormalized_laplace,
truncated_laplace,
)
Expand Down
69 changes: 46 additions & 23 deletions src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import genjax
import jax
import jax.numpy as jnp
from jax.random import split
from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import PixelColorDistribution
from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import PixelDepthDistribution
from genjax import Pytree
from genjax.typing import FloatArray, PRNGKey
from jax.random import split


def is_unexplained(latent_value: FloatArray) -> bool:
Expand Down Expand Up @@ -68,6 +68,7 @@ def logpdf(
) -> float:
raise NotImplementedError


@Pytree.dataclass
class RGBDDist(genjax.ExactDensity):
"""
Expand All @@ -82,6 +83,7 @@ class RGBDDist(genjax.ExactDensity):
Calls a color distribution and a "valid depth return" depth distribution to sample the pixel.
"""

color_distribution: PixelColorDistribution
depth_distribution: PixelDepthDistribution

Expand All @@ -95,34 +97,31 @@ def sample(
intrinsics: dict,
) -> FloatArray:
k1, k2, k3 = split(key, 3)
color = self.color_distribution.sample(
k1, latent_rgbd[:3], color_scale
)
color = self.color_distribution.sample(k1, latent_rgbd[:3], color_scale)
depth_if_return = self.depth_distribution.sample(
k2, latent_rgbd[3], depth_scale, intrinsics["near"], intrinsics["far"]
)
depth = jnp.where(
jax.random.bernoulli(k3, depth_nonreturn_prob),
0.0,
depth_if_return
jax.random.bernoulli(k3, depth_nonreturn_prob), 0.0, depth_if_return
)

return jnp.concatenate([color, depth])
def logpdf(self, obs, latent, color_scale, depth_scale, depth_nonreturn_prob, intrinsics):
color_logpdf = self.color_distribution.logpdf(
obs[:3], latent[:3], color_scale
)

def logpdf(
self, obs, latent, color_scale, depth_scale, depth_nonreturn_prob, intrinsics
):
color_logpdf = self.color_distribution.logpdf(obs[:3], latent[:3], color_scale)
depth_logpdf_if_return = self.depth_distribution.logpdf(
obs[3], latent[3], depth_scale, intrinsics["near"], intrinsics["far"]
)
depth_logpdf = jnp.where(
obs[3] == 0.0,
jnp.log(depth_nonreturn_prob),
jnp.log(1 - depth_nonreturn_prob) + depth_logpdf_if_return
jnp.log(1 - depth_nonreturn_prob) + depth_logpdf_if_return,
)
return color_logpdf + depth_logpdf


@Pytree.dataclass
class FullPixelRGBDDistribution(PixelRGBDDistribution):
"""
Expand All @@ -144,10 +143,12 @@ class FullPixelRGBDDistribution(PixelRGBDDistribution):
@property
def inlier_distribution(self):
return RGBDDist(self.inlier_color_distribution, self.inlier_depth_distribution)

@property
def outlier_distribution(self):
return RGBDDist(self.outlier_color_distribution, self.outlier_depth_distribution)
return RGBDDist(
self.outlier_color_distribution, self.outlier_depth_distribution
)

def sample(
self,
Expand All @@ -158,16 +159,26 @@ def sample(
visibility_prob: float,
depth_nonreturn_prob: float,
intrinsics: dict,
depth_nonreturn_prob_for_invisible: float
depth_nonreturn_prob_for_invisible: float,
) -> FloatArray:
k1, k2, k3 = split(key, 3)
return jnp.where(
jax.random.bernoulli(k1, visibility_prob),
self.inlier_distribution.sample(
k2, latent_rgbd, color_scale, depth_scale, depth_nonreturn_prob, intrinsics
k2,
latent_rgbd,
color_scale,
depth_scale,
depth_nonreturn_prob,
intrinsics,
),
self.outlier_distribution.sample(
k3, latent_rgbd, color_scale, depth_scale, depth_nonreturn_prob_for_invisible, intrinsics
k3,
latent_rgbd,
color_scale,
depth_scale,
depth_nonreturn_prob_for_invisible,
intrinsics,
),
)

Expand All @@ -181,14 +192,26 @@ def logpdf(
visibility_prob: float,
depth_nonreturn_prob: float,
intrinsics: dict,
invisible_depth_nonreturn_prob: float
invisible_depth_nonreturn_prob: float,
) -> float:
return jnp.logaddexp(
jnp.log(visibility_prob) + self.inlier_distribution.logpdf(
observed_rgbd, latent_rgbd, color_scale, depth_scale, depth_nonreturn_prob, intrinsics
jnp.log(visibility_prob)
+ self.inlier_distribution.logpdf(
observed_rgbd,
latent_rgbd,
color_scale,
depth_scale,
depth_nonreturn_prob,
intrinsics,
),
jnp.log(1 - visibility_prob) + self.outlier_distribution.logpdf(
observed_rgbd, latent_rgbd, color_scale, depth_scale, invisible_depth_nonreturn_prob, intrinsics
jnp.log(1 - visibility_prob)
+ self.outlier_distribution.logpdf(
observed_rgbd,
latent_rgbd,
color_scale,
depth_scale,
invisible_depth_nonreturn_prob,
intrinsics,
),
)

Expand Down

0 comments on commit d4c533e

Please sign in to comment.