Skip to content

Commit

Permalink
[Script] update conv script
Browse files Browse the repository at this point in the history
  • Loading branch information
YWHyuk committed Aug 26, 2024
1 parent 3b41e6a commit 321d981
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions scripts/generate_conv_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,44 @@

size_list = [128]#64, 256, 1024]
dtype = torch.float32

C_in = 128
C_out = 128
K_sz = 3
padding = 1
H = 14 * 4
W = 14 * 4
stride=2
HOME = os.getenv("ONNXIM_HOME", default="../")

size_name = f"{C_in}_{C_out}_{K_sz}_{H}_{W}"
# Test Convolution model
class size_conv(torch.nn.Module):
def __init__(self, C_in, C_out, K_sz):
def __init__(self, C_in, C_out, K_sz, padding=padding):
super().__init__()
self.fc = torch.nn.Conv2d(C_in, C_out, K_sz, padding=1, bias=False, dtype=dtype)
self.fc = torch.nn.Conv2d(C_in, C_out, K_sz, stride=stride, padding=padding, bias=False, dtype=dtype)

def forward(self, x):
return self.fc(x)

# Create output folder
Path(f"{HOME}/model_lists").mkdir(parents=True, exist_ok=True)
for size in size_list:
C_in = size//2
C_out = size
K_sz = 3

# Export PyTorch model to onnx
Path(f"{HOME}/models/conv_{size}").mkdir(parents=True, exist_ok=True)
m = size_conv(C_in, C_out, K_sz)
A = torch.zeros([1,C_in, 28, 28], dtype=dtype)
onnx_path = Path(f"{HOME}/models/conv_{size}/conv_{size}.onnx")
if not onnx_path.is_file():
torch.onnx.export(m, A, onnx_path, export_params=True, input_names = ['input'], output_names=['output'])
Path(f"{HOME}/models/conv_{size_name}").mkdir(parents=True, exist_ok=True)
m = size_conv(C_in, C_out, K_sz, padding)
A = torch.zeros([1,C_in, H, W], dtype=dtype)
onnx_path = Path(f"{HOME}/models/conv_{size_name}/conv_{size_name}.onnx")
torch.onnx.export(m, A, onnx_path, export_params=True, input_names = ['input'], output_names=['output'])

# Generate model_list json file
config = {
"models": [
{
"name": f"conv_{size}",
"name": f"conv_{size_name}",
"request_time": 0
}
]
}
with open(f"{HOME}/model_lists/conv_{size}.json", "w") as json_file:
with open(f"{HOME}/model_lists/conv_{size_name}.json", "w") as json_file:
json.dump(config, json_file, indent=4)

0 comments on commit 321d981

Please sign in to comment.