Skip to content

Commit

Permalink
Merge branch 'haoliang_ipe_hmm' into gen3d
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoliangWang authored Dec 16, 2024
2 parents 4273a6b + a014533 commit 8acfa48
Show file tree
Hide file tree
Showing 45 changed files with 1,316,238 additions and 9,980 deletions.
10 changes: 4 additions & 6 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
*.egg-info
*.pyc
*.png
assets/shared_data_bucket/*
assets/test_results/*
!assets/test_results/README.md
assets/bop/*
assets/nuscenes/*
assets/ycbineoat/*
assets/*
**/.ipynb_checkpoints
**/**.mp4
**/**.npz
Expand All @@ -25,3 +20,6 @@ __pycache__/
*.py[cod]
docs/*
test_results/
.pixi/*
saved_traces/*
.DS_Store
445 changes: 445 additions & 0 deletions debug.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion demos/graphics_edits_demo/demo_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# "shared_data_bucket/input_data/shout_on_desk.r3d.video_input.npz")
"shared_data_bucket/input_data/desk_ramen2_spray1.r3d.video_input.npz",
)
video_input = b3d.VideoInput.load(path)
video_input = b3d.io.VideoInput.load(path)


data, object_library = pickle.load(open("demo_data.dat", "rb"))
Expand Down
30 changes: 15 additions & 15 deletions demos/tracking_online_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import partial

import b3d
import b3d.bayes3d as bayes3d
import genjax
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -56,7 +57,7 @@
color_error, depth_error = (60.0, 0.01)
inlier_score, outlier_prob = (5.0, 0.00001)
color_multiplier, depth_multiplier = (10000.0, 500.0)
model_args = b3d.ModelArgs(
model_args = bayes3d.ModelArgs(
color_error,
depth_error,
inlier_score,
Expand Down Expand Up @@ -104,7 +105,7 @@

# Creates renderer and generative model.
renderer = b3d.Renderer(image_width, image_height, fx, fy, cx, cy, near, far)
model = b3d.model_multiobject_gl_factory(renderer)
model = bayes3d.model_multiobject_gl_factory(renderer)

# Arguments of the generative model.
# These control the inlier / outlier decision boundary for color error and depth error.
Expand Down Expand Up @@ -154,14 +155,14 @@
@partial(jax.jit, static_argnames=["addressses"])
def enumerative_proposal(trace, addressses, key, all_deltas):
addr = addressses.const[0]
current_pose = trace[addr]
current_pose = trace.get_choices()[addr]
for i in range(len(all_deltas)):
test_poses = current_pose @ all_deltas[i]
potential_scores = b3d.enumerate_choices_get_scores(
trace, jax.random.PRNGKey(0), addressses, test_poses
trace, addressses, test_poses
)
current_pose = test_poses[potential_scores.argmax()]
trace = b3d.update_choices(trace, key, addressses, current_pose)
trace = b3d.update_choices(trace, addressses, current_pose)
return trace, key


Expand All @@ -175,7 +176,7 @@ def enumerative_proposal(trace, addressses, key, all_deltas):
START_T = 0
trace, _ = importance_jit(
jax.random.PRNGKey(0),
genjax.choice_map(
genjax.ChoiceMap.d(
dict(
[
("camera_pose", Pose.identity()),
Expand All @@ -191,7 +192,7 @@ def enumerative_proposal(trace, addressses, key, all_deltas):
(jnp.arange(4), model_args, object_library),
)
# Visualize trace
b3d.rerun_visualize_trace_t(trace, 0)
bayes3d.rerun_visualize_trace_t(trace, 0)
key = jax.random.PRNGKey(0)

inference_data_over_time = []
Expand All @@ -202,22 +203,21 @@ def enumerative_proposal(trace, addressses, key, all_deltas):
)
):
# Constrain on new RGB and Depth data.
trace = b3d.update_choices_jit(
trace = b3d.update_choices(
trace,
key,
genjax.Pytree.const(["observed_rgb_depth"]),
(rgbs_resized[T_observed_image], xyzs[T_observed_image, ..., 2]),
)
# Enumerate, score, and update camera pose
trace, key = enumerative_proposal(
trace, genjax.Pytree.const(["camera_pose"]), key, all_deltas
trace, genjax.Pytree.const(("camera_pose",)), key, all_deltas
)
for i in range(1, len(object_library.ranges)):
# Enumerate, score, update each objects pose
trace, key = enumerative_proposal(
trace, genjax.Pytree.const([f"object_pose_{i}"]), key, all_deltas
trace, genjax.Pytree.const((f"object_pose_{i}",)), key, all_deltas
)
b3d.rerun_visualize_trace_t(trace, T_observed_image)
bayes3d.rerun_visualize_trace_t(trace, T_observed_image)
inference_data_over_time.append(
(
b3d.get_poses_from_trace(trace),
Expand Down Expand Up @@ -286,7 +286,7 @@ def enumerative_proposal(trace, addressses, key, all_deltas):
next_object_id = len(object_library.ranges) - 1
trace = trace.update(
key,
genjax.choice_map(
genjax.ChoiceMap.d(
{
f"object_{next_object_id}": next_object_id, # Add identity of new object to trace.
f"object_pose_{next_object_id}": trace["camera_pose"]
Expand All @@ -302,7 +302,7 @@ def enumerative_proposal(trace, addressses, key, all_deltas):
(jnp.arange(4), model_args, object_library)
),
)[0]
b3d.rerun_visualize_trace_t(trace, REAQUISITION_T)
bayes3d.rerun_visualize_trace_t(trace, REAQUISITION_T)
inference_data_over_time.append(
(
b3d.get_poses_from_trace(trace),
Expand Down Expand Up @@ -332,7 +332,7 @@ def enumerative_proposal(trace, addressses, key, all_deltas):
(jnp.arange(4), model_args, object_library)
),
)[0]
b3d.rerun_visualize_trace_t(trace, t)
bayes3d.rerun_visualize_trace_t(trace, t)
rr.set_time_sequence("frame", t)

rgb_inliers, rgb_outliers = b3d.get_rgb_inlier_outlier_from_trace(trace)
Expand Down
49 changes: 49 additions & 0 deletions physion_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# FLEX_MASSES = {
# "triangular_prism": 0.46650617387911586,
# "torus": 0.13529902128268745,
# "sphere": 0.5058756912820712,
# "pyramid": 0.3749996282020332,
# "platonic": 0.3170116821661435,
# "pipe": 0.334848007433621,
# "pentagon": 0.5944103105417653,
# "octahedron": 0.4999997582702834,
# "dumbbell": 0.8701171939416522,
# "cylinder": 0.7653673769319456,
# "cube": 1.0,
# "cone": 0.2582547675685456,
# "bowl": 0.0550834555398029,
# }

# dynamic_friction = 0.25
# static_friction = 0.4
# bounciness = 0.4
# density = 5

# TRIMESH_MESHES = {}

# for key in FLEX_MASSES.keys():
# TRIMESH_MESHES[key] = None

# TRIMESH_MESHES[key] = trimesh.load(os.path.join(mesh_folder, key + ".obj"))
# mesh = TRIMESH_MESHES[record.name]
# vertices = np.copy(mesh.vertices)
# faces = mesh.faces

# vertices[:, 0] *= scale_factor["x"]
# vertices[:, 1] *= scale_factor["y"]
# vertices[:, 2] *= scale_factor["z"]
# scaled_mesh_volume = trimesh.Trimesh(vertices, faces).volume
# mass = scaled_mesh_volume * density

# commands.extend(
# [
# {"$type": "set_mass", "mass": mass, "id": object_id},
# {
# "$type": "set_physic_material",
# "dynamic_friction": dynamic_friction,
# "static_friction": static_friction,
# "bounciness": bounciness,
# "id": object_id,
# },
# ]
# )
Loading

0 comments on commit 8acfa48

Please sign in to comment.