Skip to content

Commit

Permalink
Change keyword argument asset to include in spec.from_file and …
Browse files Browse the repository at this point in the history
…`spec.from_string`.

PiperOrigin-RevId: 720611881
Change-Id: I92f5e72cece8d3223bc552ea4f2ef742c47cb8c9
  • Loading branch information
quagla authored and copybara-github committed Jan 28, 2025
1 parent f4096bc commit 5b330ca
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 33 deletions.
55 changes: 23 additions & 32 deletions python/mujoco/specs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,7 @@ using MjDoubleRefVec = Eigen::Ref<const Eigen::VectorXd>;

struct MjSpec {
MjSpec() : ptr(mj_makeSpec()) {}
MjSpec(raw::MjSpec* ptr,
const std::unordered_map<std::string, py::bytes>& assets_ = {})
: ptr(ptr) {
for (const auto& asset : assets_) {
assets[asset.first.c_str()] = asset.second;
}
}
MjSpec(raw::MjSpec* ptr) : ptr(ptr) {}

// copy constructor and assignment
MjSpec(const MjSpec& other) : ptr(mj_copySpec(other.ptr)) {
Expand Down Expand Up @@ -269,15 +263,15 @@ PYBIND11_MODULE(_specs, m) {
mjSpec.def_static(
"from_file",
[](std::string& filename,
std::optional<std::unordered_map<std::string, py::bytes>>& assets)
std::optional<std::unordered_map<std::string, py::bytes>>& include)
-> MjSpec {
const auto converted_assets = _impl::ConvertAssetsDict(assets);
const auto files = _impl::ConvertAssetsDict(include);
raw::MjSpec* spec;
{
py::gil_scoped_release no_gil;
char error[1024];
spec = LoadSpecFileImpl(
filename, converted_assets,
filename, files,
[&error](const char* filename, const mjVFS* vfs) {
return InterceptMjErrors(mj_parseXML)(
filename, vfs, error, sizeof(error));
Expand All @@ -286,43 +280,42 @@ PYBIND11_MODULE(_specs, m) {
throw py::value_error(error);
}
}
if (assets.has_value()) {
return MjSpec(spec, assets.value());
}
return MjSpec(spec);
},
py::arg("filename"), py::arg("assets") = py::none(), R"mydelimiter(
py::arg("filename"), py::arg("include") = py::none(), R"mydelimiter(
Creates a spec from an XML file.
Parameters
----------
filename : str
Path to the XML file.
assets : dict, optional
A dictionary of assets to be used by the spec. The keys are asset names
and the values are asset contents.
)mydelimiter", py::return_value_policy::move);
include : dict, optional
A dictionary of xml files included by the model. The keys are file names
and the values are file contents.
)mydelimiter",
py::return_value_policy::move);
mjSpec.def_static(
"from_string",
[](std::string& xml,
std::optional<std::unordered_map<std::string, py::bytes>>& assets)
std::optional<std::unordered_map<std::string, py::bytes>>& include)
-> MjSpec {
auto converted_assets = _impl::ConvertAssetsDict(assets);
auto files = _impl::ConvertAssetsDict(include);
raw::MjSpec* spec;
{
py::gil_scoped_release no_gil;
std::string model_filename = "model_.xml";
if (assets.has_value()) {
while (assets->find(model_filename) != assets->end()) {
if (include.has_value()) {
while (include->find(model_filename) !=
include->end()) {
model_filename =
model_filename.substr(0, model_filename.size() - 4) + "_.xml";
}
}
converted_assets.emplace_back(
files.emplace_back(
model_filename.c_str(), xml.c_str(), xml.length());
char error[1024];
spec = LoadSpecFileImpl(
model_filename, converted_assets,
model_filename, files,
[&error](const char* filename, const mjVFS* vfs) {
return InterceptMjErrors(mj_parseXML)(
filename, vfs, error, sizeof(error));
Expand All @@ -331,22 +324,20 @@ PYBIND11_MODULE(_specs, m) {
throw py::value_error(error);
}
}
if (assets.has_value()) {
return MjSpec(spec, assets.value());
}
return MjSpec(spec);
},
py::arg("xml"), py::arg("assets") = py::none(), R"mydelimiter(
py::arg("xml"), py::arg("include") = py::none(), R"mydelimiter(
Creates a spec from an XML string.
Parameters
----------
xml : str
XML string.
assets : dict, optional
A dictionary of assets to be used by the spec. The keys are asset names
and the values are asset contents.
)mydelimiter", py::return_value_policy::move);
include : dict, optional
A dictionary of xml files included by the model. The keys are file names
and the values are file contents.
)mydelimiter",
py::return_value_policy::move);
mjSpec.def("recompile", [mjmodel_mjdata_from_spec_ptr](
const MjSpec& self, py::object m, py::object d) {
return mjmodel_mjdata_from_spec_ptr(reinterpret_cast<uintptr_t>(self.ptr),
Expand Down
2 changes: 1 addition & 1 deletion python/mujoco/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def test_include(self):
<include file="included.xml"/>
</mujoco>
"""),
{'included.xml': included_xml.encode('utf-8')},
include={'included.xml': included_xml.encode('utf-8')},
)
self.assertEqual(
spec.worldbody.first_body().first_geom().type, mujoco.mjtGeom.mjGEOM_BOX
Expand Down

0 comments on commit 5b330ca

Please sign in to comment.