diff --git a/python/mujoco/specs.cc b/python/mujoco/specs.cc index 3943f797d8..0f7a4c2a98 100644 --- a/python/mujoco/specs.cc +++ b/python/mujoco/specs.cc @@ -233,6 +233,16 @@ py::list FindAllImpl(raw::MjsBody& body, mjtObj objtype, bool recursive) { return list; // list of pointers, so they can be copied } +void SetFrame(raw::MjsBody* body, mjtObj objtype, raw::MjsFrame* frame) { + mjsElement* el = mjs_firstChild(body, objtype, 0); + while (el) { + if (frame->element != el) { + mjs_setFrame(el, frame); + } + el = mjs_nextChild(body, el, 0); + } +} + PYBIND11_MODULE(_specs, m) { auto structs_m = py::module::import("mujoco._structs"); py::function mjmodel_from_spec_ptr = @@ -496,13 +506,21 @@ PYBIND11_MODULE(_specs, m) { throw pybind11::value_error( "Only one of frame or site can be specified."); } - auto world = mjs_findBody(child.ptr, "world"); - if (!world) { + auto worldbody = mjs_findBody(child.ptr, "world"); + if (!worldbody) { throw pybind11::value_error("Child does not have a world body."); } + auto worldframe = mjs_addFrame(worldbody, nullptr); + SetFrame(worldbody, mjOBJ_BODY, worldframe); + SetFrame(worldbody, mjOBJ_SITE, worldframe); + SetFrame(worldbody, mjOBJ_FRAME, worldframe); + SetFrame(worldbody, mjOBJ_JOINT, worldframe); + SetFrame(worldbody, mjOBJ_GEOM, worldframe); + SetFrame(worldbody, mjOBJ_LIGHT, worldframe); + SetFrame(worldbody, mjOBJ_CAMERA, worldframe); const char* p = prefix.has_value() ? prefix.value().c_str() : ""; const char* s = suffix.has_value() ? suffix.value().c_str() : ""; - raw::MjsBody* attached_world = nullptr; + raw::MjsFrame* attached_frame = nullptr; if (frame.has_value()) { raw::MjsFrame* frame_ptr = nullptr; try { @@ -518,7 +536,15 @@ PYBIND11_MODULE(_specs, m) { throw pybind11::value_error( "Frame spec does not match parent spec."); } - attached_world = mjs_attachBody(frame_ptr, world, p, s); + raw::MjsBody* parent_body = mjs_getParent(frame_ptr->element); + if (!parent_body) { + throw pybind11::value_error("Frame does not have a parent body."); + } + attached_frame = mjs_attachFrame(parent_body, worldframe, p, s); + if (!attached_frame) { + throw pybind11::value_error(mjs_getError(self.ptr)); + } + mjs_setFrame(attached_frame->element, frame_ptr); } if (site.has_value()) { raw::MjsSite* site_ptr = nullptr; @@ -535,10 +561,10 @@ PYBIND11_MODULE(_specs, m) { throw pybind11::value_error( "Site spec does not match parent spec."); } - attached_world = mjs_attachToSite(site_ptr, world, p, s); - } - if (!attached_world) { - throw pybind11::value_error(mjs_getError(self.ptr)); + attached_frame = mjs_attachFrameToSite(site_ptr, worldframe, p, s); + if (!attached_frame) { + throw pybind11::value_error(mjs_getError(self.ptr)); + } } for (const auto& asset : child.assets) { if (self.assets.contains(asset.first) && !self.override_assets) { @@ -549,7 +575,7 @@ PYBIND11_MODULE(_specs, m) { self.assets[asset.first] = asset.second; } child.parent = &self; - return mjs_bodyToFrame(&attached_world); + return attached_frame; }, py::arg("child"), py::arg("prefix") = py::none(), py::arg("suffix") = py::none(), py::arg("site") = py::none(), diff --git a/python/mujoco/specs_test.py b/python/mujoco/specs_test.py index afd6e68fb2..a958e59534 100644 --- a/python/mujoco/specs_test.py +++ b/python/mujoco/specs_test.py @@ -991,7 +991,7 @@ def test_attach_to_site(self): child2.assets = {'cube2.obj': 'cube2_content'} body2 = child2.worldbody.add_body(name='body') self.assertIsNotNone(parent.attach(child2, site=site, prefix='child2-')) - self.assertIsNone(child2.worldbody) + self.assertIsNotNone(child2.worldbody) self.assertEqual(child2.parent, parent) body2.pos = [-1, -1, -1] model2 = parent.compile() @@ -1009,7 +1009,7 @@ def test_attach_to_site(self): child3.assets = {'cube3.obj': 'cube3_content'} body3 = child3.worldbody.add_body(name='body') self.assertIsNotNone(parent.attach(child3, site='site', prefix='child3-')) - self.assertIsNone(child3.worldbody) + self.assertIsNotNone(child3.worldbody) self.assertEqual(child3.parent, parent) body3.pos = [-2, -2, -2] model3 = parent.compile() @@ -1062,7 +1062,7 @@ def test_attach_to_frame(self): child2.assets = {'cube2.obj': 'cube2_content'} body2 = child2.worldbody.add_body(name='body') self.assertIsNotNone(parent.attach(child2, frame=frame, prefix='child-')) - self.assertIsNone(child2.worldbody) + self.assertIsNotNone(child2.worldbody) self.assertEqual(child2.parent, parent) body2.pos = [-1, -1, -1] model2 = parent.compile() @@ -1080,7 +1080,7 @@ def test_attach_to_frame(self): child3.assets = {'cube2.obj': 'new_cube2_content'} body3 = child3.worldbody.add_body(name='body') self.assertIsNotNone(parent.attach(child3, frame='frame', prefix='child3-')) - self.assertIsNone(child3.worldbody) + self.assertIsNotNone(child3.worldbody) self.assertEqual(child3.parent, parent) body3.pos = [-2, -2, -2] model3 = parent.compile() diff --git a/test/user/user_api_test.cc b/test/user/user_api_test.cc index 6d22cdf5ea..5a3db2e8d9 100644 --- a/test/user/user_api_test.cc +++ b/test/user/user_api_test.cc @@ -1435,7 +1435,7 @@ TEST_F(MujocoTest, AttachFrameToSite) { mj_deleteModel(expected); } -TEST_F(MujocoTest, AttachWorld) { +TEST_F(MujocoTest, BodyToFrame) { std::array er; mjtNum tol = 0; std::string field = ""; @@ -1524,6 +1524,166 @@ TEST_F(MujocoTest, AttachWorld) { mj_deleteModel(expected); } +TEST_F(MujocoTest, AttachSpecToSite) { + std::array er; + mjtNum tol = 0; + std::string field = ""; + + static constexpr char xml_parent[] = R"( + + + + + + + )"; + + static constexpr char xml_child[] = R"( + + + + + + + + + )"; + + static constexpr char xml_result[] = R"( + + + + + + + + + + + + + + )"; + + mjSpec* parent = mj_parseXMLString(xml_parent, 0, er.data(), er.size()); + EXPECT_THAT(parent, NotNull()) << er.data(); + mjSpec* child = mj_parseXMLString(xml_child, 0, er.data(), er.size()); + EXPECT_THAT(child, NotNull()) << er.data(); + mjsSite* site = mjs_asSite(mjs_findElement(parent, mjOBJ_SITE, "site")); + EXPECT_THAT(site, NotNull()); + + // add a frame to the child + mjsBody* world = mjs_findBody(child, "world"); + EXPECT_THAT(world, NotNull()); + mjsFrame* frame = mjs_addFrame(world, 0); + EXPECT_THAT(frame, NotNull()); + mjs_setString(frame->name, "world"); + mjs_setFrame(mjs_firstChild(world, mjOBJ_BODY, 0), frame); + mjs_setFrame(mjs_firstChild(world, mjOBJ_CAMERA, 0), frame); + + // attach the entire spec to the site + mjsFrame* worldframe = mjs_attachFrameToSite(site, frame, "attached-", "-1"); + EXPECT_THAT(worldframe, NotNull()); + + // compile and compare + mjModel* model = mj_compile(parent, 0); + EXPECT_THAT(model, NotNull()); + mjModel* expected = LoadModelFromString(xml_result, er.data(), er.size()); + EXPECT_THAT(expected, NotNull()) << er.data(); + EXPECT_LE(CompareModel(model, expected, field), tol) + << "Expected and attached models are different!\n" + << "Different field: " << field << '\n'; + + // check that the child world still exists + mjsBody* child_world = mjs_findBody(child, "world"); + EXPECT_THAT(child_world, NotNull()); + + mj_deleteSpec(parent); + mj_deleteSpec(child); + mj_deleteModel(model); + mj_deleteModel(expected); +} + +TEST_F(MujocoTest, AttachSpecToBody) { + std::array er; + mjtNum tol = 0; + std::string field = ""; + + static constexpr char xml_parent[] = R"( + + + + + )"; + + static constexpr char xml_child[] = R"( + + + + + + + + + )"; + + static constexpr char xml_result[] = R"( + + + + + + + + + + + + + )"; + + mjSpec* parent = mj_parseXMLString(xml_parent, 0, er.data(), er.size()); + EXPECT_THAT(parent, NotNull()) << er.data(); + mjSpec* child = mj_parseXMLString(xml_child, 0, er.data(), er.size()); + EXPECT_THAT(child, NotNull()) << er.data(); + mjsBody* body = mjs_findBody(parent, "body"); + EXPECT_THAT(body, NotNull()); + + // add a frame to the child + mjsBody* world = mjs_findBody(child, "world"); + EXPECT_THAT(world, NotNull()); + mjsFrame* frame = mjs_addFrame(world, 0); + EXPECT_THAT(frame, NotNull()); + mjs_setString(frame->name, "world"); + mjs_setFrame(mjs_firstChild(world, mjOBJ_BODY, 0), frame); + mjs_setFrame(mjs_firstChild(world, mjOBJ_CAMERA, 0), frame); + + // attach the entire spec to the site + mjsFrame* worldframe = mjs_attachFrame(body, frame, "attached-", "-1"); + EXPECT_THAT(worldframe, NotNull()); + worldframe->pos[0] = 1; + worldframe->pos[1] = 2; + worldframe->pos[2] = 3; + + // compile and compare + mjModel* model = mj_compile(parent, 0); + EXPECT_THAT(model, NotNull()); + mjModel* expected = LoadModelFromString(xml_result, er.data(), er.size()); + EXPECT_THAT(expected, NotNull()) << er.data(); + EXPECT_LE(CompareModel(model, expected, field), tol) + << "Expected and attached models are different!\n" + << "Different field: " << field << '\n'; + + // check that the child world still exists + mjsBody* child_world = mjs_findBody(child, "world"); + EXPECT_THAT(child_world, NotNull()); + + mj_deleteSpec(parent); + mj_deleteSpec(child); + mj_deleteModel(model); + mj_deleteModel(expected); +} + TEST_F(MujocoTest, PreserveState) { std::array er; std::string field = "";