diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
new file mode 100644
index 0000000..891053f
--- /dev/null
+++ b/.github/workflows/format.yml
@@ -0,0 +1,29 @@
+# Ultralytics 🚀 - AGPL-3.0 license
+# Ultralytics Actions https://github.com/ultralytics/actions
+# This workflow automatically formats code and documentation in PRs to official Ultralytics standards
+
+name: Ultralytics Actions
+
+on:
+ push:
+ branches: [main]
+ pull_request_target:
+ branches: [main]
+ types: [opened, closed, synchronize]
+
+jobs:
+ format:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Run Ultralytics Formatting
+ uses: ultralytics/actions@main
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, do not modify
+ python: true # format Python code and docstrings
+ markdown: true # format Markdown
+ prettier: true # format YAML
+ spelling: true # check spelling
+ links: false # check broken links
+ summary: true # print PR summary with GPT4 (requires 'openai_api_key' or 'openai_azure_api_key' and 'openai_azure_endpoint')
+ openai_azure_api_key: ${{ secrets.OPENAI_AZURE_API_KEY }}
+ openai_azure_endpoint: ${{ secrets.OPENAI_AZURE_ENDPOINT }}
diff --git a/README.md b/README.md
index 8a2da2e..82c878a 100644
--- a/README.md
+++ b/README.md
@@ -1,43 +1,47 @@
# THOP: PyTorch-OpCounter
-## How to install
-
-`pip install thop` (now continously intergrated on [Github actions](https://github.com/features/actions))
+## How to install
+
+`pip install thop` (now continuously integrated on [Github actions](https://github.com/features/actions))
OR
`pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git`
-
-## How to use
-* Basic usage
+
+## How to use
+
+- Basic usage
+
```python
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ))
- ```
+ ```
+
+- Define the rule for 3rd party module.
-* Define the rule for 3rd party module.
```python
class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule here
-
+
input = torch.randn(1, 3, 224, 224)
- macs, params = profile(model, inputs=(input, ),
+ macs, params = profile(model, inputs=(input, ),
custom_ops={YourModule: count_your_model})
```
-
-* Improve the output readability
+
+- Improve the output readability
Call `thop.clever_format` to give a better format of the output.
+
```python
from thop import clever_format
macs, params = clever_format([macs, params], "%.3f")
- ```
-
+ ```
+
## Results of Recent Models
The implementation are adapted from `torchvision`. Following results can be obtained using [benchmark/evaluate_famous_models.py](benchmark/evaluate_famous_models.py).
@@ -47,48 +51,48 @@ The implementation are adapted from `torchvision`. Following results can be obta
-Model | Params(M) | MACs(G)
----|---|---
-alexnet | 61.10 | 0.77
-vgg11 | 132.86 | 7.74
-vgg11_bn | 132.87 | 7.77
-vgg13 | 133.05 | 11.44
-vgg13_bn | 133.05 | 11.49
-vgg16 | 138.36 | 15.61
-vgg16_bn | 138.37 | 15.66
-vgg19 | 143.67 | 19.77
-vgg19_bn | 143.68 | 19.83
-resnet18 | 11.69 | 1.82
-resnet34 | 21.80 | 3.68
-resnet50 | 25.56 | 4.14
-resnet101 | 44.55 | 7.87
-resnet152 | 60.19 | 11.61
-wide_resnet101_2 | 126.89 | 22.84
-wide_resnet50_2 | 68.88 | 11.46
+| Model | Params(M) | MACs(G) |
+| ---------------- | --------- | ------- |
+| alexnet | 61.10 | 0.77 |
+| vgg11 | 132.86 | 7.74 |
+| vgg11_bn | 132.87 | 7.77 |
+| vgg13 | 133.05 | 11.44 |
+| vgg13_bn | 133.05 | 11.49 |
+| vgg16 | 138.36 | 15.61 |
+| vgg16_bn | 138.37 | 15.66 |
+| vgg19 | 143.67 | 19.77 |
+| vgg19_bn | 143.68 | 19.83 |
+| resnet18 | 11.69 | 1.82 |
+| resnet34 | 21.80 | 3.68 |
+| resnet50 | 25.56 | 4.14 |
+| resnet101 | 44.55 | 7.87 |
+| resnet152 | 60.19 | 11.61 |
+| wide_resnet101_2 | 126.89 | 22.84 |
+| wide_resnet50_2 | 68.88 | 11.46 |
|
-Model | Params(M) | MACs(G)
----|---|---
-resnext50_32x4d | 25.03 | 4.29
-resnext101_32x8d | 88.79 | 16.54
-densenet121 | 7.98 | 2.90
-densenet161 | 28.68 | 7.85
-densenet169 | 14.15 | 3.44
-densenet201 | 20.01 | 4.39
-squeezenet1_0 | 1.25 | 0.82
-squeezenet1_1 | 1.24 | 0.35
-mnasnet0_5 | 2.22 | 0.14
-mnasnet0_75 | 3.17 | 0.24
-mnasnet1_0 | 4.38 | 0.34
-mnasnet1_3 | 6.28 | 0.53
-mobilenet_v2 | 3.50 | 0.33
-shufflenet_v2_x0_5 | 1.37 | 0.05
-shufflenet_v2_x1_0 | 2.28 | 0.15
-shufflenet_v2_x1_5 | 3.50 | 0.31
-shufflenet_v2_x2_0 | 7.39 | 0.60
-inception_v3 | 27.16 | 5.75
+| Model | Params(M) | MACs(G) |
+| ------------------ | --------- | ------- |
+| resnext50_32x4d | 25.03 | 4.29 |
+| resnext101_32x8d | 88.79 | 16.54 |
+| densenet121 | 7.98 | 2.90 |
+| densenet161 | 28.68 | 7.85 |
+| densenet169 | 14.15 | 3.44 |
+| densenet201 | 20.01 | 4.39 |
+| squeezenet1_0 | 1.25 | 0.82 |
+| squeezenet1_1 | 1.24 | 0.35 |
+| mnasnet0_5 | 2.22 | 0.14 |
+| mnasnet0_75 | 3.17 | 0.24 |
+| mnasnet1_0 | 4.38 | 0.34 |
+| mnasnet1_3 | 6.28 | 0.53 |
+| mobilenet_v2 | 3.50 | 0.33 |
+| shufflenet_v2_x0_5 | 1.37 | 0.05 |
+| shufflenet_v2_x1_0 | 2.28 | 0.15 |
+| shufflenet_v2_x1_5 | 3.50 | 0.31 |
+| shufflenet_v2_x2_0 | 7.39 | 0.60 |
+| inception_v3 | 27.16 | 5.75 |
|
diff --git a/TODO.md b/TODO.md
index f56334c..0b7a187 100644
--- a/TODO.md
+++ b/TODO.md
@@ -1,7 +1,7 @@
TODOs.
-1. A more user-friendly warning for un-defined modules. [Done]
+1. A more user-friendly warning for un-defined modules. \[Done\]
2. Supports for models in torchvision (e.g., residual add).
3. Layer wise printing
-Integration with torchprofile?
\ No newline at end of file
+Integration with torchprofile?
diff --git a/benchmark/README.md b/benchmark/README.md
index 55ae9e1..89b834c 100644
--- a/benchmark/README.md
+++ b/benchmark/README.md
@@ -1,33 +1,30 @@
# MACs, FLOPs, what is the difference?
-`FLOPs` is abbreviation of **floating operations** which includes mul / add / div ... etc.
+`FLOPs` is abbreviation of **floating operations** which includes mul / add / div ... etc.
-`MACs` stands for **multiply–accumulate operation** that performs `a <- a + (b x c)`.
+`MACs` stands for **multiply–accumulate operation** that performs `a <- a + (b x c)`.
As shown in the text, one `MACs` has one `mul` and one `add`. That is why in many places `FLOPs` is nearly two times as `MACs`.
-However, the application in real world is far more complex. Let's consider a matrix multiplication example.
-`A` is an matrix of dimension `mxn` and `B` is an vector of `nx1`.
+However, the application in real world is far more complex. Let's consider a matrix multiplication example. `A` is an matrix of dimension `mxn` and `B` is an vector of `nx1`.
```python
for i in range(m):
for j in range(n):
C[i][j] += A[i][j] * B[j] # one mul-add
-```
+```
It would be `mn` `MACs` and `2mn` `FLOPs`. But such implementation is slow and parallelization is necessary to run faster
-
- ```python
+```python
for i in range(m):
- parallelfor j in range(n):
- d[j] = A[i][j] * B[j] # one mul
- C[i][j] = sum(d) # n adds
+ parallelfor j in range(n):
+ d[j] = A[i][j] * B[j] # one mul
+ C[i][j] = sum(d) # n adds
```
-Then the number of `MACs` is no longer `mn` .
-
+Then the number of `MACs` is no longer `mn` .
-When comparing MACs /FLOPs, we want the number to be implementation-agnostic and as general as possible. Therefore in THOP, **we only consider the number of multiplications** and ignore all other operations.
+When comparing MACs /FLOPs, we want the number to be implementation-agnostic and as general as possible. Therefore in THOP, **we only consider the number of multiplications** and ignore all other operations.
-PS: The FLOPs is approximated by multiplying two.
\ No newline at end of file
+PS: The FLOPs is approximated by multiplying two.
diff --git a/benchmark/evaluate_famous_models.py b/benchmark/evaluate_famous_models.py
index 7589526..5fb1cbf 100644
--- a/benchmark/evaluate_famous_models.py
+++ b/benchmark/evaluate_famous_models.py
@@ -1,5 +1,6 @@
import torch
from torchvision import models
+
from thop.profile import profile
model_names = sorted(
@@ -24,6 +25,4 @@
dsize = (1, 3, 299, 299)
inputs = torch.randn(dsize).to(device)
total_ops, total_params = profile(model, (inputs,), verbose=False)
- print(
- "%s | %.2f | %.2f" % (name, total_params / (1000 ** 2), total_ops / (1000 ** 3))
- )
+ print("%s | %.2f | %.2f" % (name, total_params / (1000**2), total_ops / (1000**3)))
diff --git a/benchmark/evaluate_rnn_models.py b/benchmark/evaluate_rnn_models.py
index 6def4ce..44cd3fb 100644
--- a/benchmark/evaluate_rnn_models.py
+++ b/benchmark/evaluate_rnn_models.py
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
+
from thop.profile import profile
input_size = 160
@@ -18,15 +19,9 @@
"BiRNN": nn.Sequential(nn.RNN(input_size, hidden_size, bidirectional=True)),
"BiGRU": nn.Sequential(nn.GRU(input_size, hidden_size, bidirectional=True)),
"BiLSTM": nn.Sequential(nn.LSTM(input_size, hidden_size, bidirectional=True)),
- "stacked-BiRNN": nn.Sequential(
- nn.RNN(input_size, hidden_size, bidirectional=True, num_layers=4)
- ),
- "stacked-BiGRU": nn.Sequential(
- nn.GRU(input_size, hidden_size, bidirectional=True, num_layers=4)
- ),
- "stacked-BiLSTM": nn.Sequential(
- nn.LSTM(input_size, hidden_size, bidirectional=True, num_layers=4)
- ),
+ "stacked-BiRNN": nn.Sequential(nn.RNN(input_size, hidden_size, bidirectional=True, num_layers=4)),
+ "stacked-BiGRU": nn.Sequential(nn.GRU(input_size, hidden_size, bidirectional=True, num_layers=4)),
+ "stacked-BiLSTM": nn.Sequential(nn.LSTM(input_size, hidden_size, bidirectional=True, num_layers=4)),
}
print("{} | {} | {}".format("Model", "Params(M)", "FLOPs(G)"))
@@ -49,9 +44,7 @@
# validate batch_first support
inputs = torch.randn(100, 32, input_size)
-ops_time_first = profile(
- nn.Sequential(nn.LSTM(input_size, hidden_size)), (inputs,), verbose=False
-)[0]
+ops_time_first = profile(nn.Sequential(nn.LSTM(input_size, hidden_size)), (inputs,), verbose=False)[0]
ops_batch_first = profile(
nn.Sequential(nn.LSTM(input_size, hidden_size, batch_first=True)),
(inputs.transpose(0, 1),),
diff --git a/setup.py b/setup.py
index 100bd16..0e59925 100644
--- a/setup.py
+++ b/setup.py
@@ -1,9 +1,10 @@
#!/usr/bin/env python
-import os, sys
-import shutil
import datetime
+import os
+import shutil
+import sys
-from setuptools import setup, find_packages
+from setuptools import find_packages, setup
from setuptools.command.install import install
readme = open("README.md").read()
diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py
index 5dd22a0..5eb3978 100644
--- a/tests/test_conv2d.py
+++ b/tests/test_conv2d.py
@@ -1,13 +1,14 @@
-from jinja2 import StrictUndefined
import pytest
import torch
import torch.nn as nn
+from jinja2 import StrictUndefined
+
from thop import profile
class TestUtils:
def test_conv2d_no_bias(self):
- n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist()
+ n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist()
out_c, kh, kw = 12, 5, 5
s, p, d, g = 1, 1, 1, 1
@@ -17,11 +18,11 @@ def test_conv2d_no_bias(self):
_, _, oh, ow = out.shape
- flops, params = profile(net, inputs=(data, ))
+ flops, params = profile(net, inputs=(data,))
assert flops == 810000, f"{flops} v.s. {810000}"
def test_conv2d(self):
- n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist()
+ n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist()
out_c, kh, kw = 12, 5, 5
s, p, d, g = 1, 1, 1, 1
@@ -31,14 +32,14 @@ def test_conv2d(self):
_, _, oh, ow = out.shape
- flops, params = profile(net, inputs=(data, ))
+ flops, params = profile(net, inputs=(data,))
assert flops == 810000, f"{flops} v.s. {810000}"
-
+
def test_conv2d_random(self):
for i in range(10):
- out_c, kh, kw = torch.randint(1, 20, (3,)).tolist()
- n, in_c, ih, iw = torch.randint(1, 20, (4,)).tolist() # torch.randint(1, 10, (4,)).tolist()
- ih += kh
+ out_c, kh, kw = torch.randint(1, 20, (3,)).tolist()
+ n, in_c, ih, iw = torch.randint(1, 20, (4,)).tolist() # torch.randint(1, 10, (4,)).tolist()
+ ih += kh
iw += kw
s, p, d, g = 1, 1, 1, 1
@@ -48,6 +49,8 @@ def test_conv2d_random(self):
_, _, oh, ow = out.shape
- flops, params = profile(net, inputs=(data, ))
+ flops, params = profile(net, inputs=(data,))
print(flops, params)
- assert flops == n * out_c * oh * ow // g * in_c * kh * kw , f"{flops} v.s. {n * out_c * oh * ow // g * in_c * kh * kw}"
\ No newline at end of file
+ assert (
+ flops == n * out_c * oh * ow // g * in_c * kh * kw
+ ), f"{flops} v.s. {n * out_c * oh * ow // g * in_c * kh * kw}"
diff --git a/tests/test_matmul.py b/tests/test_matmul.py
index ae31b65..cdd8083 100644
--- a/tests/test_matmul.py
+++ b/tests/test_matmul.py
@@ -1,6 +1,7 @@
import pytest
import torch
import torch.nn as nn
+
from thop import profile
@@ -8,7 +9,7 @@ class TestUtils:
def test_matmul_case2(self):
n, in_c, out_c = 1, 100, 200
net = nn.Linear(in_c, out_c)
- flops, params = profile(net, inputs=(torch.randn(n, in_c), ))
+ flops, params = profile(net, inputs=(torch.randn(n, in_c),))
print(flops, params)
assert flops == n * in_c * out_c
@@ -16,14 +17,13 @@ def test_matmul_case2(self):
for i in range(10):
n, in_c, out_c = torch.randint(1, 500, (3,)).tolist()
net = nn.Linear(in_c, out_c)
- flops, params = profile(net, inputs=(torch.randn(n, in_c), ))
+ flops, params = profile(net, inputs=(torch.randn(n, in_c),))
print(flops, params)
assert flops == n * in_c * out_c
-
+
def test_conv2d(self):
n, in_c, out_c = torch.randint(1, 500, (3,)).tolist()
net = nn.Linear(in_c, out_c)
- flops, params = profile(net, inputs=(torch.randn(n, in_c), ))
+ flops, params = profile(net, inputs=(torch.randn(n, in_c),))
print(flops, params)
assert flops == n * in_c * out_c
-
diff --git a/tests/test_relu.py b/tests/test_relu.py
index 5e2a221..727e0bd 100644
--- a/tests/test_relu.py
+++ b/tests/test_relu.py
@@ -1,6 +1,7 @@
import pytest
import torch
import torch.nn as nn
+
from thop import profile
@@ -9,8 +10,6 @@ def test_relu(self):
n, in_c, out_c = 1, 100, 200
data = torch.randn(n, in_c)
net = nn.ReLU()
- flops, params = profile(net, inputs=(torch.randn(n, in_c), ))
+ flops, params = profile(net, inputs=(torch.randn(n, in_c),))
print(flops, params)
assert flops == 0
-
-
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 3afdbfd..1b5682e 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,6 +1,7 @@
-from thop import utils
import pytest
+from thop import utils
+
class TestUtils:
def test_clever_format_returns_formatted_number(self):
@@ -14,4 +15,3 @@ def test_clever_format_returns_formatted_numbers(self):
format = "%.2f"
clever_nums = utils.clever_format(nums, format)
assert clever_nums == ("1.00B", "2.00B")
-
\ No newline at end of file
diff --git a/thop/__init__.py b/thop/__init__.py
index 1ce3290..7a022e4 100644
--- a/thop/__init__.py
+++ b/thop/__init__.py
@@ -1,7 +1,8 @@
-from .utils import clever_format
-from .profile import profile, profile_origin
# from .onnx_profile import OnnxProfile
import torch
+from .profile import profile, profile_origin
+from .utils import clever_format
+
default_dtype = torch.float64
-from .__version__ import __version__
\ No newline at end of file
+from .__version__ import __version__
diff --git a/thop/__version__.py b/thop/__version__.py
index d1f2e39..485f44a 100644
--- a/thop/__version__.py
+++ b/thop/__version__.py
@@ -1 +1 @@
-__version__ = "0.1.1"
\ No newline at end of file
+__version__ = "0.1.1"
diff --git a/thop/fx_profile.py b/thop/fx_profile.py
index 8faadf7..fe15f02 100644
--- a/thop/fx_profile.py
+++ b/thop/fx_profile.py
@@ -1,8 +1,9 @@
import logging
+from distutils.version import LooseVersion
+
import torch
import torch as th
import torch.nn as nn
-from distutils.version import LooseVersion
if LooseVersion(torch.__version__) < LooseVersion("1.8.0"):
logging.warning(
@@ -50,9 +51,7 @@ def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
bias_op = 0 # check it later
in_channel = x_shape[1]
- total_ops = calculate_conv(
- bias_op, kernel_parameters, out_shape.numel(), in_channel, groups
- ).item()
+ total_ops = calculate_conv(bias_op, kernel_parameters, out_shape.numel(), in_channel, groups).item()
return int(total_ops)
@@ -71,9 +70,7 @@ def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
in_channel = module.in_channels
groups = module.groups
kernel_ops = module.weight.shape[2:].numel()
- total_ops = calculate_conv(
- bias_op, kernel_ops, out_shape.numel(), in_channel, groups
- ).item()
+ total_ops = calculate_conv(bias_op, kernel_ops, out_shape.numel(), in_channel, groups).item()
return int(total_ops)
@@ -114,6 +111,7 @@ def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
from torch.fx import symbolic_trace
from torch.fx.passes.shape_prop import ShapeProp
+
from .utils import prGreen, prRed, prYellow
@@ -135,9 +133,7 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
for node in gm.graph.nodes:
# print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}")
- fprint(
- f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}"
- )
+ fprint(f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}")
# node_op_type = str(node.target).split(".")[-1]
node_flops = None
@@ -157,17 +153,9 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
node_flops = 0
elif node.op == "call_function":
# torch internal functions
- key = (
- str(node.target)
- .split("at")[0]
- .replace("<", "")
- .replace(">", "")
- .strip()
- )
+ key = str(node.target).split("at")[0].replace("<", "").replace(">", "").strip()
if key in count_map:
- node_flops = count_map[key](
- input_shapes, output_shapes, *node.args, **node.kwargs
- )
+ node_flops = count_map[key](input_shapes, output_shapes, *node.args, **node.kwargs)
else:
missing_maps[key] = (node.op, key)
prRed(f"|{key}| is missing")
@@ -196,9 +184,7 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
print(f"weight_shape: None")
else:
print(type(m))
- print(
- f"weight_shape: {mod.state_dict()[node.target + '.weight'].shape}"
- )
+ print(f"weight_shape: {mod.state_dict()[node.target + '.weight'].shape}")
v_maps[str(node.name)] = node.meta["tensor_meta"].shape
if node_flops is not None:
@@ -208,6 +194,7 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
if len(missing_maps.keys()) > 0:
from pprint import pprint
+
print("Missing operators: ")
pprint(missing_maps)
return total_flops
diff --git a/thop/onnx_profile.py b/thop/onnx_profile.py
index 10da68c..0442f27 100644
--- a/thop/onnx_profile.py
+++ b/thop/onnx_profile.py
@@ -1,8 +1,9 @@
+import numpy as np
+import onnx
import torch
import torch.nn
-import onnx
from onnx import numpy_helper
-import numpy as np
+
from thop.vision.onnx_counter import onnx_operators
diff --git a/thop/profile.py b/thop/profile.py
index 6b15d27..d87afd6 100644
--- a/thop/profile.py
+++ b/thop/profile.py
@@ -1,19 +1,15 @@
from distutils.version import LooseVersion
-from thop.vision.basic_hooks import *
from thop.rnn_hooks import *
-
+from thop.vision.basic_hooks import *
# logger = logging.getLogger(__name__)
# logger.setLevel(logging.INFO)
-
from .utils import prGreen, prRed, prYellow
if LooseVersion(torch.__version__) < LooseVersion("1.0.0"):
logging.warning(
- "You are using an old version PyTorch {version}, which THOP does NOT support.".format(
- version=torch.__version__
- )
+ "You are using an old version PyTorch {version}, which THOP does NOT support.".format(version=torch.__version__)
)
default_dtype = torch.float64
@@ -96,9 +92,7 @@ def add_hooks(m):
m_type = type(m)
fn = None
- if (
- m_type in custom_ops
- ): # if defined both op maps, use custom_ops to overwrite.
+ if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
fn = custom_ops[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
@@ -108,10 +102,7 @@ def add_hooks(m):
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
else:
if m_type not in types_collection and report_missing:
- prRed(
- "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params."
- % m_type
- )
+ prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type)
if fn is not None:
handler = m.register_forward_hook(fn)
@@ -191,10 +182,7 @@ def add_hooks(m: nn.Module):
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
else:
if m_type not in types_collection and report_missing:
- prRed(
- "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params."
- % m_type
- )
+ prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type)
if fn is not None:
handler_collection[m] = (
@@ -220,9 +208,7 @@ def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
# else:
# m_ops, m_params = m.total_ops, m.total_params
next_dict = {}
- if m in handler_collection and not isinstance(
- m, (nn.Sequential, nn.ModuleList)
- ):
+ if m in handler_collection and not isinstance(m, (nn.Sequential, nn.ModuleList)):
m_ops, m_params = m.total_ops.item(), m.total_params.item()
else:
m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
diff --git a/thop/utils.py b/thop/utils.py
index 5f1c4bb..993761c 100644
--- a/thop/utils.py
+++ b/thop/utils.py
@@ -4,13 +4,16 @@
COLOR_GREEN = "92m"
COLOR_YELLOW = "93m"
+
def colorful_print(fn_print, color=COLOR_RED):
def actual_call(*args, **kwargs):
print(f"\033[{color}", end="")
fn_print(*args, **kwargs)
print("\033[00m", end="")
+
return actual_call
+
prRed = colorful_print(print, color=COLOR_RED)
prGreen = colorful_print(print, color=COLOR_GREEN)
prYellow = colorful_print(print, color=COLOR_YELLOW)
@@ -50,4 +53,4 @@ def clever_format(nums, format="%.2f"):
if __name__ == "__main__":
prRed("hello", "world")
prGreen("hello", "world")
- prYellow("hello", "world")
\ No newline at end of file
+ prYellow("hello", "world")
diff --git a/thop/vision/basic_hooks.py b/thop/vision/basic_hooks.py
index fb864b6..e8abcf6 100644
--- a/thop/vision/basic_hooks.py
+++ b/thop/vision/basic_hooks.py
@@ -1,10 +1,12 @@
import argparse
import logging
-from .calc_func import *
+
import torch
import torch.nn as nn
from torch.nn.modules.conv import _ConvNd
+from .calc_func import *
+
multiply_adds = 1
@@ -26,11 +28,11 @@ def count_convNd(m: _ConvNd, x, y: torch.Tensor):
bias_ops = 1 if m.bias is not None else 0
m.total_ops += calculate_conv2d_flops(
- input_size = list(x.shape),
- output_size = list(y.shape),
- kernel_size = list(m.weight.shape),
- groups = m.groups,
- bias = m.bias
+ input_size=list(x.shape),
+ output_size=list(y.shape),
+ kernel_size=list(m.weight.shape),
+ groups=m.groups,
+ bias=m.bias,
)
# N x Cout x H x W x (Cin x Kw x Kh + bias)
# m.total_ops += calculate_conv(
@@ -64,7 +66,7 @@ def count_normalization(m: nn.modules.batchnorm._BatchNorm, x, y):
x = x[0]
# bn is by default fused in inference
flops = calculate_norm(x.numel())
- if (getattr(m, 'affine', False) or getattr(m, 'elementwise_affine', False)):
+ if getattr(m, "affine", False) or getattr(m, "elementwise_affine", False):
flops *= 2
m.total_ops += flops
@@ -112,10 +114,7 @@ def count_avgpool(m, x, y):
def count_adap_avgpool(m, x, y):
- kernel = torch.div(
- torch.DoubleTensor([*(x[0].shape[2:])]),
- torch.DoubleTensor([*(y.shape[2:])])
- )
+ kernel = torch.div(torch.DoubleTensor([*(x[0].shape[2:])]), torch.DoubleTensor([*(y.shape[2:])]))
total_add = torch.prod(kernel)
num_elements = y.numel()
m.total_ops += calculate_adaptive_avg(total_add, num_elements)
diff --git a/thop/vision/calc_func.py b/thop/vision/calc_func.py
index f6af5ce..2e2f169 100644
--- a/thop/vision/calc_func.py
+++ b/thop/vision/calc_func.py
@@ -1,13 +1,16 @@
-import torch
-import numpy as np
import warnings
+import numpy as np
+import torch
+
+
def l_prod(in_list):
res = 1
for _ in in_list:
res *= _
return res
+
def l_sum(in_list):
res = 0
for _ in in_list:
@@ -25,6 +28,7 @@ def calculate_parameters(param_list):
def calculate_zero_ops():
return torch.DoubleTensor([int(0)])
+
def calculate_conv2d_flops(input_size: list, output_size: list, kernel_size: list, groups: int, bias: bool = False):
# n, out_c, oh, ow = output_size
# n, in_c, ih, iw = input_size
@@ -36,18 +40,19 @@ def calculate_conv2d_flops(input_size: list, output_size: list, kernel_size: lis
def calculate_conv(bias, kernel_size, output_size, in_channel, group):
warnings.warn("This API is being deprecated.")
- """inputs are all numbers!"""
+ """Inputs are all numbers!"""
return torch.DoubleTensor([output_size * (in_channel / group * kernel_size + bias)])
def calculate_norm(input_size):
- """input is a number not a array or tensor"""
+ """Input is a number not a array or tensor."""
return torch.DoubleTensor([2 * input_size])
+
def calculate_relu_flops(input_size):
# x[x < 0] = 0
return 0
-
+
def calculate_relu(input_size: torch.Tensor):
warnings.warn("This API is being deprecated")
diff --git a/thop/vision/efficientnet.py b/thop/vision/efficientnet.py
index b2fe162..8de88d1 100644
--- a/thop/vision/efficientnet.py
+++ b/thop/vision/efficientnet.py
@@ -3,8 +3,7 @@
import torch
import torch.nn as nn
-from torch.nn.modules.conv import _ConvNd
-
from efficientnet_pytorch.utils import Conv2dDynamicSamePadding, Conv2dStaticSamePadding
+from torch.nn.modules.conv import _ConvNd
register_hooks = {}
diff --git a/thop/vision/onnx_counter.py b/thop/vision/onnx_counter.py
index 8beb39a..52bd1e5 100644
--- a/thop/vision/onnx_counter.py
+++ b/thop/vision/onnx_counter.py
@@ -1,18 +1,20 @@
-import torch
import numpy as np
+import torch
from onnx import numpy_helper
+
from thop.vision.basic_hooks import zero_ops
+
from .calc_func import (
- counter_matmul,
- calculate_zero_ops,
+ calculate_avgpool,
calculate_conv,
- counter_mul,
calculate_norm,
+ calculate_softmax,
+ calculate_zero_ops,
+ counter_div,
+ counter_matmul,
+ counter_mul,
counter_pow,
counter_sqrt,
- counter_div,
- calculate_softmax,
- calculate_avgpool,
)
@@ -65,20 +67,12 @@ def onnx_counter_conv(diction, node):
group = attr.i
# print(dim_dil)
dim_input = diction[node.input[0]]
- output_size = np.append(
- dim_input[0 : -np.array(dim_kernel).size - 1], dim_weight[0]
- )
+ output_size = np.append(dim_input[0 : -np.array(dim_kernel).size - 1], dim_weight[0])
hw = np.array(dim_input[-np.array(dim_kernel).size :])
for i in range(hw.size):
- hw[i] = int(
- (hw[i] + 2 * dim_pad[i] - dim_dil[i] * (dim_kernel[i] - 1) - 1)
- / dim_stride[i]
- + 1
- )
+ hw[i] = int((hw[i] + 2 * dim_pad[i] - dim_dil[i] * (dim_kernel[i] - 1) - 1) / dim_stride[i] + 1)
output_size = np.append(output_size, hw)
- macs = calculate_conv(
- dim_bias, np.prod(dim_kernel), np.prod(output_size), dim_weight[1], group
- )
+ macs = calculate_conv(dim_bias, np.prod(dim_kernel), np.prod(output_size), dim_weight[1], group)
output_name = node.output[0]
# if '140' in diction: