From 321d981e4b0a7072637aed0088ad3b567489b66e Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 26 Aug 2024 08:37:46 +0000 Subject: [PATCH] [Script] update conv script --- scripts/generate_conv_onnx.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/scripts/generate_conv_onnx.py b/scripts/generate_conv_onnx.py index 5083d50..182d66c 100644 --- a/scripts/generate_conv_onnx.py +++ b/scripts/generate_conv_onnx.py @@ -5,14 +5,21 @@ size_list = [128]#64, 256, 1024] dtype = torch.float32 - +C_in = 128 +C_out = 128 +K_sz = 3 +padding = 1 +H = 14 * 4 +W = 14 * 4 +stride=2 HOME = os.getenv("ONNXIM_HOME", default="../") +size_name = f"{C_in}_{C_out}_{K_sz}_{H}_{W}" # Test Convolution model class size_conv(torch.nn.Module): - def __init__(self, C_in, C_out, K_sz): + def __init__(self, C_in, C_out, K_sz, padding=padding): super().__init__() - self.fc = torch.nn.Conv2d(C_in, C_out, K_sz, padding=1, bias=False, dtype=dtype) + self.fc = torch.nn.Conv2d(C_in, C_out, K_sz, stride=stride, padding=padding, bias=False, dtype=dtype) def forward(self, x): return self.fc(x) @@ -20,26 +27,22 @@ def forward(self, x): # Create output folder Path(f"{HOME}/model_lists").mkdir(parents=True, exist_ok=True) for size in size_list: - C_in = size//2 - C_out = size - K_sz = 3 # Export PyTorch model to onnx - Path(f"{HOME}/models/conv_{size}").mkdir(parents=True, exist_ok=True) - m = size_conv(C_in, C_out, K_sz) - A = torch.zeros([1,C_in, 28, 28], dtype=dtype) - onnx_path = Path(f"{HOME}/models/conv_{size}/conv_{size}.onnx") - if not onnx_path.is_file(): - torch.onnx.export(m, A, onnx_path, export_params=True, input_names = ['input'], output_names=['output']) + Path(f"{HOME}/models/conv_{size_name}").mkdir(parents=True, exist_ok=True) + m = size_conv(C_in, C_out, K_sz, padding) + A = torch.zeros([1,C_in, H, W], dtype=dtype) + onnx_path = Path(f"{HOME}/models/conv_{size_name}/conv_{size_name}.onnx") + torch.onnx.export(m, A, onnx_path, export_params=True, input_names = ['input'], output_names=['output']) # Generate model_list json file config = { "models": [ { - "name": f"conv_{size}", + "name": f"conv_{size_name}", "request_time": 0 } ] } - with open(f"{HOME}/model_lists/conv_{size}.json", "w") as json_file: + with open(f"{HOME}/model_lists/conv_{size_name}.json", "w") as json_file: json.dump(config, json_file, indent=4)