From 5b330ca19c03c50b0b2d7d179b12099f5d05c33d Mon Sep 17 00:00:00 2001 From: Alessio Quaglino Date: Tue, 28 Jan 2025 09:54:48 -0800 Subject: [PATCH] Change keyword argument `asset` to `include` in `spec.from_file` and `spec.from_string`. PiperOrigin-RevId: 720611881 Change-Id: I92f5e72cece8d3223bc552ea4f2ef742c47cb8c9 --- python/mujoco/specs.cc | 55 ++++++++++++++++--------------------- python/mujoco/specs_test.py | 2 +- 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/python/mujoco/specs.cc b/python/mujoco/specs.cc index 00dbfc89cb..a8de5cfe4b 100644 --- a/python/mujoco/specs.cc +++ b/python/mujoco/specs.cc @@ -71,13 +71,7 @@ using MjDoubleRefVec = Eigen::Ref; struct MjSpec { MjSpec() : ptr(mj_makeSpec()) {} - MjSpec(raw::MjSpec* ptr, - const std::unordered_map& assets_ = {}) - : ptr(ptr) { - for (const auto& asset : assets_) { - assets[asset.first.c_str()] = asset.second; - } - } + MjSpec(raw::MjSpec* ptr) : ptr(ptr) {} // copy constructor and assignment MjSpec(const MjSpec& other) : ptr(mj_copySpec(other.ptr)) { @@ -269,15 +263,15 @@ PYBIND11_MODULE(_specs, m) { mjSpec.def_static( "from_file", [](std::string& filename, - std::optional>& assets) + std::optional>& include) -> MjSpec { - const auto converted_assets = _impl::ConvertAssetsDict(assets); + const auto files = _impl::ConvertAssetsDict(include); raw::MjSpec* spec; { py::gil_scoped_release no_gil; char error[1024]; spec = LoadSpecFileImpl( - filename, converted_assets, + filename, files, [&error](const char* filename, const mjVFS* vfs) { return InterceptMjErrors(mj_parseXML)( filename, vfs, error, sizeof(error)); @@ -286,43 +280,42 @@ PYBIND11_MODULE(_specs, m) { throw py::value_error(error); } } - if (assets.has_value()) { - return MjSpec(spec, assets.value()); - } return MjSpec(spec); }, - py::arg("filename"), py::arg("assets") = py::none(), R"mydelimiter( + py::arg("filename"), py::arg("include") = py::none(), R"mydelimiter( Creates a spec from an XML file. Parameters ---------- filename : str Path to the XML file. - assets : dict, optional - A dictionary of assets to be used by the spec. The keys are asset names - and the values are asset contents. - )mydelimiter", py::return_value_policy::move); + include : dict, optional + A dictionary of xml files included by the model. The keys are file names + and the values are file contents. + )mydelimiter", + py::return_value_policy::move); mjSpec.def_static( "from_string", [](std::string& xml, - std::optional>& assets) + std::optional>& include) -> MjSpec { - auto converted_assets = _impl::ConvertAssetsDict(assets); + auto files = _impl::ConvertAssetsDict(include); raw::MjSpec* spec; { py::gil_scoped_release no_gil; std::string model_filename = "model_.xml"; - if (assets.has_value()) { - while (assets->find(model_filename) != assets->end()) { + if (include.has_value()) { + while (include->find(model_filename) != + include->end()) { model_filename = model_filename.substr(0, model_filename.size() - 4) + "_.xml"; } } - converted_assets.emplace_back( + files.emplace_back( model_filename.c_str(), xml.c_str(), xml.length()); char error[1024]; spec = LoadSpecFileImpl( - model_filename, converted_assets, + model_filename, files, [&error](const char* filename, const mjVFS* vfs) { return InterceptMjErrors(mj_parseXML)( filename, vfs, error, sizeof(error)); @@ -331,22 +324,20 @@ PYBIND11_MODULE(_specs, m) { throw py::value_error(error); } } - if (assets.has_value()) { - return MjSpec(spec, assets.value()); - } return MjSpec(spec); }, - py::arg("xml"), py::arg("assets") = py::none(), R"mydelimiter( + py::arg("xml"), py::arg("include") = py::none(), R"mydelimiter( Creates a spec from an XML string. Parameters ---------- xml : str XML string. - assets : dict, optional - A dictionary of assets to be used by the spec. The keys are asset names - and the values are asset contents. - )mydelimiter", py::return_value_policy::move); + include : dict, optional + A dictionary of xml files included by the model. The keys are file names + and the values are file contents. + )mydelimiter", + py::return_value_policy::move); mjSpec.def("recompile", [mjmodel_mjdata_from_spec_ptr]( const MjSpec& self, py::object m, py::object d) { return mjmodel_mjdata_from_spec_ptr(reinterpret_cast(self.ptr), diff --git a/python/mujoco/specs_test.py b/python/mujoco/specs_test.py index 1302afc26e..08dbf3912d 100644 --- a/python/mujoco/specs_test.py +++ b/python/mujoco/specs_test.py @@ -765,7 +765,7 @@ def test_include(self): """), - {'included.xml': included_xml.encode('utf-8')}, + include={'included.xml': included_xml.encode('utf-8')}, ) self.assertEqual( spec.worldbody.first_body().first_geom().type, mujoco.mjtGeom.mjGEOM_BOX