diff --git a/notebooks/gm/gen3d/debugging.ipynb b/notebooks/gm/gen3d/debugging.ipynb index 32bf477c..81c5a281 100644 --- a/notebooks/gm/gen3d/debugging.ipynb +++ b/notebooks/gm/gen3d/debugging.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -17,20 +17,50 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "problematic_scene_specs = [(2, 5), (3, 0), (3, 1), (3, 2), (4, 1), (4, 3), (5, 0)]\n", - "spec = problematic_scene_specs[1]\n", + "spec = problematic_scene_specs[2]\n", "scene_id, object_idx = spec" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/46 [00:00" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "all_data, meshes, renderer, intrinsics, initial_object_poses = gen3d.dataloading.load_scene(\n", " scene_id, FRAME_RATE=50, subdir=\"train_real\"\n", @@ -40,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -51,9 +81,65 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "{'pose_kernel': GaussianVMFPoseDriftKernel(...),\n", + " 'color_kernel': RenormalizedLaplaceColorDriftKernel(...),\n", + " 'visibility_prob_kernel': DiscreteFlipKernel(...),\n", + " 'depth_nonreturn_prob_kernel': DiscreteFlipKernel(...),\n", + " 'depth_scale_kernel': DiscreteFlipKernel(...),\n", + " 'color_scale_kernel': DiscreteFlipKernel(...),\n", + " 'image_kernel': UniquePixelsImageKernel(...),\n", + " 'unexplained_depth_nonreturn_prob': 0.02,\n", + " 'intrinsics': {'fx': ,\n", + " 'fy': ,\n", + " 'cx': ,\n", + " 'cy': ,\n", + " 'image_height': Const(...),\n", + " 'image_width': Const(...),\n", + " 'near': 0.01,\n", + " 'far': 3.0},\n", + " 'vertices': \n", + " >}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "hyperparams = gen3d.settings.hyperparams\n", "hyperparams[\"intrinsics\"] = intrinsics\n", @@ -63,30 +149,47 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], - "source": [ - "all_data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "InferenceHyperparams(...)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "inference_hyperparams = gen3d.settings.inference_hyperparams\n", "inference_hyperparams = gen3d.hyperparams.InferenceHyperparams(**{\n", " **inference_hyperparams.attributes_dict(),\n", - " \"pose_proposal_args\": [(0.04, 1000.0)]\n", + " # \"pose_proposal_args\": [(0.04, 1000.0)]\n", "})\n", "inference_hyperparams" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -99,7 +202,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -110,9 +213,18 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/georgematheos/b3d/src/b3d/modeling_utils.py:87: UserWarning: RenormalizedLaplace sampling is currently not implemented perfectly.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "key = jax.random.PRNGKey(156)\n", "og_trace = gen3d.inference.get_initial_trace(\n", @@ -122,18 +234,38 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ - "b3d.rr_init(\"inference_debugging-2\")\n" + "b3d.rr_init(\"inference_debugging-3-1-2\")\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], + "source": [ + "b3d.reload(gen3d.inference)\n", + "b3d.reload(b3d.chisight.gen3d.point_attribute_proposals)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/20 [00:00 0.02:\n", - " b3d.rr_init(\"gt_pose_trace\")\n", + " b3d.rr_init(\"gt_pose_trace-3-1\")\n", " gen3d.model.viz_trace(gt_trace, T, ground_truth_vertices=meshes[object_idx].vertices, ground_truth_pose=get_gt_pose(T))\n", " break" ] }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trace.get_score()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gt_trace.get_score()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "DistributionTrace(...)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gt_trace.subtraces[6]" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "observed_rgbd = all_data[6][\"rgbd\"]\n", + "state = gen3d.model.get_new_state(gt_trace)\n", + "intrinsics = hyperparams[\"intrinsics\"]\n", + "(\n", + " observed_rgbd_per_point,\n", + " latent_rgbd_per_point,\n", + " is_valid,\n", + " _,\n", + " point_indices_for_observed_rgbds,\n", + ") = gen3d.image_kernel.calculate_latent_and_observed_correspondences(\n", + " observed_rgbd, state, hyperparams\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "state[\"depth_nonreturn_prob\"][point_indices_for_observed_rgbds[0]]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "# jax.Array bool(10000,) true:7_894 false:2_106\n", + " Array([False, False, False, ..., True, True, True], dtype=bool)\n" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scores = jax.vmap(\n", + " hyperparams[\"image_kernel\"].get_rgbd_vertex_kernel().logpdf,\n", + " in_axes=(0, 0, None, None, 0, 0, None, None),\n", + ")(\n", + " observed_rgbd_per_point,\n", + " latent_rgbd_per_point,\n", + " state[\"color_scale\"],\n", + " state[\"depth_scale\"],\n", + " state[\"visibility_prob\"][point_indices_for_observed_rgbds],\n", + " state[\"depth_nonreturn_prob\"][point_indices_for_observed_rgbds],\n", + " hyperparams[\"intrinsics\"],\n", + " hyperparams[\"unexplained_depth_nonreturn_prob\"]\n", + ")\n", + "jnp.isinf(scores)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "observed_rgbd_per_point[-1]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "is_valid[42]" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "latent_rgbd_per_point[42]" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "observed_rgbd_per_point[42]" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "state[\"visibility_prob\"][point_indices_for_observed_rgbds][42]" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "state[\"depth_nonreturn_prob\"][point_indices_for_observed_rgbds][42]" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "observed_rgbd = observed_rgbd_per_point[0]\n", + "latent_rgbd = latent_rgbd_per_point[0]\n", + "color_scale = state[\"color_scale\"]\n", + "depth_scale = state[\"depth_scale\"]\n", + "visibility_prob = state[\"visibility_prob\"][point_indices_for_observed_rgbds][0]\n", + "depth_nonreturn_prob = state[\"depth_nonreturn_prob\"][point_indices_for_observed_rgbds][0]\n", + "intrinsics = hyperparams[\"intrinsics\"]\n", + "\n", + "hyperparams[\"image_kernel\"].get_rgbd_vertex_kernel().logpdf(\n", + " observed_rgbd,\n", + " latent_rgbd,\n", + " color_scale,\n", + " depth_scale,\n", + " visibility_prob,\n", + " 0.,\n", + " intrinsics,\n", + " hyperparams[\"unexplained_depth_nonreturn_prob\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "def get_trace_generated_during_inference(\n", + " key, trace, pose, inference_hyperparams,\n", + " do_advance_time=True,\n", + " observed_rgbd=None,\n", + " just_return_trace=True\n", + "):\n", + " \"\"\"\n", + " Get the trace generated at pose `pose` with key `key` by inference_step,\n", + " when it was given `trace`, `do_advance_time`, `inference_hyperparams`,\n", + " and `observed_rgbd` as input.\n", + " \"\"\"\n", + " if do_advance_time:\n", + " assert observed_rgbd is not None\n", + " trace = gen3d.inference.advance_time(key, trace, observed_rgbd)\n", + " vals = gen3d.inference.propose_other_latents_given_pose(key, trace, pose, inference_hyperparams)\n", + " if just_return_trace:\n", + " return vals[0]\n", + " else:\n", + " return vals\n" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "{'all_visprob_dnrprob_pairs': \n", + " >,\n", + " 'index': ,\n", + " 'log_normalized_scores': \n", + " >,\n", + " 'log_q_score': ,\n", + " 'log_qs_rgb': \n", + " >,\n", + " 'rgb_proposal_metadata': {'isvalid': ,\n", + " 'log_q_if_invalid': \n", + " >,\n", + " 'metadata_if_valid': {'log_K_score': \n", + " >,\n", + " 'log_L_score': \n", + " >,\n", + " 'log_qs': \n", + " >,\n", + " 'normalized_scores': \n", + " >,\n", + " 'overall_score': \n", + " >,\n", + " 'proposed_rgbs': \n", + " >,\n", + " 'sampled_index': \n", + " >,\n", + " 'sampled_rgb': \n", + " >},\n", + " 'value_if_observed_is_invalid': \n", + " >}}" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "_, _, metadata = get_trace_generated_during_inference(\n", + " gt_key, prev_trace, gt_pose, inference_hyperparams,\n", + " observed_rgbd=all_data[T][\"rgbd\"],\n", + " just_return_trace=False\n", + ")\n", + "jax.tree.map(lambda x: x[0], metadata[\"point_attribute_proposal_metadata\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "{'all_visprob_dnrprob_pairs': \n", + " >,\n", + " 'index': \n", + " >,\n", + " 'log_normalized_scores': \n", + " >,\n", + " 'log_q_score': \n", + " >,\n", + " 'log_qs_rgb': \n", + " >,\n", + " 'rgb_proposal_metadata': {'isvalid': # jax.Array bool(10000, 4) true:15_892 false:24_108\n", + " Array([[False, False, False, False],\n", + " [ True, True, True, True],\n", + " [False, False, False, False],\n", + " ...,\n", + " [False, False, False, False],\n", + " [False, False, False, False],\n", + " [ True, True, True, True]], dtype=bool)\n", + " ,\n", + " 'log_q_if_invalid': \n", + " >,\n", + " 'metadata_if_valid': {'log_K_score': \n", + " >,\n", + " 'log_L_score': \n", + " >,\n", + " 'log_qs': \n", + " >,\n", + " 'normalized_scores': \n", + " >,\n", + " 'overall_score': \n", + " >,\n", + " 'proposed_rgbs': \n", + " >,\n", + " 'sampled_index': \n", + " >,\n", + " 'sampled_rgb': \n", + " >},\n", + " 'value_if_observed_is_invalid': \n", + " >}}" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata[\"point_attribute_proposal_metadata\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -178,6 +1159,276 @@ "outputs": [], "source": [] }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hyperparams[\"image_kernel\"].get_rgbd_vertex_kernel().outlier_color_distribution.logpdf(\n", + " observed_rgbd[:3], latent_rgbd[:3], color_scale\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "s = hyperparams[\"image_kernel\"].get_rgbd_vertex_kernel()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "ename": "SyntaxError", + "evalue": "invalid syntax (3129153674.py, line 1)", + "output_type": "error", + "traceback": [ + "\u001b[0;36m Cell \u001b[0;32mIn[32], line 1\u001b[0;36m\u001b[0m\n\u001b[0;31m .outlier_depth_distribution.logpdf(\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" + ] + } + ], + "source": [ + ".outlier_depth_distribution.logpdf(\n", + " observed_rgbd[3], latent_rgbd[3], depth_scale, intrinsics[\"near\"], intrinsics[\"far\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [], + "source": [ + "total_log_prob = 0.0\n", + "is_depth_non_return = observed_rgbd[3] == 0.0\n", + "\n", + "# Is visible\n", + "total_visible_log_prob = 0.0\n", + "# color term\n", + "total_visible_log_prob += s.inlier_color_distribution.logpdf(\n", + " observed_rgbd[:3], latent_rgbd[:3], color_scale\n", + ")\n", + "# depth term\n", + "total_visible_log_prob += jnp.where(\n", + " is_depth_non_return,\n", + " jnp.log(depth_nonreturn_prob),\n", + " jnp.log(1 - depth_nonreturn_prob)\n", + " + s.inlier_depth_distribution.logpdf(\n", + " observed_rgbd[3],\n", + " latent_rgbd[3],\n", + " depth_scale,\n", + " intrinsics[\"near\"],\n", + " intrinsics[\"far\"],\n", + " ),\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 89, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jnp.log(1 - depth_nonreturn_prob)" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 91, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "total_not_visible_log_prob = 0.0\n", + "# color term\n", + "outlier_color_log_prob = s.outlier_color_distribution.logpdf(\n", + " observed_rgbd[:3],\n", + " latent_rgbd[:3],\n", + " color_scale,\n", + ")\n", + "outlier_depth_log_prob = s.outlier_depth_distribution.logpdf(\n", + " observed_rgbd[3],\n", + " latent_rgbd[3],\n", + " depth_scale,\n", + " intrinsics[\"near\"],\n", + " intrinsics[\"far\"],\n", + ")\n", + "outlier_depth_log_prob + outlier_color_log_prob" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 93, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "total_not_visible_log_prob += outlier_color_log_prob\n", + "# depth term\n", + "total_not_visible_log_prob += jnp.where(\n", + " is_depth_non_return,\n", + " jnp.log(depth_nonreturn_prob),\n", + " jnp.log(1 - depth_nonreturn_prob) + outlier_depth_log_prob,\n", + ")\n", + "total_not_visible_log_prob" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/b3d/chisight/gen3d/image_kernel.py b/src/b3d/chisight/gen3d/image_kernel.py index 5905c1b0..ca6b1d0b 100644 --- a/src/b3d/chisight/gen3d/image_kernel.py +++ b/src/b3d/chisight/gen3d/image_kernel.py @@ -259,8 +259,8 @@ def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping) -> FloatArr ), ) return jax.vmap( - jax.vmap(vertex_kernel.sample, in_axes=(0, 0, None, None, 0, 0, None)), - in_axes=(0, 0, None, None, 0, 0, None), + jax.vmap(vertex_kernel.sample, in_axes=(0, 0, None, None, 0, 0, None, None)), + in_axes=(0, 0, None, None, 0, 0, None, None), )( keys, pixel_latent_rgbd, @@ -269,6 +269,7 @@ def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping) -> FloatArr pixel_visibility_prob, pixel_depth_nonreturn_prob, hyperparams["intrinsics"], + hyperparams["unexplained_depth_nonreturn_prob"] ) def logpdf( @@ -284,7 +285,7 @@ def logpdf( ) vertex_kernel = self.get_rgbd_vertex_kernel() - scores = jax.vmap(vertex_kernel.logpdf, in_axes=(0, 0, None, None, 0, 0, None))( + scores = jax.vmap(vertex_kernel.logpdf, in_axes=(0, 0, None, None, 0, 0, None, None))( observed_rgbd_per_point, latent_rgbd_per_point, state["color_scale"], @@ -292,6 +293,7 @@ def logpdf( state["visibility_prob"], state["depth_nonreturn_prob"], hyperparams["intrinsics"], + hyperparams["unexplained_depth_nonreturn_prob"] ) # Points that don't hit the camera plane should not contribute to the score. @@ -420,7 +422,7 @@ def logpdf( # Score the collided pixels scores = jax.vmap( hyperparams["image_kernel"].get_rgbd_vertex_kernel().logpdf, - in_axes=(0, 0, None, None, 0, 0, None), + in_axes=(0, 0, None, None, 0, 0, None, None), )( observed_rgbd_per_point, latent_rgbd_per_point, @@ -429,6 +431,7 @@ def logpdf( state["visibility_prob"][point_indices_for_observed_rgbds], state["depth_nonreturn_prob"][point_indices_for_observed_rgbds], hyperparams["intrinsics"], + hyperparams["unexplained_depth_nonreturn_prob"] ) total_score_for_explained_pixels = jnp.where(is_valid, scores, 0.0).sum() diff --git a/src/b3d/chisight/gen3d/image_kernel_new.py b/src/b3d/chisight/gen3d/image_kernel_new.py new file mode 100644 index 00000000..59764a1d --- /dev/null +++ b/src/b3d/chisight/gen3d/image_kernel_new.py @@ -0,0 +1,89 @@ +from abc import abstractmethod +from functools import cached_property +from typing import Mapping + +import genjax +import jax +import jax.numpy as jnp +from genjax import Pytree +from genjax.typing import FloatArray, IntArray, PRNGKey + +import b3d.utils +from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import ( + PixelRGBDDistribution, + is_unexplained, +) + +# using this in combination with mode="drop" in the .at[] +# methods can help filter out vertices that are not visible in the image +INVALID_IDX = jnp.iinfo(jnp.int32).min # -2147483648 + +class PixelsPointsAssociation(Pytree): + observed_pixel_indices : IntArray + + def from_pose_intrinsics_vertices(pose, intrinsics, vertices): + image_height, image_width = ( + intrinsics["image_height"].unwrap(), + intrinsics["image_width"].unwrap(), + ) + transformed_points = pose.apply(vertices) + + # Sort the vertices by depth. + sort_order = jnp.argsort(transformed_points[..., 2]) + transformed_points_sorted = transformed_points[sort_order] + + # Project the vertices to the image plane. + projected_coords = jnp.rint( + b3d.utils.xyz_to_pixel_coordinates( + transformed_points_sorted, + intrinsics["fx"], + intrinsics["fy"], + intrinsics["cx"], + intrinsics["cy"], + ) + - 0.5 + ).astype(jnp.int32) + projected_coords = jnp.nan_to_num(projected_coords, nan=INVALID_IDX) + # handle the case where the projected coordinates are outside the image + projected_coords = jnp.where(projected_coords > 0, projected_coords, INVALID_IDX) + projected_coords = jnp.where( + projected_coords < jnp.array([image_height, image_width]), + projected_coords, + INVALID_IDX, + ) + + # Compute the unique pixel coordinates and the indices of the first vertex that hit that pixel. + unique_pixel_coordinates, unique_indices = jnp.unique( + projected_coords, + axis=0, + return_index=True, + size=projected_coords.shape[0], + fill_value=INVALID_IDX, + ) + + # Reorder the unique pixel coordinates, to the original point array indexing scheme + observed_pixel_coordinates_per_point = -jnp.ones((transformed_points.shape[0], 2), dtype=jnp.int32) + observed_pixel_coordinates_per_point = observed_pixel_coordinates_per_point.at[ + sort_order[unique_indices] + ].set(unique_pixel_coordinates) + + return PixelsPointsAssociation(observed_pixel_coordinates_per_point) + + def get_pixel_index(self, point_index): + return self.observed_pixel_indices[point_index] + +@Pytree.dataclass +class UniquePixelsImageKernel(genjax.ExactDensity): + rgbd_vertex_kernel: PixelRGBDDistribution + + def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping): + ppa = PixelsPointsAssociation.from_pose_intrinsics_vertices( + state["pose"], hyperparams["intrinsics"], state["vertices"] + ) + return jax.vmap( + jax.vmap( + lambda i, j: self.rgbd_vertex_kernel.sample( + key, + ) + ) + ) \ No newline at end of file diff --git a/src/b3d/chisight/gen3d/inference/inference.py b/src/b3d/chisight/gen3d/inference/inference.py index ab30b584..fdc482b3 100644 --- a/src/b3d/chisight/gen3d/inference/inference.py +++ b/src/b3d/chisight/gen3d/inference/inference.py @@ -127,8 +127,25 @@ def inference_step( return (trace, weight) -def get_trace_generated_during_inference(key, trace, pose, inference_hyperparams): - return propose_other_latents_given_pose(key, trace, pose, inference_hyperparams)[0] +def get_trace_generated_during_inference( + key, trace, pose, inference_hyperparams, + do_advance_time=True, + observed_rgbd=None, + just_return_trace=True +): + """ + Get the trace generated at pose `pose` with key `key` by inference_step, + when it was given `trace`, `do_advance_time`, `inference_hyperparams`, + and `observed_rgbd` as input. + """ + if do_advance_time: + assert observed_rgbd is not None + trace = advance_time(key, trace, observed_rgbd) + vals = propose_other_latents_given_pose(key, trace, pose, inference_hyperparams) + if just_return_trace: + return vals[0] + else: + return vals def maybe_swap_in_gt_pose( @@ -238,6 +255,12 @@ def propose_other_latents_given_pose(key, advanced_trace, pose, inference_hyperp k1, k2, k3, k4 = split(key, 4) trace = update_field(k1, advanced_trace, "pose", pose) + + sup = get_hypers(trace)["color_scale_kernel"].support + val = get_prev_state(advanced_trace)["color_scale"] + idx = jnp.argmin(jnp.abs(sup - val)) + newidx = jnp.minimum(idx+1, sup.shape[0]-1) + trace = update_field(k1, trace, "color_scale", sup[newidx]) k2a, k2b = split(k2) ( diff --git a/src/b3d/chisight/gen3d/inference/point_attribute_proposals.py b/src/b3d/chisight/gen3d/inference/point_attribute_proposals.py index 462413f4..7b0b3aac 100644 --- a/src/b3d/chisight/gen3d/inference/point_attribute_proposals.py +++ b/src/b3d/chisight/gen3d/inference/point_attribute_proposals.py @@ -86,14 +86,7 @@ def propose_a_points_attributes( return _propose_a_points_attributes( key, observed_rgbd_for_point=observed_rgbd_for_point, - latent_rgbd_for_point=jnp.array( - [ - 0.0, - 0.0, - 0.0, - new_state["pose"].apply(hyperparams["vertices"][vertex_index])[2], - ] - ), + latent_depth=new_state["pose"].apply(hyperparams["vertices"][vertex_index])[2], previous_color=prev_state["colors"][vertex_index], previous_visibility_prob=prev_state["visibility_prob"][vertex_index], previous_dnrp=prev_state["depth_nonreturn_prob"][vertex_index], @@ -107,7 +100,7 @@ def propose_a_points_attributes( def _propose_a_points_attributes( key, observed_rgbd_for_point, - latent_rgbd_for_point, + latent_depth, previous_color, previous_visibility_prob, previous_dnrp, @@ -121,7 +114,6 @@ def _propose_a_points_attributes( visibility_transition_kernel = hyperparams["visibility_prob_kernel"] color_kernel = hyperparams["color_kernel"] obs_rgbd_kernel = hyperparams["image_kernel"].get_rgbd_vertex_kernel() - latent_depth = latent_rgbd_for_point[3] intrinsics = hyperparams["intrinsics"] def score_attribute_assignment(color, visprob, dnrprob): @@ -138,6 +130,7 @@ def score_attribute_assignment(color, visprob, dnrprob): visibility_prob=visprob, depth_nonreturn_prob=dnrprob, intrinsics=intrinsics, + invisible_depth_nonreturn_prob=hyperparams["unexplained_depth_nonreturn_prob"], ) return ( visprob_transition_score diff --git a/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py b/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py index dc9a2f0f..3dfcfacc 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py @@ -30,7 +30,6 @@ class PixelColorDistribution(genjax.ExactDensity): Distribuiton args: - latent_rgb - rgb_scale - - visibility_prob Support: - An RGB value in [0, 1]^3. @@ -183,61 +182,61 @@ def logpdf_per_channel( return self._base_dist.log_prob(observed_color) -@Pytree.dataclass -class MixturePixelColorDistribution(PixelColorDistribution): - """A distribution that generates the color of a pixel from a mixture of a - truncated Laplace distribution centered around the latent color (inlier - branch) and a uniform distribution (occluded branch). The mixture is - controlled by the occluded_prob parameter. The support of the - distribution is ([0, 1]^3). - """ - - @property - def _occluded_dist(self) -> PixelColorDistribution: - return UniformPixelColorDistribution() - - @property - def _inlier_dist(self) -> PixelColorDistribution: - return TruncatedLaplacePixelColorDistribution() - - @property - def _mixture_dists(self) -> tuple[PixelColorDistribution, PixelColorDistribution]: - return (self._occluded_dist, self._inlier_dist) - - def _get_mix_ratio(self, visibility_prob: float) -> FloatArray: - return jnp.array((1 - visibility_prob, visibility_prob)) - - def sample( - self, - key: PRNGKey, - latent_color: FloatArray, - color_scale: FloatArray, - visibility_prob: float, - *args, - **kwargs, - ) -> FloatArray: - return PythonMixtureDistribution(self._mixture_dists).sample( - key, self._get_mix_ratio(visibility_prob), [(), (latent_color, color_scale)] - ) - - def logpdf_per_channel( - self, - observed_color: FloatArray, - latent_color: FloatArray, - color_scale: FloatArray, - visibility_prob: float, - *args, - **kwargs, - ) -> FloatArray: - # Since the mixture model class does not keep the per-channel information, - # we have to redefine this method to allow for testing - logprobs = [] - for dist, prob in zip( - self._mixture_dists, self._get_mix_ratio(visibility_prob) - ): - logprobs.append( - dist.logpdf_per_channel(observed_color, latent_color, color_scale) - + jnp.log(prob) - ) - - return jnp.logaddexp(*logprobs) +# @Pytree.dataclass +# class MixturePixelColorDistribution(PixelColorDistribution): +# """A distribution that generates the color of a pixel from a mixture of a +# truncated Laplace distribution centered around the latent color (inlier +# branch) and a uniform distribution (occluded branch). The mixture is +# controlled by the occluded_prob parameter. The support of the +# distribution is ([0, 1]^3). +# """ + +# @property +# def _occluded_dist(self) -> PixelColorDistribution: +# return UniformPixelColorDistribution() + +# @property +# def _inlier_dist(self) -> PixelColorDistribution: +# return TruncatedLaplacePixelColorDistribution() + +# @property +# def _mixture_dists(self) -> tuple[PixelColorDistribution, PixelColorDistribution]: +# return (self._occluded_dist, self._inlier_dist) + +# def _get_mix_ratio(self, visibility_prob: float) -> FloatArray: +# return jnp.array((1 - visibility_prob, visibility_prob)) + +# def sample( +# self, +# key: PRNGKey, +# latent_color: FloatArray, +# color_scale: FloatArray, +# visibility_prob: float, +# *args, +# **kwargs, +# ) -> FloatArray: +# return PythonMixtureDistribution(self._mixture_dists).sample( +# key, self._get_mix_ratio(visibility_prob), [(), (latent_color, color_scale)] +# ) + +# def logpdf_per_channel( +# self, +# observed_color: FloatArray, +# latent_color: FloatArray, +# color_scale: FloatArray, +# visibility_prob: float, +# *args, +# **kwargs, +# ) -> FloatArray: +# # Since the mixture model class does not keep the per-channel information, +# # we have to redefine this method to allow for testing +# logprobs = [] +# for dist, prob in zip( +# self._mixture_dists, self._get_mix_ratio(visibility_prob) +# ): +# logprobs.append( +# dist.logpdf_per_channel(observed_color, latent_color, color_scale) +# + jnp.log(prob) +# ) + +# return jnp.logaddexp(*logprobs) diff --git a/src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py b/src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py index e87c9504..a44a59bb 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py @@ -26,8 +26,8 @@ class PixelDepthDistribution(genjax.ExactDensity): Distribution args: - latent_depth - depth_scale - - visibility_prob - - depth_nonreturn_prob + - near + - far Support: depth value in [near, far], or DEPTH_NONRETURN_VAL. """ diff --git a/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py b/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py index 4de3d9da..fe06ca94 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py @@ -3,6 +3,7 @@ import genjax import jax import jax.numpy as jnp +from jax.random import split from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import PixelColorDistribution from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import PixelDepthDistribution from genjax import Pytree @@ -67,6 +68,60 @@ def logpdf( ) -> float: raise NotImplementedError +@Pytree.dataclass +class RGBDDist(genjax.ExactDensity): + """ + Distribution on an RGBD pixel. + + Args: + - latent_rgbd + - color_scale + - depth_scale + - depth_nonreturn_prob + - intrinsics + + Calls a color distribution and a "valid depth return" depth distribution to sample the pixel. + """ + color_distribution: PixelColorDistribution + depth_distribution: PixelDepthDistribution + + def sample( + self, + key: PRNGKey, + latent_rgbd: FloatArray, + color_scale: float, + depth_scale: float, + depth_nonreturn_prob: float, + intrinsics: dict, + ) -> FloatArray: + k1, k2, k3 = split(key, 3) + color = self.color_distribution.sample( + k1, latent_rgbd[:3], color_scale + ) + depth_if_return = self.depth_distribution.sample( + k2, latent_rgbd[3], depth_scale, intrinsics["near"], intrinsics["far"] + ) + depth = jnp.where( + jax.random.bernoulli(k3, depth_nonreturn_prob), + 0.0, + depth_if_return + ) + + return jnp.concatenate([color, depth]) + + def logpdf(self, obs, latent, color_scale, depth_scale, depth_nonreturn_prob, intrinsics): + color_logpdf = self.color_distribution.logpdf( + obs[:3], latent[:3], color_scale + ) + depth_logpdf_if_return = self.depth_distribution.logpdf( + obs[3], latent[3], depth_scale, intrinsics["near"], intrinsics["far"] + ) + depth_logpdf = jnp.where( + obs[3] == 0.0, + jnp.log(depth_nonreturn_prob), + jnp.log(1 - depth_nonreturn_prob) + depth_logpdf_if_return + ) + return color_logpdf + depth_logpdf @Pytree.dataclass class FullPixelRGBDDistribution(PixelRGBDDistribution): @@ -86,6 +141,14 @@ class FullPixelRGBDDistribution(PixelRGBDDistribution): inlier_depth_distribution: PixelDepthDistribution outlier_depth_distribution: PixelDepthDistribution + @property + def inlier_distribution(self): + return RGBDDist(self.inlier_color_distribution, self.inlier_depth_distribution) + + @property + def outlier_distribution(self): + return RGBDDist(self.outlier_color_distribution, self.outlier_depth_distribution) + def sample( self, key: PRNGKey, @@ -95,9 +158,18 @@ def sample( visibility_prob: float, depth_nonreturn_prob: float, intrinsics: dict, + depth_nonreturn_prob_for_invisible: float ) -> FloatArray: - # TODO: Implement this - return jnp.ones((4,)) * 0.5 + k1, k2, k3 = split(key, 3) + return jnp.where( + jax.random.bernoulli(k1, visibility_prob), + self.inlier_distribution.sample( + k2, latent_rgbd, color_scale, depth_scale, depth_nonreturn_prob, intrinsics + ), + self.outlier_distribution.sample( + k3, latent_rgbd, color_scale, depth_scale, depth_nonreturn_prob_for_invisible, intrinsics + ), + ) @jax.jit def logpdf( @@ -109,63 +181,15 @@ def logpdf( visibility_prob: float, depth_nonreturn_prob: float, intrinsics: dict, + invisible_depth_nonreturn_prob: float ) -> float: - total_log_prob = 0.0 - - is_depth_non_return = observed_rgbd[3] == 0.0 - - # Is visible - total_visible_log_prob = 0.0 - # color term - total_visible_log_prob += self.inlier_color_distribution.logpdf( - observed_rgbd[:3], latent_rgbd[:3], color_scale - ) - # depth term - total_visible_log_prob += jnp.where( - is_depth_non_return, - jnp.log(depth_nonreturn_prob), - jnp.log(1 - depth_nonreturn_prob) - + self.inlier_depth_distribution.logpdf( - observed_rgbd[3], - latent_rgbd[3], - depth_scale, - intrinsics["near"], - intrinsics["far"], + return jnp.logaddexp( + jnp.log(visibility_prob) + self.inlier_distribution.logpdf( + observed_rgbd, latent_rgbd, color_scale, depth_scale, depth_nonreturn_prob, intrinsics + ), + jnp.log(1 - visibility_prob) + self.outlier_distribution.logpdf( + observed_rgbd, latent_rgbd, color_scale, depth_scale, invisible_depth_nonreturn_prob, intrinsics ), - ) - - # Is not visible - total_not_visible_log_prob = 0.0 - # color term - outlier_color_log_prob = self.outlier_color_distribution.logpdf( - observed_rgbd[:3], - latent_rgbd[:3], - color_scale, - ) - outlier_depth_log_prob = self.outlier_depth_distribution.logpdf( - observed_rgbd[3], - latent_rgbd[3], - depth_scale, - intrinsics["near"], - intrinsics["far"], - ) - - total_not_visible_log_prob += outlier_color_log_prob - # depth term - total_not_visible_log_prob += jnp.where( - is_depth_non_return, - jnp.log(depth_nonreturn_prob), - jnp.log(1 - depth_nonreturn_prob) + outlier_depth_log_prob, - ) - - total_log_prob += jnp.logaddexp( - jnp.log(visibility_prob) + total_visible_log_prob, - jnp.log(1 - visibility_prob) + total_not_visible_log_prob, - ) - return jnp.where( - jnp.any(is_unexplained(latent_rgbd)), - outlier_color_log_prob + outlier_depth_log_prob, - total_log_prob, )