Skip to content

Commit

Permalink
swap in previous pose at best location in array
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoliangWang committed Jan 11, 2025
1 parent 3765561 commit 90ff4af
Show file tree
Hide file tree
Showing 5 changed files with 26,335 additions and 46 deletions.
26,296 changes: 26,296 additions & 0 deletions analysis.ipynb

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"physion_hdf5",
)

for scenario in ['collide', 'drop', 'roll', 'dominoes', 'support', 'link', 'contain']:
for scenario in ['collide', 'link', 'drop', 'roll', 'dominoes', 'support', 'contain']:
scenario_path = join(hdf5_file_path, scenario + "_all_movies")
onlyhdf5 = [
f
Expand All @@ -26,8 +26,6 @@
viz_index = 0
for trial_index, hdf5_file in enumerate(onlyhdf5):
trial_name = hdf5_file[:-5]
if scenario != 'roll' or trial_index < 20:
continue
print(trial_index + 1, "\t", trial_name)
os.system(f"python /home/haoliangwang/b3d/test_b3d_tracking_hmm_single.py --scenario {scenario} --trial_name {trial_name} --recording_id {recording_id} --viz_index {viz_index}")
viz_index += FINAL_T+1
4 changes: 1 addition & 3 deletions src/b3d/chisight/gen3d/datawriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,7 @@ def write_json(pred_file, hyperparams, posterior_across_frames, save_path, scena
[
{
"x": pose._position[0].astype(float).item(),
"y": pose._position[1].astype(float).item()
if pose._position[1].astype(float).item() >= 0
else 0,
"y": pose._position[1].astype(float).item(),
"z": pose._position[2].astype(float).item(),
}
for pose in poses[0]
Expand Down
45 changes: 8 additions & 37 deletions src/b3d/chisight/gen3d/inference/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
import itertools
# import jax.scipy.stats as ss

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -48,25 +49,21 @@ def c2f_step(
proposed_poses, log_q_poses = jax.vmap(
propose_pose, in_axes=(0, None, None, None)
)(pose_generation_keys, trace, addr, pose_proposal_args)
# jax.debug.print("before proposed_poses: {v}", v=proposed_poses)
# jax.debug.print("rank before: {v}", v=ss.rankdata(log_q_poses))
# jax.debug.print("score before: {v}", v=log_q_poses)

proposed_poses, log_q_poses = maybe_swap_in_previous_pose(
proposed_poses, log_q_poses, trace, addr, include_previous_pose, pose_proposal_args
)
# proposed_poses, log_q_poses = filter_floor_penetration(
# proposed_poses, log_q_poses, trace, addr, pose_proposal_args
# )
# jax.debug.print("after proposed_poses: {v}", v=proposed_poses)
# jax.debug.print("rank after: {v}", v=ss.rankdata(log_q_poses))
# jax.debug.print("score after: {v}", v=log_q_poses)

def update_and_get_scores(key, proposed_pose, trace, addr):
key, subkey = split(key)
updated_trace = update_field(subkey, trace, addr, proposed_pose)
return updated_trace, updated_trace.get_score()

param_generation_keys = split(k2, inference_hyperparams.n_poses)
# _, p_scores = jax.lax.map(
# lambda x: update_and_get_scores(x[0], x[1], trace, addr),
# (param_generation_keys, proposed_poses),
# )
_, p_scores = jax.vmap(update_and_get_scores, in_axes=(0, 0, None, None))(
param_generation_keys, proposed_poses, trace, addr
)
Expand Down Expand Up @@ -122,8 +119,9 @@ def maybe_swap_in_previous_pose(
):
previous_pose = get_prev_state(trace)[addr]
log_q = assess_previous_pose(trace, addr, previous_pose, pose_proposal_args)
chosen_index = log_q_poses.argmin()
proposed_poses = jax.tree.map(
lambda x, y: x.at[0].set(jnp.where(include_previous_pose, y, x[0])),
lambda x, y: x.at[chosen_index].set(jnp.where(include_previous_pose, y, x[chosen_index])),
proposed_poses,
previous_pose,
)
Expand All @@ -139,33 +137,6 @@ def maybe_swap_in_previous_pose(
return proposed_poses, log_q_poses


# def filter_floor_penetration(
# proposed_poses, log_q_poses, trace, addr, pose_proposal_args
# ):
# previous_pose = get_prev_state(trace)[addr]
# log_q = assess_previous_pose(trace, addr, previous_pose, pose_proposal_args)

# def replace_if_not_above_zero(proposed_poses: Pose, prev_pose: Pose, log_q_poses, log_q) -> Pose:
# mask = proposed_poses._position[:, 1] > 0
# broadcasted_pos = jnp.broadcast_to(prev_pose._position, proposed_poses._position.shape)
# broadcasted_quat = jnp.broadcast_to(prev_pose._quaternion, proposed_poses._quaternion.shape)
# new_position = jnp.where(mask[:, None], proposed_poses._position, broadcasted_pos)
# new_quaternion = jnp.where(mask[:, None], proposed_poses._quaternion, broadcasted_quat)

# broadcasted_log_q = jnp.broadcast_to(log_q, log_q_poses.shape)
# log_q_poses = jnp.where(mask[:, None], log_q_poses, broadcasted_log_q)
# return Pose(new_position, new_quaternion), log_q_poses

# proposed_poses, log_q_poses = replace_if_not_above_zero(
# proposed_poses,
# previous_pose,
# log_q_poses,
# log_q
# )

# return proposed_poses, log_q_poses


def assess_previous_pose(advanced_trace, addr, previous_pose, args):
"""
Returns the log proposal density of the given pose, conditional upon the previous pose.
Expand Down
32 changes: 29 additions & 3 deletions test_combine_json.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -11,6 +11,32 @@
"from os.path import isfile, join"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'pilot_it2_collision_yeet_box_0029.hdf5'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scenario_path = '/home/haoliangwang/data/physion_hdf5/collide_all_movies/'\n",
"onlyjson = [\n",
" f\n",
" for f in listdir(scenario_path)\n",
" if isfile(join(scenario_path, f)) and join(scenario_path, f).endswith(\".hdf5\")\n",
"]\n",
"onlyjson[34]"
]
},
{
"cell_type": "code",
"execution_count": 7,
Expand Down Expand Up @@ -193,7 +219,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "gpu",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -207,7 +233,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.9.2"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 90ff4af

Please sign in to comment.