Skip to content

Commit

Permalink
Introduce Separate Builder (+ Moved Material files)
Browse files Browse the repository at this point in the history
  • Loading branch information
kwesiRutledge committed Feb 12, 2024
1 parent 93a7235 commit 7ffda13
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 195 deletions.
253 changes: 253 additions & 0 deletions obj2mjcf/MJCFBuilder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
"""
MJCFBuilder.py
Description:
This file contains the class that is used to build MJCF files.
"""
import logging
import os
from lxml import etree
from pathlib import Path

from typing import List, Tuple, Union, Any

import mujoco
import trimesh
from termcolor import cprint

from obj2mjcf.Material import Material


# 2-space indentation for the generated XML.
_XML_INDENTATION = " "


class MJCFBuilder:
def __init__(
self,
filename: Path,
mesh: Union[trimesh.base.Trimesh,Any],
materials: List[Material],
work_dir: Path = None,
):
self.filename = filename
self.mesh = mesh
self.materials = materials

self.work_dir = work_dir
if self.work_dir is None:
self.work_dir = filename.parent / filename.stem

# Define variables that will be defined later
self.tree = None

def add_visual_and_collision_default_classes(
self,
root: etree.Element,
):
# Define the default element
default_elem = etree.SubElement(root, "default")

# Define visual defaults
visual_default_elem = etree.SubElement(default_elem, "default")
visual_default_elem.attrib["class"] = "visual"
etree.SubElement(
visual_default_elem,
"geom",
group="2",
type="mesh",
contype="0",
conaffinity="0",
)

# Define collision defaults
collision_default_elem = etree.SubElement(default_elem, "default")
collision_default_elem.attrib["class"] = "collision"
etree.SubElement(collision_default_elem, "geom", group="3", type="mesh")

def add_assets(self, root: etree.Element, mtls: List[Material]) -> etree.Element:
# Define the assets element
asset_elem = etree.SubElement(root, "asset")

for material in mtls:
if material.map_Kd is not None:
# Create the texture asset.
texture = Path(material.map_Kd)
etree.SubElement(
asset_elem,
"texture",
type="2d",
name=texture.stem,
file=texture.name,
)
# Reference the texture asset in a material asset.
etree.SubElement(
asset_elem,
"material",
name=material.name,
texture=texture.stem,
specular=material.mjcf_specular(),
shininess=material.mjcf_shininess(),
)
else:
etree.SubElement(
asset_elem,
"material",
name=material.name,
specular=material.mjcf_specular(),
shininess=material.mjcf_shininess(),
rgba=material.mjcf_rgba(),
)

return asset_elem

def add_visual_geometries(
self,
obj_body: etree.Element,
asset_elem: etree.Element,
):
# Constants
filename = self.filename
mesh = self.mesh
materials = self.materials

process_mtl = len(materials) > 0

# Add visual geometries to object body
if isinstance(mesh, trimesh.base.Trimesh):
meshname = Path(f"{filename.stem}.obj")
# Add the mesh to assets.
etree.SubElement(asset_elem, "mesh", file=str(meshname))
# Add the geom to the worldbody.
if process_mtl:
e_ = etree.SubElement(
obj_body, "geom", material=materials[0].name, mesh=str(meshname.stem)
)
e_.attrib["class"] = "visual"
else:
e_ = etree.SubElement(obj_body, "geom", mesh=meshname.stem)
e_.attrib["class"] = "visual"
else:
for i, (name, geom) in enumerate(mesh.geometry.items()):
meshname = Path(f"{filename.stem}_{i}.obj")
# Add the mesh to assets.
etree.SubElement(asset_elem, "mesh", file=str(meshname))
# Add the geom to the worldbody.
if process_mtl:
e_ = etree.SubElement(
obj_body, "geom", mesh=meshname.stem, material=name
)
e_.attrib["class"] = "visual"
else:
e_ = etree.SubElement(obj_body, "geom", mesh=meshname.stem)
e_.attrib["class"] = "visual"

def add_collision_geometries(
self,
obj_body: etree.Element,
asset_elem: etree.Element,
decomp_success: bool = False,
):
# Constants
filename = self.filename
mesh = self.mesh

work_dir = self.work_dir

if decomp_success:
# Find collision files from the decomposed convex hulls.
collisions = [
x for x in work_dir.glob("**/*") if x.is_file() and "collision" in x.name
]
collisions.sort(key=lambda x: int(x.stem.split("_")[-1]))

for collision in collisions:
etree.SubElement(asset_elem, "mesh", file=collision.name)
e_ = etree.SubElement(obj_body, "geom", mesh=collision.stem)
e_.attrib["class"] = "collision"
else:
# If no decomposed convex hulls were created, use the original mesh as the
# collision mesh.
if isinstance(mesh, trimesh.base.Trimesh):
meshname = Path(f"{filename.stem}.obj")
e_ = etree.SubElement(obj_body, "geom", mesh=meshname.stem)
e_.attrib["class"] = "collision"
else:
for i, (name, geom) in enumerate(mesh.geometry.items()):
meshname = Path(f"{filename.stem}_{i}.obj")
e_ = etree.SubElement(obj_body, "geom", mesh=meshname.stem)
e_.attrib["class"] = "collision"

def build(
self,
add_free_joint: bool = False,
):
# Constants
filename = self.filename
mesh = self.mesh
mtls = self.materials

# Start assembling xml tree
root = etree.Element("mujoco", model=filename.stem)

# Add Defaults + Assets
self.add_visual_and_collision_default_classes(root)
asset_elem = self.add_assets(root, mtls)

# Add Worldbody
worldbody_elem = etree.SubElement(root, "worldbody")
obj_body = etree.SubElement(worldbody_elem, "body", name=filename.stem)
if add_free_joint:
etree.SubElement(obj_body, "freejoint")

# Add visual and collision geometries to object body
self.add_visual_geometries(obj_body, asset_elem)
self.add_collision_geometries(obj_body, asset_elem)

# Collect Tree
tree = etree.ElementTree(root)
etree.indent(tree, space=_XML_INDENTATION, level=0)

self.tree = tree

def compile_model(self):
# Constants
filename = self.filename
work_dir = self.work_dir

# Pull up tree if possible
tree = self.tree
if tree is None:
raise ValueError("Tree has not been defined yet.")

# Create the work directory if it does not exist.
try:
tmp_path = work_dir / "tmp.xml"
tree.write(tmp_path, encoding="utf-8")
model = mujoco.MjModel.from_xml_path(str(tmp_path))
data = mujoco.MjData(model)
mujoco.mj_step(model, data)
cprint(f"{filename} compiled successfully!", "green")
except Exception as e:
cprint(f"Error compiling model: {e}", "red")
finally:
if tmp_path.exists():
tmp_path.unlink()

def save_mjcf(
self,
):
# Constants
filename = self.filename
work_dir = self.work_dir

# Input Processing

# Pull up tree if possible
tree = self.tree
if tree is None:
raise ValueError("Tree has not been defined yet.")

# Save the MJCF file.
xml_path = str(work_dir / f"{filename.stem}.xml")
tree.write(xml_path, encoding="utf-8")
logging.info(f"Saved MJCF to {xml_path}")
72 changes: 72 additions & 0 deletions obj2mjcf/Material.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Optional, Sequence
from dataclasses import dataclass, field


# MTL fields relevant to MuJoCo.
_MTL_FIELDS = (
# Ambient, diffuse and specular colors.
"Ka",
"Kd",
"Ks",
# d or Tr are used for the rgba transparency.
"d",
"Tr",
# Shininess.
"Ns",
# References a texture file.
"map_Kd",
)

# Character used to denote a comment in an MTL file.
_MTL_COMMENT_CHAR = "#"

@dataclass
class Material:
name: str
Ka: Optional[str] = None
Kd: Optional[str] = None
Ks: Optional[str] = None
d: Optional[str] = None
Tr: Optional[str] = None
Ns: Optional[str] = None
map_Kd: Optional[str] = None

@staticmethod
def from_string(lines: Sequence[str]) -> "Material":
"""Construct a Material object from a string."""
attrs = {"name": lines[0].split(" ")[1].strip()}
for line in lines[1:]:
for attr in _MTL_FIELDS:
if line.startswith(attr):
elems = line.split(" ")[1:]
elems = [elem for elem in elems if elem != ""]
attrs[attr] = " ".join(elems)
break
return Material(**attrs)

def mjcf_rgba(self) -> str:
Kd = self.Kd or "1.0 1.0 1.0"
if self.d is not None: # alpha
alpha = self.d
elif self.Tr is not None: # 1 - alpha
alpha = str(1.0 - float(self.Tr))
else:
alpha = "1.0"
# TODO(kevin): Figure out how to use Ka for computing rgba.
return f"{Kd} {alpha}"

def mjcf_shininess(self) -> str:
if self.Ns is not None:
# Normalize Ns value to [0, 1]. Ns values normally range from 0 to 1000.
Ns = float(self.Ns) / 1_000
else:
Ns = 0.5
return f"{Ns}"

def mjcf_specular(self) -> str:
if self.Ks is not None:
# Take the average of the specular RGB values.
Ks = sum(list(map(float, self.Ks.split(" ")))) / 3
else:
Ks = 0.5
return f"{Ks}"
Loading

0 comments on commit 7ffda13

Please sign in to comment.