diff --git a/mujoco_robotics_environments/config/dataset/default.yaml b/mujoco_robotics_environments/config/dataset/default.yaml index ee3d92f..1854387 100644 --- a/mujoco_robotics_environments/config/dataset/default.yaml +++ b/mujoco_robotics_environments/config/dataset/default.yaml @@ -1,3 +1,3 @@ -num_episodes: 100 +num_episodes: 10 max_steps: 10 max_episodes_per_file: 10 diff --git a/mujoco_robotics_environments/hf_scripts/hf_data_upload.py b/mujoco_robotics_environments/hf_scripts/hf_data_upload.py new file mode 100644 index 0000000..bc1b398 --- /dev/null +++ b/mujoco_robotics_environments/hf_scripts/hf_data_upload.py @@ -0,0 +1,25 @@ +import os +import glob +import tarfile + +import tensorflow_datasets as tfds +from huggingface_hub import HfApi + +LOCAL_FILEPATH="/home/peter/Code/mujoco_robotics_environments/mujoco_robotics_environments/data" + +if __name__=="__main__": + + for folder_name in os.listdir(LOCAL_FILEPATH): + if os.path.isdir(os.path.join(LOCAL_FILEPATH, folder_name)): + OUTPUT_FILENAME = folder_name + '.tar.gz' + with tarfile.open(OUTPUT_FILENAME, "w:xz") as tar: + tar.add(os.path.join(LOCAL_FILEPATH, folder_name), arcname=".") + + # upload to huggingface + api = HfApi() + api.upload_file( + repo_id="peterdavidfagan/transporter_networks_mujoco", + repo_type="dataset", + path_or_fileobj=f"./{OUTPUT_FILENAME}", + path_in_repo=f"/{OUTPUT_FILENAME}", + ) \ No newline at end of file diff --git a/mujoco_robotics_environments/tasks/rearrangement.py b/mujoco_robotics_environments/tasks/rearrangement.py index 0de5844..1da8d6d 100644 --- a/mujoco_robotics_environments/tasks/rearrangement.py +++ b/mujoco_robotics_environments/tasks/rearrangement.py @@ -1,6 +1,6 @@ """Mujoco environment for interactive task learning.""" from abc import abstractmethod -from typing import Optional +from typing import Optional, Dict from copy import deepcopy from pathlib import Path import random @@ -53,7 +53,7 @@ def __init__( print("Viewer not requested, running headless.") else: self.has_viewer = self._cfg.viewer - # TODO: read arena from xml + # create arena self._arena = empty.Arena() table = Rectangle( @@ -148,6 +148,8 @@ def __init__( ) self.prop_random_state = np.random.RandomState(seed=self._cfg.task.initializers.seed) self.prop_place_random_state = np.random.RandomState(seed=self._cfg.task.initializers.seed+1) + + self.mode = None def close(self) -> None: if self.passive_view is not None: @@ -276,6 +278,9 @@ def reset(self) -> dm_env.TimeStep: # set the initial eef pose to home self.eef_home_pose = self._robot.eef_pose.copy() self.eef_home_pose[0] -= 0.1 # move up 10cm back so it is out of view of camera + + # start in pick mode + self.mode="pick" return dm_env.TimeStep( step_type=dm_env.StepType.FIRST, @@ -284,36 +289,43 @@ def reset(self) -> dm_env.TimeStep: observation=self._compute_observation(), ) - def step(self, poses) -> dm_env.TimeStep: - """Updates the environment according to the action and returns a - `TimeStep`. + def step(self, action_dict) -> dm_env.TimeStep: + """ + Updates the environment according to the action and returns a `TimeStep`. + """ + + if self.mode == "pick": + self.pick(action_dict['pose']) + self.mode="place" + else: + self.place(action_dict['pose']) + self.mode="pick" + + return dm_env.TimeStep( + step_type=dm_env.StepType.MID, + reward=0.0, + discount=0.0, + observation=self._compute_observation(), + ) + + def pick(self, pose): """ - # split into pick/place poses - pick_pose = poses[:7] - place_pose = poses[7:] - - # parse action into poses - pick_pos = pick_pose[:3] - pick_pos[-1] = 0.575 # hardcode for now - pick_orn = pick_pose[3:] - place_pos = place_pose[:3] - place_pos[-1] = 0.575 # hardcode for now - place_orn = place_pose[3:] - - # move arm to pre pick position - pre_pick = pick_pos.copy() + Scripted pick behaviour. + """ + pose[2] = 0.575 # hardcode for now :( + pre_pick = pose.copy() pre_pick[2] = 0.9 self._robot.arm_controller.set_target( - position=pre_pick, + position=pre_pick[:3], velocity=np.zeros(3), - quat=pick_orn, + quat=pre_pick[3:], angular_velocity=np.zeros(3), ) if not self._robot.run_controller(2.0): raise RuntimeError("Failed to move arm to pre pick position") # move arm to pick position - self._robot.arm_controller.set_target(position=pick_pos) + self._robot.arm_controller.set_target(position=pose[:3]) if not self._robot.run_controller(2.0): raise RuntimeError("Failed to move arm to pick position") @@ -324,22 +336,38 @@ def step(self, poses) -> dm_env.TimeStep: # move arm to pre grasp position - self._robot.arm_controller.set_target(position=pre_pick) + self._robot.arm_controller.set_target(position=pre_pick[:3]) if not self._robot.run_controller(2.0): raise RuntimeError("Failed to move arm to pre grasp position") + # move arm to home position + quat = np.zeros(4,) + rot_mat = R.from_euler('xyz', [0, 180, 0], degrees=True).as_matrix().flatten() + mujoco.mju_mat2Quat(quat, rot_mat) + self._robot.arm_controller.set_target( + position=self.eef_home_pose, + quat=quat, + ) + if not self._robot.run_controller(2.0): + raise RuntimeError("Failed to move arm to home position") + + def place(self, pose): + """ + Scripted place behaviour. + """ + pose[2] = 0.575 # hardcode for now :( # move arm to pre place position - pre_place = place_pos.copy() + pre_place = pose.copy() pre_place[2] = 0.9 self._robot.arm_controller.set_target( - position=pre_place, - quat=place_orn, + position=pre_place[:3], + quat=pre_place[3:], ) if not self._robot.run_controller(2.0): raise RuntimeError("Failed to move arm to pre place position") # move arm to place position - self._robot.arm_controller.set_target(position=place_pos) + self._robot.arm_controller.set_target(position=pose[:3]) if not self._robot.run_controller(2.0): raise RuntimeError("Failed to move arm to place position") @@ -349,7 +377,7 @@ def step(self, poses) -> dm_env.TimeStep: raise RuntimeError("Failed to open gripper") # move arm to pre place position - self._robot.arm_controller.set_target(position=pre_place) + self._robot.arm_controller.set_target(position=pre_place[:3]) if not self._robot.run_controller(2.0): raise RuntimeError("Failed to move arm to pre place position") @@ -364,13 +392,6 @@ def step(self, poses) -> dm_env.TimeStep: if not self._robot.run_controller(2.0): raise RuntimeError("Failed to move arm to home position") - return dm_env.TimeStep( - step_type=dm_env.StepType.MID, - reward=0.0, - discount=0.0, - observation=self._compute_observation(), - ) - def observation_spec(self) -> dm_env.specs.Array: """Returns the observation spec.""" # get shape of overhead camera @@ -381,12 +402,13 @@ def observation_spec(self) -> dm_env.specs.Array: "overhead_camera/rgb": dm_env.specs.Array(shape=camera_shape, dtype=np.float32), } - def action_spec(self) -> dm_env.specs.Array: + def action_spec(self) -> Dict[str, dm_env.specs.Array]: """Returns the action spec.""" - return dm_env.specs.Array( - shape=(14,), - dtype=np.float32, - ) + return { + "pose": dm_env.specs.Array(shape=(7,), dtype=np.float64), # [x, y, z, qx, qy, qz, qw] + "pixel_coords": dm_env.specs.Array(shape=(2,), dtype=np.int64), # [u, v] + "gripper_rot": dm_env.specs.Array(shape=(1,), dtype=np.float64), + } def _compute_observation(self) -> np.ndarray: """Returns the observation.""" @@ -478,7 +500,7 @@ def world_2_pixel(self, camera_name, coords): image_coords = intrinsics @ camera_coords image_coords = image_coords[:2] / image_coords[2] - return image_coords + return jnp.round(image_coords).astype(jnp.int32) def get_camera_params(self, camera_name): """Returns the camera parameters.""" @@ -486,7 +508,31 @@ def get_camera_params(self, camera_name): extrinsics = self._get_camera_extrinsics(camera_name) return {"intrinsics": intrinsics, "extrinsics": extrinsics} + def get_camera_metadata(self): + """Returns the camera parameters.""" + intrinsics = self._get_camera_intrinsics("overhead_camera/overhead_camera", self.overhead_camera_height, self.overhead_camera_width) + extrinsics = self._get_camera_extrinsics("overhead_camera/overhead_camera") + # convert rotation to quat + quat = R.from_matrix(extrinsics[:3, :3]).as_quat() + return { + "intrinsics": { + "fx": intrinsics[0, 0], + "fy": intrinsics[1, 1], + "cx": intrinsics[0, 2], + "cy": intrinsics[1, 2], + }, + "extrinsics": { + "x": extrinsics[3, 0], + "y": extrinsics[3, 1], + "z": extrinsics[3, 2], + "qx": quat[0], + "qy": quat[1], + "qz": quat[2], + "qw": quat[3], + }} + def prop_pick(self, prop_id): + """Returns pick pose for a given prop.""" # get prop pose information obj_pose = self.props_info[prop_id]["position"] obj_quat = self.props_info[prop_id]["orientation"] @@ -502,7 +548,7 @@ def prop_pick(self, prop_id): return np.concatenate([obj_pose, grasp_quat]) def prop_place(self, prop_id, min_pose=None, max_pose=None): - """Returns collisiin free place pose for a given prop.""" + """Returns collision free place pose for a given prop.""" # don't want to mess with actual physics dummy_physics = deepcopy(self._physics) prop_name = f"{prop_id}/{prop_id}" @@ -601,39 +647,7 @@ def random_pick_and_place(self): return pick_pose, place_pose - - # def generate_task(self) -> RearrangementTask: - # """Generate a task given current environment state""" - # while True: - # # sample subject - # s_quant = random.choice(GenRefExp.quant) - # s_refexp, s_obj = GenRefExp.generate(s_quant, self.props_info) - - # # sample spatial relationship - # rel, rel_fun = EvalSpatial.sample() - - # props = GenRefExp.filter_props(self.props_info, s_obj, rel_fun) - # if props: - # o_quant = random.choice(GenRefExp.quant) - # o_refexp, _ = GenRefExp.generate(o_quant, props) - - # task = RearrangementTask( - # instruction=f"Move {s_refexp} {rel} {o_refexp}", - # s_refexp=s_refexp, - # rel=rel, - # o_refexp=o_refexp, - # ) - # if task.s_refexp_lf is not None and task.o_refexp_lf is not None: - # return task - - # def update_task(self, task: RearrangementTask) -> None: - # self.task = task - - # def success_score(self) -> float: - # """get the succes score of task. Implementation based on RAVENS""" - # raise NotImplementedError - def sort_colours(self): """Generates pick/place action for sorting blue/green coloured blocks""" colours = ["blue", "green"] @@ -671,7 +685,7 @@ def sort_colours(self): continue return False, None, None - + if __name__=="__main__": # clear hydra global state to avoid conflicts with other hydra instances hydra.core.global_hydra.GlobalHydra.instance().clear() @@ -690,52 +704,23 @@ def sort_colours(self): # instantiate color separation task env = RearrangementEnv(COLOR_SEPARATING_CONFIG, True) _, _, _, obs = env.reset() - - ### checking that domain model and prop info are consistent - ### TODO move to test - # pro_info = env.props_info - # print(pro_info) - # domain_model = props_info2domain_model(pro_info) - # print(domain_model) - # props_info_new = domain_model2props_info(domain_model) - # print(props_info_new) - # # generate task - # task = env.generate_task() - # print(f"Task: {task.instruction}") - - # s_entities = Semantics.eval_refexp(task.s_refexp_lf, domain_model).entities - # o_entities = Semantics.eval_refexp(task.o_refexp_lf, domain_model).entities - - # print(s_entities) - # print(o_entities) - - # perform reference resolution with a domain model - - # complete singe loop of sorting colours + while env.sort_colours()[0]: _, pick_pose, place_pose = env.sort_colours() + + pick_action = { + "pose": pick_pose, + "pixel_coords": env.world_2_pixel("overhead_camera/overhead_camera", pick_pose[:3]), + "gripper_rot": None, + } + + place_action = { + "pose": place_pose, + "pixel_coords": env.world_2_pixel("overhead_camera/overhead_camera", place_pose[:3]), + "gripper_rot": None, + } - # TODO: delete below commented code when done debugging - #pixel = env.world_2_pixel("overhead_camera/overhead_camera", pick_pose[:3]) - #print("pixel", pixel) - # overlay pixel on image - #from PIL import Image, ImageDraw - #print(obs["overhead_camera/rgb"].shape) - #image = Image.fromarray(obs["overhead_camera/rgb"]) - #print(image.size) - #draw = ImageDraw.Draw(image) - #draw.ellipse((pixel[0]-10, pixel[1]-10, pixel[0]+10, pixel[1]+10), fill=(255, 0, 0, 0)) - #image.show() - # convert pixel coords back to world coords - #world_coords = env.pixel_2_world("overhead_camera/overhead_camera", pixel) - #print("world", world_coords) - - _, _, _, obs = env.step(np.concatenate([pick_pose, place_pose])) + _, _, _, obs = env.step(pick_action) + _, _, _, obs = env.step(place_action) env.close() - # close passive viewer and check obs render - #rgb = obs["overhead_camera/rgb"] - #depth = obs["overhead_camera/depth"] - #from PIL import Image - #img = Image.fromarray(depth * 255).convert("L") - #img.show() diff --git a/mujoco_robotics_environments/tasks/utils.py b/mujoco_robotics_environments/tasks/utils.py deleted file mode 100644 index b07f79e..0000000 --- a/mujoco_robotics_environments/tasks/utils.py +++ /dev/null @@ -1,301 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Callable, Iterable -from collections import defaultdict -import random - -import inflect -from logic_toolkit import RefExpParser, Entity, RefExp, DomainModel, Symbol, Denotation - -from mujoco_robotics_environments.models.llm.refexp_parser import RefExpLLMParser - - -@dataclass -class RearrangementTask: - - def __init__(self, - instruction: str, - s_refexp: str, - rel: str, - o_refexp: str,): - """Container for a rearrangment task specification""" - self.instruction = instruction - self.s_refexp_str = s_refexp - self.rel = rel - self.o_refexp_str = o_refexp - - # explicitly add logical expressions - self.lang_parser = RefExpLLMParser() - self.logic_parser = RefExpParser() - - self.s_refexp_lf = self.safe_parse(self.s_refexp_str) - self.o_refexp_lf = self.safe_parse(self.o_refexp_str) - - def safe_parse(self, refexp: str) -> RefExp|None: - """parse the refexp and return None if parsing fails.""" - - try: - return self.logic_parser(self.lang_parser(refexp)) - except: - return None - - def __str__(self): - - return f"Task: {self.instruction}" - - def __repr__(self) -> str: - - return f"Task: {self.instruction}" - - -class PropEntity(Entity): - """entity with metada information from the environment""" - - def __init__(self, - name: str, - position: Iterable[float], - orientation: Iterable[float], - rgba: Iterable[float], - bbox: Iterable[int]): - super().__init__(name) - self.position = position - self.orientation = orientation - self.rgba = rgba - self.bbox = bbox - - def __str__(self): - - return f"{self.name} with bbox {self.bbox}" - - def __repr__(self) -> str: - - return f"{self.name} with bbox {self.bbox}" - - -class GenRefExp: - """Utilities to generate reference expressions for objects in a scene""" - - quant = [ - "exists","unique", "every", - ] - - @staticmethod - def get_symbols(symbols:list[str], sample:bool =True) -> list[str]: - """get a subset of symbols if sample is True, else return all symbols.""" - if sample: - start = random.randint(0, len(symbols)-1) - end = random.randint(1, len(symbols)-start) - sub_symbols = symbols[start:start + end] - # add object if no shape is given for grammaticality - if symbols[-1] != sub_symbols[-1]: - sub_symbols.append("object") - symbols = sub_symbols - return symbols - - @staticmethod - def get_same(props_info: dict, target_props:list[str]) -> list[str]: - """get object_names of the objects sharing the same props""" - entities = [] - for entity, prop_info in props_info.items(): - # adding "objects" to handle props with and without "object" - object_props = GenRefExp.get_symbols(prop_info['symbols'], sample=False)+["object"] - target_props = target_props + ["object"] - if all(prop in object_props for prop in target_props): - entities.append(entity) - - return entities - - @staticmethod - def filter_props(props_info:dict, targets:list[str], filter: Callable) -> dict: - """filter our props by the filter function""" - props_filter = {} - for entity, prop_info in props_info.items(): - if entity in targets: - continue - - if all(not filter(prop_info['bbox'], props_info[target]['bbox']) for target in targets): - props_filter[entity] = prop_info - - return props_filter - - @staticmethod - def make_exists(props_info:dict, sample:bool=True) -> tuple[str, list[str]]: - """create exists (a/an) description""" - entity = random.choice(list(props_info.keys())) - symbols = GenRefExp.get_symbols(props_info[entity]['symbols'], sample) - return inflect.engine().a(" ".join(symbols)), [entity] - - @staticmethod - def make_every(props_info:dict, sample=True) -> tuple[str, list[str]]: - """create every description""" - entity = random.choice(list(props_info.keys())) - symbols = GenRefExp.get_symbols(props_info[entity]['symbols'], sample) - ref_entities = GenRefExp.get_same(props_info, symbols) - - return f"every {' '.join(symbols)}", ref_entities - - @staticmethod - def make_unique(props_info:dict, sample:bool=True) -> tuple[str, list[str]]: - """create unique (the N) description""" - entity = random.choice(list(props_info.keys())) - symbols = GenRefExp.get_symbols(props_info[entity]['symbols'], sample) - ref_entities = GenRefExp.get_same(props_info, symbols) - - article = f"the {inflect.engine().number_to_words(len(ref_entities))}" - - if len(ref_entities) > 1: # plural - symbols[-1] = inflect.engine().plural_noun(symbols[-1]) - - return f"{article} {' '.join(symbols)}", ref_entities - - @staticmethod - def generate(quant:str, - props_info: dict, - sample:bool=True) -> tuple[str, list[str]]: - - match quant: - - case "exists": - return GenRefExp.make_exists(props_info, sample) - case "unique": - return GenRefExp.make_unique(props_info, sample) - case "every": - return GenRefExp.make_every(props_info, sample) - case _ : - raise NotImplementedError(f"{quant} not implemented") - - -class EvalSpatial: - """Utilities to evaluate spatial relationships between objects""" - - rel_str2sym = { - "to the left of": Symbol("left", 2), - "to the right of": Symbol("right", 2), - "in front of": Symbol("front", 2), - "behind": Symbol("behind", 2), - "above": Symbol("above", 2), - } - - def is_left(s_bbox: Iterable[int],o_bbox:Iterable[int], threshold:float=0) -> bool: - return s_bbox[2] + threshold < o_bbox[0] - - def is_right(s_bbox: Iterable[int], o_bbox:Iterable[int], threshold:float=0) -> bool: - return o_bbox[2] + threshold < s_bbox[0] - - def is_front(s_bbox: Iterable[int], o_bbox:Iterable[int], threshold:float=0) -> bool: - return s_bbox[3] + threshold < o_bbox[1] - - def is_behind(s_bbox: Iterable[int], o_bbox:Iterable[int], threshold:float=0) -> bool: - return o_bbox[3] + threshold < s_bbox[1] - - def is_above(s_bbox: Iterable[int], o_bbox:Iterable[int], threshold:float=0) -> bool: - - min_x1, min_y1, max_x1, max_y1 = s_bbox - min_x2, min_y2, max_x2, max_y2 = o_bbox - - if (min_x1 >= min_x2 + threshold and - min_y1 >= min_y2 + threshold and - max_x1 <= max_x2 + threshold and - max_y1 <= max_y2 + threshold): - return True - else: - return False - - def get_rel_func(name:str) -> Callable[[Iterable[int], Iterable[int], float], bool]: - - assert name in EvalSpatial.rel_str2sym.keys(), "not supported relationship" - - match name: - - case "to the left of": - return EvalSpatial.is_left - - case "to the right of": - return EvalSpatial.is_right - - case "in front of": - return EvalSpatial.is_front - - case "behind": - return EvalSpatial.is_behind - - case "above": - return EvalSpatial.is_above - - case _: - raise NotImplementedError(f"{name} not supported") - - @staticmethod - def sample() -> tuple[str,Callable[[Iterable[int], Iterable[int], float], bool]]: - - rel = random.choice(list(EvalSpatial.rel_str2sym.keys())) - fun = EvalSpatial.get_rel_func(rel) - - return rel, fun - - @staticmethod - def eval(name:str, - s_entity: PropEntity, - o_entity: PropEntity, - threshold:float=0) -> bool: - """evaluate the spatial relationship between two objects""" - rel_func = EvalSpatial.get_rel_func(name) - return rel_func(s_entity.bbox, o_entity.bbox, threshold) - - -def props_info2domain_model(props_info:dict) -> DomainModel: - """""" - - entities = [PropEntity(name, - prop['position'], - prop['orientation'], - prop['rgba'], - prop['bbox']) for name, prop in props_info.items()] - - extensions = defaultdict(set) - - for entity_name, prop in props_info.items(): - - entity = PropEntity(entity_name, - prop['position'], - prop['orientation'], - prop['rgba'], - prop['bbox']) - - for symbol_name in prop['symbols']: - symbol = Symbol(symbol_name, 1) - extensions[symbol].add(Denotation(entity)) - - # add spatial relationships to the domain - for rel_str, sym in EvalSpatial.rel_str2sym.items(): - for s_entity, o_entity in zip(entities, entities): - if EvalSpatial.eval(rel_str, s_entity, o_entity): - extensions[sym].add(Denotation((s_entity, o_entity))) - - - extensions = {k: frozenset(v) for k, v in extensions.items()} - - return DomainModel(entities, extensions) - - -def domain_model2props_info(domain_model:DomainModel) -> dict: - - props_info = {} - - for entity in domain_model.entities: - - props_info[entity.name] = { - "position": entity.position, - "orientation": entity.orientation, - "rgba": entity.rgba, - "bbox": entity.bbox, - "symbols": set() - } - - for symbol, denotations in domain_model.extensions.items(): - - for denotation in list(denotations): - if len(denotation) != 1 or symbol.name == "object": - break - props_info[denotation[0].name]['symbols'].add(symbol.name) - - return props_info \ No newline at end of file diff --git a/mujoco_robotics_environments/transporter_network_data_generation.py b/mujoco_robotics_environments/transporter_network_data_generation.py index 6cb41da..a59336f 100644 --- a/mujoco_robotics_environments/transporter_network_data_generation.py +++ b/mujoco_robotics_environments/transporter_network_data_generation.py @@ -2,6 +2,7 @@ import os from absl import logging +import time import numpy as np import tensorflow as tf @@ -36,19 +37,14 @@ TASKS = [COLOR_SEPERATOR_TASK_CONFIG] for task_config in TASKS: + current_timestamp = time.localtime() + human_readable_timestamp = time.strftime("%Y-%m-%d-%H:%M:%S", current_timestamp) + DATASET_NAME = f"{task_config.name}_{human_readable_timestamp}" + # set up the data folder - DATA_DIR = os.path.join(os.getcwd(), "data", task_config.name) + DATA_DIR = os.path.join(os.getcwd(), "data", DATASET_NAME) if not os.path.exists(DATA_DIR): os.makedirs(DATA_DIR) - - def record_camera_params(timestamp, action, env): - """Record camera parameters to metadata.""" - camera_params = env.get_camera_params("overhead_camera/overhead_camera") - return { - "camera_name": "overhead_camera/overhead_camera", - "camera_intrinsics": camera_params["intrinsics"], - "camera_extrinsics": camera_params["extrinsics"] - } # define base dataset configuration across all transporter tasks for camera in task_config.arena.cameras: @@ -57,21 +53,45 @@ def record_camera_params(timestamp, action, env): camera_width = camera.width ds_config = tfds.rlds.rlds_base.DatasetConfig( - name="rearrangement_{}".format(task_config.name), + name=DATASET_NAME, observation_info=tfds.features.FeaturesDict({ "overhead_camera/rgb": tfds.features.Tensor(shape=(camera_height, camera_width, 3), dtype=tf.uint8), "overhead_camera/depth": tfds.features.Tensor(shape=(camera_height, camera_width), dtype=tf.float32), }), - action_info=tfds.features.Tensor(shape=(14,), dtype=np.float64), + action_info=tfds.features.FeaturesDict({ + "pose": tfds.features.Tensor(shape=(7,), dtype=np.float64), + "pixel_coords": tfds.features.Tensor(shape=(2,), dtype=np.int32), + "gripper_rot": np.float64, + }), reward_info=np.float64, discount_info=np.float64, - step_metadata_info={ - "camera_name": tfds.features.Text(), - "camera_intrinsics": tfds.features.Tensor(shape=(3, 3), dtype=np.float64), - "camera_extrinsics": tfds.features.Tensor(shape=(4, 4), dtype=np.float64), + episode_metadata_info={ + "intrinsics":{ + "fx": tf.float64, + "fy": tf.float64, + "cx": tf.float64, + "cy": tf.float64, + }, + "extrinsics":{ + "x": tf.float64, + "y": tf.float64, + "z": tf.float64, + "qx": tf.float64, + "qy": tf.float64, + "qz": tf.float64, + "qw": tf.float64, }, + }, ) + def calibration_metadata(timestep, unused_action, unused_env): + """ + Store camera calibration params as episode metadata. + """ + if timestep.first: + return unused_env.get_camera_metadata() + else: + return None # instantiate task environment env = RearrangementEnv(task_config) @@ -79,30 +99,37 @@ def record_camera_params(timestamp, action, env): # collect data with envlogger with envlogger.EnvLogger( env, + episode_fn=calibration_metadata, backend=tfds_backend_writer.TFDSBackendWriter( data_directory=DATA_DIR, split_name="train", # for now default to train and eval on environment directly max_episodes_per_file=task_config.dataset.max_episodes_per_file, - ds_config=ds_config) , step_fn=record_camera_params + ds_config=ds_config), ) as env: for _ in range(task_config.dataset.num_episodes): try: _, _, _, obs = env.reset() - #from PIL import Image - #from matplotlib import cm - #Image.fromarray(obs["overhead_camera/rgb"]).show() - #depth = obs["overhead_camera/depth"] - #depth -= np.min(depth) - #depth /= np.max(depth) - #Image.fromarray(np.uint8(cm.Greys(depth)*255)).show() for _ in range(task_config.dataset.max_steps): in_progress, pick_pose, place_pose = env.sort_colours() if not in_progress: print("Task demonstration is complete") break - _, _, _, obs = env.step(np.concatenate([pick_pose, place_pose])) - #Image.fromarray(obs["overhead_camera/rgb"]).show() + + pick_action = { + "pose": pick_pose, + "pixel_coords": env.world_2_pixel("overhead_camera/overhead_camera", pick_pose[:3]), + "gripper_rot": 0.0, + } + + place_action = { + "pose": place_pose, + "pixel_coords": env.world_2_pixel("overhead_camera/overhead_camera", place_pose[:3]), + "gripper_rot": 0.0, + } + + _, _, _, obs = env.step(pick_action) + _, _, _, obs = env.step(place_action) except Exception as e: print("Task demonstration failed with exception: {}".format(e)) continue diff --git a/pyproject.toml b/pyproject.toml index 884f3bd..b561c86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,12 +44,13 @@ wandb = "^0.16.1" # linux only deps dm-reverb = {version="0.13.0", markers = "sys_platform == 'linux'"} tensorflow-cpu = {version="^2.14.0", markers = "sys_platform == 'linux'"} -envlogger = {version="^1.1", extras=["tfds"], markers = "sys_platform == 'linux'"} +envlogger = {extras = ["tfds"], version = "^1.2"} rlds = {version="^0.1.7", markers = "sys_platform == 'linux'"} # submodule deps mujoco_controllers = {path="./mujoco_robotics_environments/mujoco_pkgs/mujoco_controllers", develop=true} logic_toolkit = {git = "git@github.com:ipab-rad/logic_toolkit.git"} +huggingface-hub = "^0.23.0" [tool.poetry.extras] control_tuning = [