Skip to content

Commit

Permalink
Fix: pd_ee_pose controller ik, env step cpu sim output, utils, fetch …
Browse files Browse the repository at this point in the history
…pose
  • Loading branch information
arth-shukla committed Jan 22, 2024
1 parent a473a0d commit 3ea9c06
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 10 deletions.
10 changes: 6 additions & 4 deletions mani_skill2/agents/controllers/pd_ee_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scipy.spatial.transform import Rotation

from mani_skill2.utils.common import clip_and_scale_action
from mani_skill2.utils.sapien_utils import get_obj_by_name
from mani_skill2.utils.sapien_utils import get_obj_by_name, to_numpy, to_tensor
from mani_skill2.utils.structs.pose import vectorize_pose

from .base_controller import BaseController, ControllerConfig
Expand Down Expand Up @@ -64,15 +64,17 @@ def reset(self):

def compute_ik(self, target_pose, max_iterations=100):
# Assume the target pose is defined in the base frame
# TODO (arth): currently ik only supports cpu, so input/output is managed as such
# in future, need to change input/output processing per gpu implementation
result, success, error = self.pmodel.compute_inverse_kinematics(
self.ee_link_idx,
target_pose,
initial_qpos=self.articulation.get_qpos(),
target_pose.sp,
initial_qpos=to_numpy(self.articulation.get_qpos()).squeeze(0),
active_qmask=self.qmask,
max_iterations=max_iterations,
)
if success:
return result[self.joint_indices]
return to_tensor([result[self.joint_indices]])
else:
return None

Expand Down
2 changes: 1 addition & 1 deletion mani_skill2/agents/robots/fetch/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,4 +385,4 @@ def build_grasp_pose(approaching, closing, center):
def tcp_pose(self) -> Pose:
p = (self.finger1_link.pose.p + self.finger2_link.pose.p) / 2
q = (self.finger1_link.pose.q + self.finger2_link.pose.q) / 2
return sapien.Pose(p=p, q=q)
return Pose.create_from_pq(p=p, q=q)
8 changes: 6 additions & 2 deletions mani_skill2/envs/sapien_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
get_component_meshes,
merge_meshes,
)
from mani_skill2.utils.sapien_utils import get_obj_by_type, to_numpy, to_tensor
from mani_skill2.utils.sapien_utils import get_obj_by_type, to_numpy, to_tensor, unbatch
from mani_skill2.utils.structs.types import Array
from mani_skill2.utils.visualization.misc import observations_to_images, tile_images

Expand Down Expand Up @@ -627,7 +627,11 @@ def step(self, action: Union[None, np.ndarray, Dict]):
if self.num_envs == 1:
terminated = terminated[0]
reward = reward[0]
return obs, reward, terminated, False, info

if physx.is_gpu_enabled():
return obs, reward, terminated, torch.Tensor(False), info
else:
return unbatch(obs, reward, terminated.item(), False, to_numpy(info))

def step_action(self, action):
set_action = False
Expand Down
34 changes: 31 additions & 3 deletions mani_skill2/utils/sapien_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,17 @@ def to_tensor(array: Union[torch.Tensor, np.array, Sequence]):
elif get_backend_name() == "numpy":
if isinstance(array, np.ndarray):
return torch.from_numpy(array)
# TODO (arth): better way to address torch "UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow" ?
elif isinstance(array, list) and isinstance(array[0], np.ndarray):
return torch.from_numpy(np.array(array))
elif np.iterable(array):
return torch.Tensor(array)
else:
return torch.tensor(array)


def to_numpy(array: Union[Array, Sequence]):
def _to_numpy(array: Union[Array, Sequence]) -> np.ndarray:
if isinstance(array, (dict)):
return {k: to_numpy(v) for k, v in array.items()}
return {k: _to_numpy(v) for k, v in array.items()}
if isinstance(array, str):
return array
if torch is not None:
Expand All @@ -52,6 +56,30 @@ def to_numpy(array: Union[Array, Sequence]):
else:
return np.array(array)

def to_numpy(array: Union[Array, Sequence], dtype=None) -> np.ndarray:
array = _to_numpy(array)
if dtype is not None:
return array.astype(dtype)
return array

def _unbatch(array: Union[Array, Sequence]):
if isinstance(array, (dict)):
return {k: _unbatch(v) for k, v in array.items()}
if isinstance(array, str):
return array
if torch is not None:
if isinstance(array, torch.Tensor):
return array.squeeze(0)
if isinstance(array, np.ndarray):
if np.iterable(array) and array.shape[0] == 1:
return array.squeeze(0)
if isinstance(array, list):
if len(array) == 1:
return array[0]
return array

def unbatch(*args: Tuple[Union[Array, Sequence]]):
return tuple([_unbatch(x) for x in args])

def clone_tensor(array: Array):
if torch is not None and isinstance(array, torch.Tensor):
Expand Down

0 comments on commit 3ea9c06

Please sign in to comment.