diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 0000000..891053f --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,29 @@ +# Ultralytics 🚀 - AGPL-3.0 license +# Ultralytics Actions https://github.com/ultralytics/actions +# This workflow automatically formats code and documentation in PRs to official Ultralytics standards + +name: Ultralytics Actions + +on: + push: + branches: [main] + pull_request_target: + branches: [main] + types: [opened, closed, synchronize] + +jobs: + format: + runs-on: ubuntu-latest + steps: + - name: Run Ultralytics Formatting + uses: ultralytics/actions@main + with: + token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, do not modify + python: true # format Python code and docstrings + markdown: true # format Markdown + prettier: true # format YAML + spelling: true # check spelling + links: false # check broken links + summary: true # print PR summary with GPT4 (requires 'openai_api_key' or 'openai_azure_api_key' and 'openai_azure_endpoint') + openai_azure_api_key: ${{ secrets.OPENAI_AZURE_API_KEY }} + openai_azure_endpoint: ${{ secrets.OPENAI_AZURE_ENDPOINT }} diff --git a/README.md b/README.md index 8a2da2e..82c878a 100644 --- a/README.md +++ b/README.md @@ -1,43 +1,47 @@ # THOP: PyTorch-OpCounter -## How to install - -`pip install thop` (now continously intergrated on [Github actions](https://github.com/features/actions)) +## How to install + +`pip install thop` (now continuously integrated on [Github actions](https://github.com/features/actions)) OR `pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git` - -## How to use -* Basic usage + +## How to use + +- Basic usage + ```python from torchvision.models import resnet50 from thop import profile model = resnet50() input = torch.randn(1, 3, 224, 224) macs, params = profile(model, inputs=(input, )) - ``` + ``` + +- Define the rule for 3rd party module. -* Define the rule for 3rd party module. ```python class YourModule(nn.Module): # your definition def count_your_model(model, x, y): # your rule here - + input = torch.randn(1, 3, 224, 224) - macs, params = profile(model, inputs=(input, ), + macs, params = profile(model, inputs=(input, ), custom_ops={YourModule: count_your_model}) ``` - -* Improve the output readability + +- Improve the output readability Call `thop.clever_format` to give a better format of the output. + ```python from thop import clever_format macs, params = clever_format([macs, params], "%.3f") - ``` - + ``` + ## Results of Recent Models The implementation are adapted from `torchvision`. Following results can be obtained using [benchmark/evaluate_famous_models.py](benchmark/evaluate_famous_models.py). @@ -47,48 +51,48 @@ The implementation are adapted from `torchvision`. Following results can be obta -Model | Params(M) | MACs(G) ----|---|--- -alexnet | 61.10 | 0.77 -vgg11 | 132.86 | 7.74 -vgg11_bn | 132.87 | 7.77 -vgg13 | 133.05 | 11.44 -vgg13_bn | 133.05 | 11.49 -vgg16 | 138.36 | 15.61 -vgg16_bn | 138.37 | 15.66 -vgg19 | 143.67 | 19.77 -vgg19_bn | 143.68 | 19.83 -resnet18 | 11.69 | 1.82 -resnet34 | 21.80 | 3.68 -resnet50 | 25.56 | 4.14 -resnet101 | 44.55 | 7.87 -resnet152 | 60.19 | 11.61 -wide_resnet101_2 | 126.89 | 22.84 -wide_resnet50_2 | 68.88 | 11.46 +| Model | Params(M) | MACs(G) | +| ---------------- | --------- | ------- | +| alexnet | 61.10 | 0.77 | +| vgg11 | 132.86 | 7.74 | +| vgg11_bn | 132.87 | 7.77 | +| vgg13 | 133.05 | 11.44 | +| vgg13_bn | 133.05 | 11.49 | +| vgg16 | 138.36 | 15.61 | +| vgg16_bn | 138.37 | 15.66 | +| vgg19 | 143.67 | 19.77 | +| vgg19_bn | 143.68 | 19.83 | +| resnet18 | 11.69 | 1.82 | +| resnet34 | 21.80 | 3.68 | +| resnet50 | 25.56 | 4.14 | +| resnet101 | 44.55 | 7.87 | +| resnet152 | 60.19 | 11.61 | +| wide_resnet101_2 | 126.89 | 22.84 | +| wide_resnet50_2 | 68.88 | 11.46 | -Model | Params(M) | MACs(G) ----|---|--- -resnext50_32x4d | 25.03 | 4.29 -resnext101_32x8d | 88.79 | 16.54 -densenet121 | 7.98 | 2.90 -densenet161 | 28.68 | 7.85 -densenet169 | 14.15 | 3.44 -densenet201 | 20.01 | 4.39 -squeezenet1_0 | 1.25 | 0.82 -squeezenet1_1 | 1.24 | 0.35 -mnasnet0_5 | 2.22 | 0.14 -mnasnet0_75 | 3.17 | 0.24 -mnasnet1_0 | 4.38 | 0.34 -mnasnet1_3 | 6.28 | 0.53 -mobilenet_v2 | 3.50 | 0.33 -shufflenet_v2_x0_5 | 1.37 | 0.05 -shufflenet_v2_x1_0 | 2.28 | 0.15 -shufflenet_v2_x1_5 | 3.50 | 0.31 -shufflenet_v2_x2_0 | 7.39 | 0.60 -inception_v3 | 27.16 | 5.75 +| Model | Params(M) | MACs(G) | +| ------------------ | --------- | ------- | +| resnext50_32x4d | 25.03 | 4.29 | +| resnext101_32x8d | 88.79 | 16.54 | +| densenet121 | 7.98 | 2.90 | +| densenet161 | 28.68 | 7.85 | +| densenet169 | 14.15 | 3.44 | +| densenet201 | 20.01 | 4.39 | +| squeezenet1_0 | 1.25 | 0.82 | +| squeezenet1_1 | 1.24 | 0.35 | +| mnasnet0_5 | 2.22 | 0.14 | +| mnasnet0_75 | 3.17 | 0.24 | +| mnasnet1_0 | 4.38 | 0.34 | +| mnasnet1_3 | 6.28 | 0.53 | +| mobilenet_v2 | 3.50 | 0.33 | +| shufflenet_v2_x0_5 | 1.37 | 0.05 | +| shufflenet_v2_x1_0 | 2.28 | 0.15 | +| shufflenet_v2_x1_5 | 3.50 | 0.31 | +| shufflenet_v2_x2_0 | 7.39 | 0.60 | +| inception_v3 | 27.16 | 5.75 | diff --git a/TODO.md b/TODO.md index f56334c..0b7a187 100644 --- a/TODO.md +++ b/TODO.md @@ -1,7 +1,7 @@ TODOs. -1. A more user-friendly warning for un-defined modules. [Done] +1. A more user-friendly warning for un-defined modules. \[Done\] 2. Supports for models in torchvision (e.g., residual add). 3. Layer wise printing -Integration with torchprofile? \ No newline at end of file +Integration with torchprofile? diff --git a/benchmark/README.md b/benchmark/README.md index 55ae9e1..89b834c 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1,33 +1,30 @@ # MACs, FLOPs, what is the difference? -`FLOPs` is abbreviation of **floating operations** which includes mul / add / div ... etc. +`FLOPs` is abbreviation of **floating operations** which includes mul / add / div ... etc. -`MACs` stands for **multiply–accumulate operation** that performs `a <- a + (b x c)`. +`MACs` stands for **multiply–accumulate operation** that performs `a <- a + (b x c)`. As shown in the text, one `MACs` has one `mul` and one `add`. That is why in many places `FLOPs` is nearly two times as `MACs`. -However, the application in real world is far more complex. Let's consider a matrix multiplication example. -`A` is an matrix of dimension `mxn` and `B` is an vector of `nx1`. +However, the application in real world is far more complex. Let's consider a matrix multiplication example. `A` is an matrix of dimension `mxn` and `B` is an vector of `nx1`. ```python for i in range(m): for j in range(n): C[i][j] += A[i][j] * B[j] # one mul-add -``` +``` It would be `mn` `MACs` and `2mn` `FLOPs`. But such implementation is slow and parallelization is necessary to run faster - - ```python +```python for i in range(m): - parallelfor j in range(n): - d[j] = A[i][j] * B[j] # one mul - C[i][j] = sum(d) # n adds + parallelfor j in range(n): + d[j] = A[i][j] * B[j] # one mul + C[i][j] = sum(d) # n adds ``` -Then the number of `MACs` is no longer `mn` . - +Then the number of `MACs` is no longer `mn` . -When comparing MACs /FLOPs, we want the number to be implementation-agnostic and as general as possible. Therefore in THOP, **we only consider the number of multiplications** and ignore all other operations. +When comparing MACs /FLOPs, we want the number to be implementation-agnostic and as general as possible. Therefore in THOP, **we only consider the number of multiplications** and ignore all other operations. -PS: The FLOPs is approximated by multiplying two. \ No newline at end of file +PS: The FLOPs is approximated by multiplying two. diff --git a/benchmark/evaluate_famous_models.py b/benchmark/evaluate_famous_models.py index 7589526..5fb1cbf 100644 --- a/benchmark/evaluate_famous_models.py +++ b/benchmark/evaluate_famous_models.py @@ -1,5 +1,6 @@ import torch from torchvision import models + from thop.profile import profile model_names = sorted( @@ -24,6 +25,4 @@ dsize = (1, 3, 299, 299) inputs = torch.randn(dsize).to(device) total_ops, total_params = profile(model, (inputs,), verbose=False) - print( - "%s | %.2f | %.2f" % (name, total_params / (1000 ** 2), total_ops / (1000 ** 3)) - ) + print("%s | %.2f | %.2f" % (name, total_params / (1000**2), total_ops / (1000**3))) diff --git a/benchmark/evaluate_rnn_models.py b/benchmark/evaluate_rnn_models.py index 6def4ce..44cd3fb 100644 --- a/benchmark/evaluate_rnn_models.py +++ b/benchmark/evaluate_rnn_models.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn + from thop.profile import profile input_size = 160 @@ -18,15 +19,9 @@ "BiRNN": nn.Sequential(nn.RNN(input_size, hidden_size, bidirectional=True)), "BiGRU": nn.Sequential(nn.GRU(input_size, hidden_size, bidirectional=True)), "BiLSTM": nn.Sequential(nn.LSTM(input_size, hidden_size, bidirectional=True)), - "stacked-BiRNN": nn.Sequential( - nn.RNN(input_size, hidden_size, bidirectional=True, num_layers=4) - ), - "stacked-BiGRU": nn.Sequential( - nn.GRU(input_size, hidden_size, bidirectional=True, num_layers=4) - ), - "stacked-BiLSTM": nn.Sequential( - nn.LSTM(input_size, hidden_size, bidirectional=True, num_layers=4) - ), + "stacked-BiRNN": nn.Sequential(nn.RNN(input_size, hidden_size, bidirectional=True, num_layers=4)), + "stacked-BiGRU": nn.Sequential(nn.GRU(input_size, hidden_size, bidirectional=True, num_layers=4)), + "stacked-BiLSTM": nn.Sequential(nn.LSTM(input_size, hidden_size, bidirectional=True, num_layers=4)), } print("{} | {} | {}".format("Model", "Params(M)", "FLOPs(G)")) @@ -49,9 +44,7 @@ # validate batch_first support inputs = torch.randn(100, 32, input_size) -ops_time_first = profile( - nn.Sequential(nn.LSTM(input_size, hidden_size)), (inputs,), verbose=False -)[0] +ops_time_first = profile(nn.Sequential(nn.LSTM(input_size, hidden_size)), (inputs,), verbose=False)[0] ops_batch_first = profile( nn.Sequential(nn.LSTM(input_size, hidden_size, batch_first=True)), (inputs.transpose(0, 1),), diff --git a/setup.py b/setup.py index 100bd16..0e59925 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,10 @@ #!/usr/bin/env python -import os, sys -import shutil import datetime +import os +import shutil +import sys -from setuptools import setup, find_packages +from setuptools import find_packages, setup from setuptools.command.install import install readme = open("README.md").read() diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index 5dd22a0..5eb3978 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -1,13 +1,14 @@ -from jinja2 import StrictUndefined import pytest import torch import torch.nn as nn +from jinja2 import StrictUndefined + from thop import profile class TestUtils: def test_conv2d_no_bias(self): - n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist() + n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist() out_c, kh, kw = 12, 5, 5 s, p, d, g = 1, 1, 1, 1 @@ -17,11 +18,11 @@ def test_conv2d_no_bias(self): _, _, oh, ow = out.shape - flops, params = profile(net, inputs=(data, )) + flops, params = profile(net, inputs=(data,)) assert flops == 810000, f"{flops} v.s. {810000}" def test_conv2d(self): - n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist() + n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist() out_c, kh, kw = 12, 5, 5 s, p, d, g = 1, 1, 1, 1 @@ -31,14 +32,14 @@ def test_conv2d(self): _, _, oh, ow = out.shape - flops, params = profile(net, inputs=(data, )) + flops, params = profile(net, inputs=(data,)) assert flops == 810000, f"{flops} v.s. {810000}" - + def test_conv2d_random(self): for i in range(10): - out_c, kh, kw = torch.randint(1, 20, (3,)).tolist() - n, in_c, ih, iw = torch.randint(1, 20, (4,)).tolist() # torch.randint(1, 10, (4,)).tolist() - ih += kh + out_c, kh, kw = torch.randint(1, 20, (3,)).tolist() + n, in_c, ih, iw = torch.randint(1, 20, (4,)).tolist() # torch.randint(1, 10, (4,)).tolist() + ih += kh iw += kw s, p, d, g = 1, 1, 1, 1 @@ -48,6 +49,8 @@ def test_conv2d_random(self): _, _, oh, ow = out.shape - flops, params = profile(net, inputs=(data, )) + flops, params = profile(net, inputs=(data,)) print(flops, params) - assert flops == n * out_c * oh * ow // g * in_c * kh * kw , f"{flops} v.s. {n * out_c * oh * ow // g * in_c * kh * kw}" \ No newline at end of file + assert ( + flops == n * out_c * oh * ow // g * in_c * kh * kw + ), f"{flops} v.s. {n * out_c * oh * ow // g * in_c * kh * kw}" diff --git a/tests/test_matmul.py b/tests/test_matmul.py index ae31b65..cdd8083 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -1,6 +1,7 @@ import pytest import torch import torch.nn as nn + from thop import profile @@ -8,7 +9,7 @@ class TestUtils: def test_matmul_case2(self): n, in_c, out_c = 1, 100, 200 net = nn.Linear(in_c, out_c) - flops, params = profile(net, inputs=(torch.randn(n, in_c), )) + flops, params = profile(net, inputs=(torch.randn(n, in_c),)) print(flops, params) assert flops == n * in_c * out_c @@ -16,14 +17,13 @@ def test_matmul_case2(self): for i in range(10): n, in_c, out_c = torch.randint(1, 500, (3,)).tolist() net = nn.Linear(in_c, out_c) - flops, params = profile(net, inputs=(torch.randn(n, in_c), )) + flops, params = profile(net, inputs=(torch.randn(n, in_c),)) print(flops, params) assert flops == n * in_c * out_c - + def test_conv2d(self): n, in_c, out_c = torch.randint(1, 500, (3,)).tolist() net = nn.Linear(in_c, out_c) - flops, params = profile(net, inputs=(torch.randn(n, in_c), )) + flops, params = profile(net, inputs=(torch.randn(n, in_c),)) print(flops, params) assert flops == n * in_c * out_c - diff --git a/tests/test_relu.py b/tests/test_relu.py index 5e2a221..727e0bd 100644 --- a/tests/test_relu.py +++ b/tests/test_relu.py @@ -1,6 +1,7 @@ import pytest import torch import torch.nn as nn + from thop import profile @@ -9,8 +10,6 @@ def test_relu(self): n, in_c, out_c = 1, 100, 200 data = torch.randn(n, in_c) net = nn.ReLU() - flops, params = profile(net, inputs=(torch.randn(n, in_c), )) + flops, params = profile(net, inputs=(torch.randn(n, in_c),)) print(flops, params) assert flops == 0 - - diff --git a/tests/test_utils.py b/tests/test_utils.py index 3afdbfd..1b5682e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ -from thop import utils import pytest +from thop import utils + class TestUtils: def test_clever_format_returns_formatted_number(self): @@ -14,4 +15,3 @@ def test_clever_format_returns_formatted_numbers(self): format = "%.2f" clever_nums = utils.clever_format(nums, format) assert clever_nums == ("1.00B", "2.00B") - \ No newline at end of file diff --git a/thop/__init__.py b/thop/__init__.py index 1ce3290..7a022e4 100644 --- a/thop/__init__.py +++ b/thop/__init__.py @@ -1,7 +1,8 @@ -from .utils import clever_format -from .profile import profile, profile_origin # from .onnx_profile import OnnxProfile import torch +from .profile import profile, profile_origin +from .utils import clever_format + default_dtype = torch.float64 -from .__version__ import __version__ \ No newline at end of file +from .__version__ import __version__ diff --git a/thop/__version__.py b/thop/__version__.py index d1f2e39..485f44a 100644 --- a/thop/__version__.py +++ b/thop/__version__.py @@ -1 +1 @@ -__version__ = "0.1.1" \ No newline at end of file +__version__ = "0.1.1" diff --git a/thop/fx_profile.py b/thop/fx_profile.py index 8faadf7..fe15f02 100644 --- a/thop/fx_profile.py +++ b/thop/fx_profile.py @@ -1,8 +1,9 @@ import logging +from distutils.version import LooseVersion + import torch import torch as th import torch.nn as nn -from distutils.version import LooseVersion if LooseVersion(torch.__version__) < LooseVersion("1.8.0"): logging.warning( @@ -50,9 +51,7 @@ def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs): bias_op = 0 # check it later in_channel = x_shape[1] - total_ops = calculate_conv( - bias_op, kernel_parameters, out_shape.numel(), in_channel, groups - ).item() + total_ops = calculate_conv(bias_op, kernel_parameters, out_shape.numel(), in_channel, groups).item() return int(total_ops) @@ -71,9 +70,7 @@ def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes): in_channel = module.in_channels groups = module.groups kernel_ops = module.weight.shape[2:].numel() - total_ops = calculate_conv( - bias_op, kernel_ops, out_shape.numel(), in_channel, groups - ).item() + total_ops = calculate_conv(bias_op, kernel_ops, out_shape.numel(), in_channel, groups).item() return int(total_ops) @@ -114,6 +111,7 @@ def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes): from torch.fx import symbolic_trace from torch.fx.passes.shape_prop import ShapeProp + from .utils import prGreen, prRed, prYellow @@ -135,9 +133,7 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False): for node in gm.graph.nodes: # print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}") - fprint( - f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}" - ) + fprint(f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}") # node_op_type = str(node.target).split(".")[-1] node_flops = None @@ -157,17 +153,9 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False): node_flops = 0 elif node.op == "call_function": # torch internal functions - key = ( - str(node.target) - .split("at")[0] - .replace("<", "") - .replace(">", "") - .strip() - ) + key = str(node.target).split("at")[0].replace("<", "").replace(">", "").strip() if key in count_map: - node_flops = count_map[key]( - input_shapes, output_shapes, *node.args, **node.kwargs - ) + node_flops = count_map[key](input_shapes, output_shapes, *node.args, **node.kwargs) else: missing_maps[key] = (node.op, key) prRed(f"|{key}| is missing") @@ -196,9 +184,7 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False): print(f"weight_shape: None") else: print(type(m)) - print( - f"weight_shape: {mod.state_dict()[node.target + '.weight'].shape}" - ) + print(f"weight_shape: {mod.state_dict()[node.target + '.weight'].shape}") v_maps[str(node.name)] = node.meta["tensor_meta"].shape if node_flops is not None: @@ -208,6 +194,7 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False): if len(missing_maps.keys()) > 0: from pprint import pprint + print("Missing operators: ") pprint(missing_maps) return total_flops diff --git a/thop/onnx_profile.py b/thop/onnx_profile.py index 10da68c..0442f27 100644 --- a/thop/onnx_profile.py +++ b/thop/onnx_profile.py @@ -1,8 +1,9 @@ +import numpy as np +import onnx import torch import torch.nn -import onnx from onnx import numpy_helper -import numpy as np + from thop.vision.onnx_counter import onnx_operators diff --git a/thop/profile.py b/thop/profile.py index 6b15d27..d87afd6 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -1,19 +1,15 @@ from distutils.version import LooseVersion -from thop.vision.basic_hooks import * from thop.rnn_hooks import * - +from thop.vision.basic_hooks import * # logger = logging.getLogger(__name__) # logger.setLevel(logging.INFO) - from .utils import prGreen, prRed, prYellow if LooseVersion(torch.__version__) < LooseVersion("1.0.0"): logging.warning( - "You are using an old version PyTorch {version}, which THOP does NOT support.".format( - version=torch.__version__ - ) + "You are using an old version PyTorch {version}, which THOP does NOT support.".format(version=torch.__version__) ) default_dtype = torch.float64 @@ -96,9 +92,7 @@ def add_hooks(m): m_type = type(m) fn = None - if ( - m_type in custom_ops - ): # if defined both op maps, use custom_ops to overwrite. + if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite. fn = custom_ops[m_type] if m_type not in types_collection and verbose: print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type)) @@ -108,10 +102,7 @@ def add_hooks(m): print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type)) else: if m_type not in types_collection and report_missing: - prRed( - "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." - % m_type - ) + prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type) if fn is not None: handler = m.register_forward_hook(fn) @@ -191,10 +182,7 @@ def add_hooks(m: nn.Module): print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type)) else: if m_type not in types_collection and report_missing: - prRed( - "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." - % m_type - ) + prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type) if fn is not None: handler_collection[m] = ( @@ -220,9 +208,7 @@ def dfs_count(module: nn.Module, prefix="\t") -> (int, int): # else: # m_ops, m_params = m.total_ops, m.total_params next_dict = {} - if m in handler_collection and not isinstance( - m, (nn.Sequential, nn.ModuleList) - ): + if m in handler_collection and not isinstance(m, (nn.Sequential, nn.ModuleList)): m_ops, m_params = m.total_ops.item(), m.total_params.item() else: m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t") diff --git a/thop/utils.py b/thop/utils.py index 5f1c4bb..993761c 100644 --- a/thop/utils.py +++ b/thop/utils.py @@ -4,13 +4,16 @@ COLOR_GREEN = "92m" COLOR_YELLOW = "93m" + def colorful_print(fn_print, color=COLOR_RED): def actual_call(*args, **kwargs): print(f"\033[{color}", end="") fn_print(*args, **kwargs) print("\033[00m", end="") + return actual_call + prRed = colorful_print(print, color=COLOR_RED) prGreen = colorful_print(print, color=COLOR_GREEN) prYellow = colorful_print(print, color=COLOR_YELLOW) @@ -50,4 +53,4 @@ def clever_format(nums, format="%.2f"): if __name__ == "__main__": prRed("hello", "world") prGreen("hello", "world") - prYellow("hello", "world") \ No newline at end of file + prYellow("hello", "world") diff --git a/thop/vision/basic_hooks.py b/thop/vision/basic_hooks.py index fb864b6..e8abcf6 100644 --- a/thop/vision/basic_hooks.py +++ b/thop/vision/basic_hooks.py @@ -1,10 +1,12 @@ import argparse import logging -from .calc_func import * + import torch import torch.nn as nn from torch.nn.modules.conv import _ConvNd +from .calc_func import * + multiply_adds = 1 @@ -26,11 +28,11 @@ def count_convNd(m: _ConvNd, x, y: torch.Tensor): bias_ops = 1 if m.bias is not None else 0 m.total_ops += calculate_conv2d_flops( - input_size = list(x.shape), - output_size = list(y.shape), - kernel_size = list(m.weight.shape), - groups = m.groups, - bias = m.bias + input_size=list(x.shape), + output_size=list(y.shape), + kernel_size=list(m.weight.shape), + groups=m.groups, + bias=m.bias, ) # N x Cout x H x W x (Cin x Kw x Kh + bias) # m.total_ops += calculate_conv( @@ -64,7 +66,7 @@ def count_normalization(m: nn.modules.batchnorm._BatchNorm, x, y): x = x[0] # bn is by default fused in inference flops = calculate_norm(x.numel()) - if (getattr(m, 'affine', False) or getattr(m, 'elementwise_affine', False)): + if getattr(m, "affine", False) or getattr(m, "elementwise_affine", False): flops *= 2 m.total_ops += flops @@ -112,10 +114,7 @@ def count_avgpool(m, x, y): def count_adap_avgpool(m, x, y): - kernel = torch.div( - torch.DoubleTensor([*(x[0].shape[2:])]), - torch.DoubleTensor([*(y.shape[2:])]) - ) + kernel = torch.div(torch.DoubleTensor([*(x[0].shape[2:])]), torch.DoubleTensor([*(y.shape[2:])])) total_add = torch.prod(kernel) num_elements = y.numel() m.total_ops += calculate_adaptive_avg(total_add, num_elements) diff --git a/thop/vision/calc_func.py b/thop/vision/calc_func.py index f6af5ce..2e2f169 100644 --- a/thop/vision/calc_func.py +++ b/thop/vision/calc_func.py @@ -1,13 +1,16 @@ -import torch -import numpy as np import warnings +import numpy as np +import torch + + def l_prod(in_list): res = 1 for _ in in_list: res *= _ return res + def l_sum(in_list): res = 0 for _ in in_list: @@ -25,6 +28,7 @@ def calculate_parameters(param_list): def calculate_zero_ops(): return torch.DoubleTensor([int(0)]) + def calculate_conv2d_flops(input_size: list, output_size: list, kernel_size: list, groups: int, bias: bool = False): # n, out_c, oh, ow = output_size # n, in_c, ih, iw = input_size @@ -36,18 +40,19 @@ def calculate_conv2d_flops(input_size: list, output_size: list, kernel_size: lis def calculate_conv(bias, kernel_size, output_size, in_channel, group): warnings.warn("This API is being deprecated.") - """inputs are all numbers!""" + """Inputs are all numbers!""" return torch.DoubleTensor([output_size * (in_channel / group * kernel_size + bias)]) def calculate_norm(input_size): - """input is a number not a array or tensor""" + """Input is a number not a array or tensor.""" return torch.DoubleTensor([2 * input_size]) + def calculate_relu_flops(input_size): # x[x < 0] = 0 return 0 - + def calculate_relu(input_size: torch.Tensor): warnings.warn("This API is being deprecated") diff --git a/thop/vision/efficientnet.py b/thop/vision/efficientnet.py index b2fe162..8de88d1 100644 --- a/thop/vision/efficientnet.py +++ b/thop/vision/efficientnet.py @@ -3,8 +3,7 @@ import torch import torch.nn as nn -from torch.nn.modules.conv import _ConvNd - from efficientnet_pytorch.utils import Conv2dDynamicSamePadding, Conv2dStaticSamePadding +from torch.nn.modules.conv import _ConvNd register_hooks = {} diff --git a/thop/vision/onnx_counter.py b/thop/vision/onnx_counter.py index 8beb39a..52bd1e5 100644 --- a/thop/vision/onnx_counter.py +++ b/thop/vision/onnx_counter.py @@ -1,18 +1,20 @@ -import torch import numpy as np +import torch from onnx import numpy_helper + from thop.vision.basic_hooks import zero_ops + from .calc_func import ( - counter_matmul, - calculate_zero_ops, + calculate_avgpool, calculate_conv, - counter_mul, calculate_norm, + calculate_softmax, + calculate_zero_ops, + counter_div, + counter_matmul, + counter_mul, counter_pow, counter_sqrt, - counter_div, - calculate_softmax, - calculate_avgpool, ) @@ -65,20 +67,12 @@ def onnx_counter_conv(diction, node): group = attr.i # print(dim_dil) dim_input = diction[node.input[0]] - output_size = np.append( - dim_input[0 : -np.array(dim_kernel).size - 1], dim_weight[0] - ) + output_size = np.append(dim_input[0 : -np.array(dim_kernel).size - 1], dim_weight[0]) hw = np.array(dim_input[-np.array(dim_kernel).size :]) for i in range(hw.size): - hw[i] = int( - (hw[i] + 2 * dim_pad[i] - dim_dil[i] * (dim_kernel[i] - 1) - 1) - / dim_stride[i] - + 1 - ) + hw[i] = int((hw[i] + 2 * dim_pad[i] - dim_dil[i] * (dim_kernel[i] - 1) - 1) / dim_stride[i] + 1) output_size = np.append(output_size, hw) - macs = calculate_conv( - dim_bias, np.prod(dim_kernel), np.prod(output_size), dim_weight[1], group - ) + macs = calculate_conv(dim_bias, np.prod(dim_kernel), np.prod(output_size), dim_weight[1], group) output_name = node.output[0] # if '140' in diction: