Skip to content

Commit

Permalink
progress
Browse files Browse the repository at this point in the history
  • Loading branch information
georgematheos committed Sep 20, 2024
1 parent ac490b3 commit 341f690
Show file tree
Hide file tree
Showing 8 changed files with 1,545 additions and 163 deletions.
1,309 changes: 1,280 additions & 29 deletions notebooks/gm/gen3d/debugging.ipynb

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions src/b3d/chisight/gen3d/image_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping) -> FloatArr
),
)
return jax.vmap(
jax.vmap(vertex_kernel.sample, in_axes=(0, 0, None, None, 0, 0, None)),
in_axes=(0, 0, None, None, 0, 0, None),
jax.vmap(vertex_kernel.sample, in_axes=(0, 0, None, None, 0, 0, None, None)),
in_axes=(0, 0, None, None, 0, 0, None, None),
)(
keys,
pixel_latent_rgbd,
Expand All @@ -269,6 +269,7 @@ def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping) -> FloatArr
pixel_visibility_prob,
pixel_depth_nonreturn_prob,
hyperparams["intrinsics"],
hyperparams["unexplained_depth_nonreturn_prob"]
)

def logpdf(
Expand All @@ -284,14 +285,15 @@ def logpdf(
)

vertex_kernel = self.get_rgbd_vertex_kernel()
scores = jax.vmap(vertex_kernel.logpdf, in_axes=(0, 0, None, None, 0, 0, None))(
scores = jax.vmap(vertex_kernel.logpdf, in_axes=(0, 0, None, None, 0, 0, None, None))(
observed_rgbd_per_point,
latent_rgbd_per_point,
state["color_scale"],
state["depth_scale"],
state["visibility_prob"],
state["depth_nonreturn_prob"],
hyperparams["intrinsics"],
hyperparams["unexplained_depth_nonreturn_prob"]
)

# Points that don't hit the camera plane should not contribute to the score.
Expand Down Expand Up @@ -420,7 +422,7 @@ def logpdf(
# Score the collided pixels
scores = jax.vmap(
hyperparams["image_kernel"].get_rgbd_vertex_kernel().logpdf,
in_axes=(0, 0, None, None, 0, 0, None),
in_axes=(0, 0, None, None, 0, 0, None, None),
)(
observed_rgbd_per_point,
latent_rgbd_per_point,
Expand All @@ -429,6 +431,7 @@ def logpdf(
state["visibility_prob"][point_indices_for_observed_rgbds],
state["depth_nonreturn_prob"][point_indices_for_observed_rgbds],
hyperparams["intrinsics"],
hyperparams["unexplained_depth_nonreturn_prob"]
)
total_score_for_explained_pixels = jnp.where(is_valid, scores, 0.0).sum()

Expand Down
89 changes: 89 additions & 0 deletions src/b3d/chisight/gen3d/image_kernel_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from abc import abstractmethod

Check failure on line 1 in src/b3d/chisight/gen3d/image_kernel_new.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/b3d/chisight/gen3d/image_kernel_new.py:1:17: F401 `abc.abstractmethod` imported but unused
from functools import cached_property

Check failure on line 2 in src/b3d/chisight/gen3d/image_kernel_new.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/b3d/chisight/gen3d/image_kernel_new.py:2:23: F401 `functools.cached_property` imported but unused
from typing import Mapping

import genjax
import jax
import jax.numpy as jnp
from genjax import Pytree
from genjax.typing import FloatArray, IntArray, PRNGKey

Check failure on line 9 in src/b3d/chisight/gen3d/image_kernel_new.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/b3d/chisight/gen3d/image_kernel_new.py:9:27: F401 `genjax.typing.FloatArray` imported but unused

import b3d.utils
from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import (
PixelRGBDDistribution,
is_unexplained,

Check failure on line 14 in src/b3d/chisight/gen3d/image_kernel_new.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/b3d/chisight/gen3d/image_kernel_new.py:14:5: F401 `b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels.is_unexplained` imported but unused
)

# using this in combination with mode="drop" in the .at[]
# methods can help filter out vertices that are not visible in the image
INVALID_IDX = jnp.iinfo(jnp.int32).min # -2147483648

class PixelsPointsAssociation(Pytree):
observed_pixel_indices : IntArray

def from_pose_intrinsics_vertices(pose, intrinsics, vertices):
image_height, image_width = (
intrinsics["image_height"].unwrap(),
intrinsics["image_width"].unwrap(),
)
transformed_points = pose.apply(vertices)

# Sort the vertices by depth.
sort_order = jnp.argsort(transformed_points[..., 2])
transformed_points_sorted = transformed_points[sort_order]

# Project the vertices to the image plane.
projected_coords = jnp.rint(
b3d.utils.xyz_to_pixel_coordinates(
transformed_points_sorted,
intrinsics["fx"],
intrinsics["fy"],
intrinsics["cx"],
intrinsics["cy"],
)
- 0.5
).astype(jnp.int32)
projected_coords = jnp.nan_to_num(projected_coords, nan=INVALID_IDX)
# handle the case where the projected coordinates are outside the image
projected_coords = jnp.where(projected_coords > 0, projected_coords, INVALID_IDX)
projected_coords = jnp.where(
projected_coords < jnp.array([image_height, image_width]),
projected_coords,
INVALID_IDX,
)

# Compute the unique pixel coordinates and the indices of the first vertex that hit that pixel.
unique_pixel_coordinates, unique_indices = jnp.unique(
projected_coords,
axis=0,
return_index=True,
size=projected_coords.shape[0],
fill_value=INVALID_IDX,
)

# Reorder the unique pixel coordinates, to the original point array indexing scheme
observed_pixel_coordinates_per_point = -jnp.ones((transformed_points.shape[0], 2), dtype=jnp.int32)
observed_pixel_coordinates_per_point = observed_pixel_coordinates_per_point.at[
sort_order[unique_indices]
].set(unique_pixel_coordinates)

return PixelsPointsAssociation(observed_pixel_coordinates_per_point)

def get_pixel_index(self, point_index):
return self.observed_pixel_indices[point_index]

@Pytree.dataclass
class UniquePixelsImageKernel(genjax.ExactDensity):
rgbd_vertex_kernel: PixelRGBDDistribution

def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping):
ppa = PixelsPointsAssociation.from_pose_intrinsics_vertices(

Check failure on line 80 in src/b3d/chisight/gen3d/image_kernel_new.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

src/b3d/chisight/gen3d/image_kernel_new.py:80:9: F841 Local variable `ppa` is assigned to but never used
state["pose"], hyperparams["intrinsics"], state["vertices"]
)
return jax.vmap(
jax.vmap(
lambda i, j: self.rgbd_vertex_kernel.sample(
key,
)
)
)
27 changes: 25 additions & 2 deletions src/b3d/chisight/gen3d/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,25 @@ def inference_step(
return (trace, weight)


def get_trace_generated_during_inference(key, trace, pose, inference_hyperparams):
return propose_other_latents_given_pose(key, trace, pose, inference_hyperparams)[0]
def get_trace_generated_during_inference(
key, trace, pose, inference_hyperparams,
do_advance_time=True,
observed_rgbd=None,
just_return_trace=True
):
"""
Get the trace generated at pose `pose` with key `key` by inference_step,
when it was given `trace`, `do_advance_time`, `inference_hyperparams`,
and `observed_rgbd` as input.
"""
if do_advance_time:
assert observed_rgbd is not None
trace = advance_time(key, trace, observed_rgbd)
vals = propose_other_latents_given_pose(key, trace, pose, inference_hyperparams)
if just_return_trace:
return vals[0]
else:
return vals


def maybe_swap_in_gt_pose(
Expand Down Expand Up @@ -238,6 +255,12 @@ def propose_other_latents_given_pose(key, advanced_trace, pose, inference_hyperp
k1, k2, k3, k4 = split(key, 4)

trace = update_field(k1, advanced_trace, "pose", pose)

sup = get_hypers(trace)["color_scale_kernel"].support
val = get_prev_state(advanced_trace)["color_scale"]
idx = jnp.argmin(jnp.abs(sup - val))
newidx = jnp.minimum(idx+1, sup.shape[0]-1)
trace = update_field(k1, trace, "color_scale", sup[newidx])

k2a, k2b = split(k2)
(
Expand Down
13 changes: 3 additions & 10 deletions src/b3d/chisight/gen3d/inference/point_attribute_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,7 @@ def propose_a_points_attributes(
return _propose_a_points_attributes(
key,
observed_rgbd_for_point=observed_rgbd_for_point,
latent_rgbd_for_point=jnp.array(
[
0.0,
0.0,
0.0,
new_state["pose"].apply(hyperparams["vertices"][vertex_index])[2],
]
),
latent_depth=new_state["pose"].apply(hyperparams["vertices"][vertex_index])[2],
previous_color=prev_state["colors"][vertex_index],
previous_visibility_prob=prev_state["visibility_prob"][vertex_index],
previous_dnrp=prev_state["depth_nonreturn_prob"][vertex_index],
Expand All @@ -107,7 +100,7 @@ def propose_a_points_attributes(
def _propose_a_points_attributes(
key,
observed_rgbd_for_point,
latent_rgbd_for_point,
latent_depth,
previous_color,
previous_visibility_prob,
previous_dnrp,
Expand All @@ -121,7 +114,6 @@ def _propose_a_points_attributes(
visibility_transition_kernel = hyperparams["visibility_prob_kernel"]
color_kernel = hyperparams["color_kernel"]
obs_rgbd_kernel = hyperparams["image_kernel"].get_rgbd_vertex_kernel()
latent_depth = latent_rgbd_for_point[3]
intrinsics = hyperparams["intrinsics"]

def score_attribute_assignment(color, visprob, dnrprob):
Expand All @@ -138,6 +130,7 @@ def score_attribute_assignment(color, visprob, dnrprob):
visibility_prob=visprob,
depth_nonreturn_prob=dnrprob,
intrinsics=intrinsics,
invisible_depth_nonreturn_prob=hyperparams["unexplained_depth_nonreturn_prob"],
)
return (
visprob_transition_score
Expand Down
117 changes: 58 additions & 59 deletions src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class PixelColorDistribution(genjax.ExactDensity):
Distribuiton args:
- latent_rgb
- rgb_scale
- visibility_prob
Support:
- An RGB value in [0, 1]^3.
Expand Down Expand Up @@ -183,61 +182,61 @@ def logpdf_per_channel(
return self._base_dist.log_prob(observed_color)


@Pytree.dataclass
class MixturePixelColorDistribution(PixelColorDistribution):
"""A distribution that generates the color of a pixel from a mixture of a
truncated Laplace distribution centered around the latent color (inlier
branch) and a uniform distribution (occluded branch). The mixture is
controlled by the occluded_prob parameter. The support of the
distribution is ([0, 1]^3).
"""

@property
def _occluded_dist(self) -> PixelColorDistribution:
return UniformPixelColorDistribution()

@property
def _inlier_dist(self) -> PixelColorDistribution:
return TruncatedLaplacePixelColorDistribution()

@property
def _mixture_dists(self) -> tuple[PixelColorDistribution, PixelColorDistribution]:
return (self._occluded_dist, self._inlier_dist)

def _get_mix_ratio(self, visibility_prob: float) -> FloatArray:
return jnp.array((1 - visibility_prob, visibility_prob))

def sample(
self,
key: PRNGKey,
latent_color: FloatArray,
color_scale: FloatArray,
visibility_prob: float,
*args,
**kwargs,
) -> FloatArray:
return PythonMixtureDistribution(self._mixture_dists).sample(
key, self._get_mix_ratio(visibility_prob), [(), (latent_color, color_scale)]
)

def logpdf_per_channel(
self,
observed_color: FloatArray,
latent_color: FloatArray,
color_scale: FloatArray,
visibility_prob: float,
*args,
**kwargs,
) -> FloatArray:
# Since the mixture model class does not keep the per-channel information,
# we have to redefine this method to allow for testing
logprobs = []
for dist, prob in zip(
self._mixture_dists, self._get_mix_ratio(visibility_prob)
):
logprobs.append(
dist.logpdf_per_channel(observed_color, latent_color, color_scale)
+ jnp.log(prob)
)

return jnp.logaddexp(*logprobs)
# @Pytree.dataclass
# class MixturePixelColorDistribution(PixelColorDistribution):
# """A distribution that generates the color of a pixel from a mixture of a
# truncated Laplace distribution centered around the latent color (inlier
# branch) and a uniform distribution (occluded branch). The mixture is
# controlled by the occluded_prob parameter. The support of the
# distribution is ([0, 1]^3).
# """

# @property
# def _occluded_dist(self) -> PixelColorDistribution:
# return UniformPixelColorDistribution()

# @property
# def _inlier_dist(self) -> PixelColorDistribution:
# return TruncatedLaplacePixelColorDistribution()

# @property
# def _mixture_dists(self) -> tuple[PixelColorDistribution, PixelColorDistribution]:
# return (self._occluded_dist, self._inlier_dist)

# def _get_mix_ratio(self, visibility_prob: float) -> FloatArray:
# return jnp.array((1 - visibility_prob, visibility_prob))

# def sample(
# self,
# key: PRNGKey,
# latent_color: FloatArray,
# color_scale: FloatArray,
# visibility_prob: float,
# *args,
# **kwargs,
# ) -> FloatArray:
# return PythonMixtureDistribution(self._mixture_dists).sample(
# key, self._get_mix_ratio(visibility_prob), [(), (latent_color, color_scale)]
# )

# def logpdf_per_channel(
# self,
# observed_color: FloatArray,
# latent_color: FloatArray,
# color_scale: FloatArray,
# visibility_prob: float,
# *args,
# **kwargs,
# ) -> FloatArray:
# # Since the mixture model class does not keep the per-channel information,
# # we have to redefine this method to allow for testing
# logprobs = []
# for dist, prob in zip(
# self._mixture_dists, self._get_mix_ratio(visibility_prob)
# ):
# logprobs.append(
# dist.logpdf_per_channel(observed_color, latent_color, color_scale)
# + jnp.log(prob)
# )

# return jnp.logaddexp(*logprobs)
4 changes: 2 additions & 2 deletions src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class PixelDepthDistribution(genjax.ExactDensity):
Distribution args:
- latent_depth
- depth_scale
- visibility_prob
- depth_nonreturn_prob
- near
- far
Support: depth value in [near, far], or DEPTH_NONRETURN_VAL.
"""
Expand Down
Loading

0 comments on commit 341f690

Please sign in to comment.