From 9900664f43e673e99c5c33dee5f7634b9fc14adf Mon Sep 17 00:00:00 2001 From: twata Date: Tue, 4 Jul 2023 03:09:26 +0000 Subject: [PATCH 1/2] [WIP] [pfto] Add ppe.map test case --- .../onnx_tests/test_export.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index 201b025b6..438f85995 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -416,3 +416,18 @@ def forward(self, x, y): output_names=["out"], dynamic_axes={"x": {1: "A"}, "y": {0: "B"}}, ) + + +def test_ppe_map(): + torch.manual_seed(100) + + class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv = torch.nn.Conv2d(1, 1, 3) + + def forward(self, x): + y = self.conv(x) + return list(ppe.map(lambda u: u + 1, y))[0] + + run_model_test(Net(), (torch.rand(1, 1, 112, 112),), rtol=1e-03) From d8e2ae383dbfbd4e782367abdcadc2dcf68c664f Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jul 2023 09:27:39 +0000 Subject: [PATCH 2/2] Make test run --- .../onnx_tests/test_export.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index 438f85995..1bc007d0a 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -426,8 +426,16 @@ def __init__(self): super(Net, self).__init__() self.conv = torch.nn.Conv2d(1, 1, 3) + def map_f(self, u): + return u + 1 + def forward(self, x): - y = self.conv(x) - return list(ppe.map(lambda u: u + 1, y))[0] + y1 = self.conv(x) + y2 = self.conv(x) + y = [{"u" : y1}, {"u": y2}] + return list(ppe.map(self.map_f, y))[0] - run_model_test(Net(), (torch.rand(1, 1, 112, 112),), rtol=1e-03) + model = Net() + ppe.to(model, device="cpu") + + run_model_test(model, (torch.rand(1, 1, 112, 112),), rtol=1e-03)