Skip to content

Commit

Permalink
Add asset argument to MjSpec constructors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720628689
Change-Id: I189aa34973e58b9fb87951e34cb923d94e081a55
  • Loading branch information
quagla authored and copybara-github committed Jan 28, 2025
1 parent 5b330ca commit 0643017
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 11 deletions.
48 changes: 37 additions & 11 deletions python/mujoco/specs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,32 +71,44 @@ using MjDoubleRefVec = Eigen::Ref<const Eigen::VectorXd>;

struct MjSpec {
MjSpec() : ptr(mj_makeSpec()) {}
MjSpec(raw::MjSpec* ptr) : ptr(ptr) {}
MjSpec(raw::MjSpec* ptr, const py::dict& assets_ = {}) : ptr(ptr) {
for (const auto [key, value] : assets_) {
assets[key] = value;
}
}

// copy constructor and assignment
MjSpec(const MjSpec& other) : ptr(mj_copySpec(other.ptr)) {
override_assets = other.override_assets;
assets = other.assets;
for (const auto [key, value] : other.assets) {
assets[key] = value;
}
}
MjSpec& operator=(const MjSpec& other) {
override_assets = other.override_assets;
ptr = mj_copySpec(other.ptr);
assets = other.assets;
for (const auto [key, value] : other.assets) {
assets[key] = value;
}
return *this;
}

// move constructor and move assignment
MjSpec(MjSpec&& other) : ptr(other.ptr) {
override_assets = other.override_assets;
other.ptr = nullptr;
assets = other.assets;
for (const auto [key, value] : other.assets) {
assets[key] = value;
}
other.assets.clear();
}
MjSpec& operator=(MjSpec&& other) {
override_assets = other.override_assets;
ptr = other.ptr;
other.ptr = nullptr;
assets = other.assets;
for (const auto [key, value] : other.assets) {
assets[key] = value;
}
other.assets.clear();
return *this;
}
Expand Down Expand Up @@ -263,8 +275,8 @@ PYBIND11_MODULE(_specs, m) {
mjSpec.def_static(
"from_file",
[](std::string& filename,
std::optional<std::unordered_map<std::string, py::bytes>>& include)
-> MjSpec {
std::optional<std::unordered_map<std::string, py::bytes>>& include,
std::optional<py::dict>& assets) -> MjSpec {
const auto files = _impl::ConvertAssetsDict(include);
raw::MjSpec* spec;
{
Expand All @@ -280,9 +292,13 @@ PYBIND11_MODULE(_specs, m) {
throw py::value_error(error);
}
}
if (assets.has_value()) {
return MjSpec(spec, assets.value());
}
return MjSpec(spec);
},
py::arg("filename"), py::arg("include") = py::none(), R"mydelimiter(
py::arg("filename"), py::arg("include") = py::none(),
py::arg("assets") = py::none(), R"mydelimiter(
Creates a spec from an XML file.
Parameters
Expand All @@ -292,13 +308,16 @@ PYBIND11_MODULE(_specs, m) {
include : dict, optional
A dictionary of xml files included by the model. The keys are file names
and the values are file contents.
assets : dict, optional
A dictionary of assets to be used by the spec. The keys are asset names
and the values are asset contents.
)mydelimiter",
py::return_value_policy::move);
mjSpec.def_static(
"from_string",
[](std::string& xml,
std::optional<std::unordered_map<std::string, py::bytes>>& include)
-> MjSpec {
std::optional<std::unordered_map<std::string, py::bytes>>& include,
std::optional<py::dict>& assets) -> MjSpec {
auto files = _impl::ConvertAssetsDict(include);
raw::MjSpec* spec;
{
Expand All @@ -324,9 +343,13 @@ PYBIND11_MODULE(_specs, m) {
throw py::value_error(error);
}
}
if (assets.has_value()) {
return MjSpec(spec, assets.value());
}
return MjSpec(spec);
},
py::arg("xml"), py::arg("include") = py::none(), R"mydelimiter(
py::arg("xml"), py::arg("include") = py::none(),
py::arg("assets") = py::none(), R"mydelimiter(
Creates a spec from an XML string.
Parameters
Expand All @@ -336,6 +359,9 @@ PYBIND11_MODULE(_specs, m) {
include : dict, optional
A dictionary of xml files included by the model. The keys are file names
and the values are file contents.
assets : dict, optional
A dictionary of assets to be used by the spec. The keys are asset names
and the values are asset contents.
)mydelimiter",
py::return_value_policy::move);
mjSpec.def("recompile", [mjmodel_mjdata_from_spec_ptr](
Expand Down
16 changes: 16 additions & 0 deletions python/mujoco/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,22 @@ def test_assets(self):
self.assertEqual(model.nmeshvert, 8)
self.assertEqual(spec.assets['cube.obj'], cube)

xml = """
<mujoco model="test">
<asset>
<mesh name="cube" file="cube.obj"/>
</asset>
<worldbody>
<geom mesh="cube"/>
</worldbody>
</mujoco>
"""
assets = {'cube.obj': cube}
spec = mujoco.MjSpec.from_string(xml, assets=assets)
model = spec.compile()
self.assertEqual(model.nmeshvert, 8)
self.assertEqual(spec.assets['cube.obj'], cube)

def test_include(self):
included_xml = """
<mujoco>
Expand Down

0 comments on commit 0643017

Please sign in to comment.