From b479f8d4e5f9f84027b897a806e393bd87f2027f Mon Sep 17 00:00:00 2001 From: Wenhao Hu Date: Sat, 13 Oct 2018 20:01:17 +0900 Subject: [PATCH] refactor handler_helper --- onnx_tf/common/handler_helper.py | 60 +++++++++++++++++++++++++------- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/onnx_tf/common/handler_helper.py b/onnx_tf/common/handler_helper.py index 16af2374b..d81f7c781 100644 --- a/onnx_tf/common/handler_helper.py +++ b/onnx_tf/common/handler_helper.py @@ -9,6 +9,32 @@ from . import op_name_to_lower +class DomainHandlerDict(dict): + + def __init__(self, domain, unknown_message="", failed_message=""): + self.unknown = {} + self.failed = {} + self._domain = domain + self._unknown_message = unknown_message + self._failed_message = failed_message + + def _warn(self, k): + if k in self.unknown: + warnings.warn( + self._unknown_message.format(self._domain, self.unknown.pop(k))) + if k in self.failed: + warnings.warn( + self._failed_message.format(self._domain, k, self.failed.pop(k))) + + def __getitem__(self, k): + self._warn(k) + return super(DomainHandlerDict, self).__getitem__(k) + + def get(self, k, d=None): + self._warn(k) + return super(DomainHandlerDict, self).get(k, d) + + def get_all_frontend_handlers(opset_dict): """ Get a dict of all frontend handler classes. e.g. {'domain': {'Abs': Abs handler class}, ...}, }. @@ -23,6 +49,14 @@ def get_all_frontend_handlers(opset_dict): domain = handler.DOMAIN version = opset_dict[domain] handler.VERSION = version + domain_handler_dict = handlers.setdefault( + domain, + DomainHandlerDict( + domain or "ai.onnx", + unknown_message="Unknown op {1} in domain `{0}`. " + "Can't check specification by ONNX. " + "Please set should_check flag to False " + "when call make_node method in handler.")) since_version = 1 if handler.ONNX_OP and defs.has(handler.ONNX_OP, domain=handler.DOMAIN): @@ -30,16 +64,12 @@ def get_all_frontend_handlers(opset_dict): handler.ONNX_OP, domain=handler.DOMAIN, max_inclusive_version=version).since_version else: - warnings.warn("Unknown op {} in domain `{}`. " - "Can't check specification by ONNX. " - "Please set should_check flag to False " - "when call make_node method in handler.".format( - handler.ONNX_OP or "Undefined", handler.DOMAIN or - "ai.onnx")) + for tf_op in handler.TF_OP: + domain_handler_dict.unknown[tf_op] = handler.ONNX_OP or tf_op handler.SINCE_VERSION = since_version for tf_op in handler.TF_OP: - handlers.setdefault(domain, {})[tf_op] = handler + domain_handler_dict[tf_op] = handler return handlers @@ -57,6 +87,13 @@ def get_all_backend_handlers(opset_dict): domain = handler.DOMAIN version = opset_dict[domain] handler.VERSION = version + domain_handler_dict = handlers.setdefault( + domain, + DomainHandlerDict( + domain or "ai.onnx", + failed_message="Fail to get since_version of {1} in domain `{0}` " + "with max_inclusive_version={2}. Set to 1.", + unknown_message="Unknown op {1} in domain `{0}`.")) since_version = 1 if defs.has(handler.ONNX_OP, domain=handler.DOMAIN): @@ -66,14 +103,11 @@ def get_all_backend_handlers(opset_dict): domain=handler.DOMAIN, max_inclusive_version=version).since_version except RuntimeError: - warnings.warn("Fail to get since_version of {} in domain `{}` " - "with max_inclusive_version={}. Set to 1.".format( - handler.ONNX_OP, handler.DOMAIN, version)) + domain_handler_dict.failed[handler.ONNX_OP] = version else: - warnings.warn("Unknown op {} in domain `{}`.".format( - handler.ONNX_OP, handler.DOMAIN or "ai.onnx")) + domain_handler_dict.unknown[handler.ONNX_OP] = handler.ONNX_OP handler.SINCE_VERSION = since_version - handlers.setdefault(domain, {})[handler.ONNX_OP] = handler + domain_handler_dict[handler.ONNX_OP] = handler return handlers