diff --git a/.github/collect_env.py b/.github/collect_env.py index 4664c4f..9bf9d6d 100644 --- a/.github/collect_env.py +++ b/.github/collect_env.py @@ -159,14 +159,13 @@ def get_nvidia_smi(): def get_platform(): if sys.platform.startswith("linux"): return "linux" - elif sys.platform.startswith("win32"): + if sys.platform.startswith("win32"): return "win32" - elif sys.platform.startswith("cygwin"): + if sys.platform.startswith("cygwin"): return "cygwin" - elif sys.platform.startswith("darwin"): + if sys.platform.startswith("darwin"): return "darwin" - else: - return sys.platform + return sys.platform def get_mac_version(run_lambda): diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..1d5dba9 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,35 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + time: "06:00" + timezone: "Europe/Paris" + groups: + gh-actions: + patterns: + - "*" + reviewers: + - "frgfm" + assignees: + - "frgfm" + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "daily" + time: "06:00" + timezone: "Europe/Paris" + reviewers: + - "frgfm" + assignees: + - "frgfm" + allow: + - dependency-name: "ruff" + - dependency-name: "mypy" + - dependency-name: "pre-commit" diff --git a/.github/labeler.yml b/.github/labeler.yml index 95b2d1c..40e3213 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,31 +1,50 @@ 'module: crawler': -- torchscan/crawler.py +- changed-files: + - any-glob-to-any-file: torchscan/crawler.py 'module: modules': -- torchscan/modules/* +- changed-files: + - any-glob-to-any-file: torchscan/modules/* 'module: process': -- torchscan/process/* +- changed-files: + - any-glob-to-any-file: torchscan/process/* 'module: utils': -- torchscan/utils.py +- changed-files: + - any-glob-to-any-file: torchscan/utils.py 'ext: docs': -- docs/* +- changed-files: + - any-glob-to-any-file: docs/* 'ext: scripts': -- scripts/* +- changed-files: + - any-glob-to-any-file: scripts/* 'ext: tests': -- tests/* +- changed-files: + - any-glob-to-any-file: tests/* 'topic: ci': -- .github/* - -'topic: documentation': -- README.md -- CONTRIBUTING.md +- changed-files: + - any-glob-to-any-file: .github/* + +'topic: docs': +- changed-files: + - any-glob-to-any-file: + - README.md + - CONTRIBUTING.md + - CODFE_OF_CONDUCT.md + - CITATION.cff + - LICENSE 'topic: build': -- setup.py -- pyproject.toml +- changed-files: + - any-glob-to-any-file: + - setup.py + - pyproject.toml + +'topic: style': +- changed-files: + - any-glob-to-any-file: .pre-commit-config.yaml diff --git a/.github/release.yml b/.github/release.yml index 8f7f47c..5962f0a 100644 --- a/.github/release.yml +++ b/.github/release.yml @@ -9,15 +9,15 @@ changelog: # NEW FEATURES - title: New Features 🚀 labels: - - "type: new feature" + - "type: feat" # BUG FIXES - title: Bug Fixes 🐛 labels: - - "type: bug" + - "type: fix" # IMPROVEMENTS - title: Improvements labels: - - "type: enhancement" + - "type: improvement" # MISC - title: Miscellaneous labels: diff --git a/.github/verify_labels.py b/.github/verify_labels.py index 58f6c99..83dc1d3 100644 --- a/.github/verify_labels.py +++ b/.github/verify_labels.py @@ -74,9 +74,7 @@ def parse_args(): ) parser.add_argument("pr", type=int, help="PR number") - args = parser.parse_args() - - return args + return parser.parse_args() if __name__ == "__main__": diff --git a/.github/workflows/builds.yml b/.github/workflows/builds.yml index 6b0180f..eb8cb52 100644 --- a/.github/workflows/builds.yml +++ b/.github/workflows/builds.yml @@ -13,43 +13,39 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python: [3.8, 3.9, '3.10', 3.11] + python: [3.8, 3.9, '3.10', 3.11, 3.12] + exclude: + - os: macos-latest + python: 3.8 + - os: macos-latest + python: 3.9 + - os: macos-latest + python: '3.10' steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} architecture: x64 - - name: Cache python modules - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-python-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-build - name: Install package run: | - python -m pip install --upgrade pip - pip install -e . + python -m pip install --upgrade uv + uv pip install --system -e . - name: Import package run: python -c "import torchscan; print(torchscan.__version__)" pypi: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: 3.11 architecture: x64 - - name: Cache python modules - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-python-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-build - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install setuptools wheel twine --upgrade + python -m pip install --upgrade uv + uv pip install --system setuptools wheel twine --upgrade - run: | python setup.py sdist bdist_wheel twine check dist/* @@ -57,11 +53,11 @@ jobs: conda: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: conda-incubator/setup-miniconda@v2 + - uses: actions/checkout@v4 + - uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true - python-version: "3.9" + python-version: "3.11" - name: Install dependencies shell: bash -el {0} run: conda install -y conda-build conda-verify diff --git a/.github/workflows/doc-status.yml b/.github/workflows/doc-status.yml index d2f2d9d..69b7010 100644 --- a/.github/workflows/doc-status.yml +++ b/.github/workflows/doc-status.yml @@ -6,9 +6,9 @@ jobs: see-page-build-payload: runs-on: ubuntu-latest steps: - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: 3.11 architecture: x64 - name: check status run: | diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index e778325..4792276 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -11,22 +11,17 @@ jobs: os: [ubuntu-latest] python: [3.9] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: persist-credentials: false - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} architecture: x64 - - name: Cache python modules - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-python-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-docs - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install -e ".[docs]" + python -m pip install --upgrade uv + uv pip install --system -e ".[docs]" - name: Build documentation run: cd docs && bash build.sh @@ -35,12 +30,12 @@ jobs: run: test -e docs/build/index.html || exit - name: Install SSH Client 🔑 - uses: webfactory/ssh-agent@v0.4.1 + uses: webfactory/ssh-agent@v0.9.0 with: ssh-private-key: ${{ secrets.SSH_DEPLOY_KEY }} - name: Deploy to Github Pages - uses: JamesIves/github-pages-deploy-action@3.7.1 + uses: JamesIves/github-pages-deploy-action@4 with: BRANCH: gh-pages FOLDER: 'docs/build' diff --git a/.github/workflows/pr-labels.yml b/.github/workflows/pr-labels.yml index 39e039b..237654a 100644 --- a/.github/workflows/pr-labels.yml +++ b/.github/workflows/pr-labels.yml @@ -10,17 +10,15 @@ jobs: if: github.event.pull_request.merged == true runs-on: ubuntu-latest steps: - - name: Checkout repository - uses: actions/checkout@v2 - - name: Set up python - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 - name: Install requests run: pip install requests - name: Process commit and find merger responsible for labeling id: commit run: echo "::set-output name=merger::$(python .github/verify_labels.py ${{ github.event.pull_request.number }})" - name: Comment PR - uses: actions/github-script@0.3.0 + uses: actions/github-script@7.0.1 if: ${{ steps.commit.outputs.merger != '' }} with: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index c6d9c07..66b9027 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -9,20 +9,15 @@ jobs: if: "!github.event.release.prerelease" runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: 3.11 architecture: x64 - - name: Cache python modules - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-python-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-pypi - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install setuptools wheel twine --upgrade + python -m pip install --upgrade uv + uv pip install --system setuptools wheel twine --upgrade - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} @@ -38,27 +33,27 @@ jobs: runs-on: ubuntu-latest needs: pypi steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: 3.11 architecture: x64 - name: Install package run: | - python -m pip install --upgrade pip - pip install torchscan + python -m pip install --upgrade uv + uv pip install --system torchscan python -c "import torchscan; print(torchscan.__version__)" conda: if: "!github.event.release.prerelease" runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Miniconda setup uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true - python-version: 3.9 + python-version: 3.11 - name: Install dependencies shell: bash -el {0} run: conda install -y conda-build conda-verify anaconda-client @@ -83,7 +78,7 @@ jobs: uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true - python-version: 3.9 + python-version: 3.11 auto-activate-base: true - name: Install package shell: bash -el {0} diff --git a/.github/workflows/pull_requests.yml b/.github/workflows/pull_requests.yml index 2d9ac11..cc2205f 100644 --- a/.github/workflows/pull_requests.yml +++ b/.github/workflows/pull_requests.yml @@ -8,21 +8,15 @@ jobs: docs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: 3.9 architecture: x64 - - name: Cache python modules - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-python-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-docs - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install -e ".[docs]" + python -m pip install --upgrade uv + uv pip install --system -e ".[docs]" - name: Build documentation run: cd docs && bash build.sh diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 57dade8..a67e7e9 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -12,16 +12,17 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python: [3.9] + python: [3.11] steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} architecture: x64 - name: Run ruff run: | - pip install ruff==0.1.14 + python -m pip install --upgrade uv + uv pip install --system -e '.[quality]' ruff --version ruff check --diff . @@ -30,24 +31,17 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python: [3.9] + python: [3.11] steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} architecture: x64 - - name: Cache python modules - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-python-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -e ".[quality]" --upgrade - name: Run mypy run: | + python -m pip install --upgrade uv + uv pip install --system -e '.[quality]' mypy --version mypy @@ -56,16 +50,17 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python: [3.9] + python: [3.11] steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} architecture: x64 - name: Run ruff run: | - pip install ruff==0.1.14 + python -m pip install --upgrade uv + uv pip install --system -e '.[quality]' ruff --version ruff format --check --diff . @@ -74,16 +69,17 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python: [3.9] + python: [3.11] steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} architecture: x64 - name: Run pre-commit hooks run: | - pip install pre-commit + python -m pip install --upgrade uv + uv pip install --system -e '.[quality]' git checkout -b temp pre-commit install pre-commit --version diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a627fb4..3ca7b5f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,27 +12,22 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python: [3.9] + python: [3.11] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: persist-credentials: false - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} architecture: x64 - - name: Cache python modules - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-python-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-pytest - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install -e ".[test]" --upgrade + python -m pip install --upgrade uv + uv pip install --system -e ".[test]" --upgrade - name: Run unittests run: pytest --cov=torchscan --cov-report xml tests/ - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 with: name: coverage-reports path: ./coverage.xml @@ -41,10 +36,10 @@ jobs: runs-on: ubuntu-latest needs: pytest steps: - - uses: actions/checkout@v2 - - uses: actions/download-artifact@v2 + - uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} flags: unittests @@ -57,7 +52,7 @@ jobs: matrix: os: [ubuntu-latest] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: persist-credentials: false - name: Check the headers diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d94bf8a..72d6294 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.9 + python: python3.11 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 @@ -22,7 +22,7 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.1.14' + rev: 'v0.6.4' hooks: - id: ruff args: diff --git a/Makefile b/Makefile index e9885b3..3b87619 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ quality: # this target runs checks on all files and potentially modifies some of them style: ruff format . - ruff --fix . + ruff check --fix . # Run tests for the library test: diff --git a/pyproject.toml b/pyproject.toml index fba5175..2df81fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,8 +40,8 @@ test = [ "pytest-pretty>=1.0.0,<2.0.0", ] quality = [ - "ruff==0.1.14", - "mypy==1.8.0", + "ruff==0.6.4", + "mypy==1.10.0", "pre-commit>=3.0.0,<4.0.0", ] docs = [ @@ -89,38 +89,49 @@ testpaths = ["torchscan/"] source = ["torchscan/"] [tool.ruff] +line-length = 120 +target-version = "py311" +preview = true + +[tool.ruff.lint] select = [ + "F", # pyflakes "E", # pycodestyle errors "W", # pycodestyle warnings - "D101", "D103", # pydocstyle missing docstring in public function/class - "D201","D202","D207","D208","D214","D215","D300","D301","D417", "D419", # pydocstyle - "F", # pyflakes "I", # isort - "C4", # flake8-comprehensions - "B", # flake8-bugbear - "CPY001", # flake8-copyright - "ISC", # flake8-implicit-str-concat - "PYI", # flake8-pyi - "NPY", # numpy - "PERF", # perflint - "RUF", # ruff specific - "PTH", # flake8-use-pathlib - "S", # flake8-bandit "N", # pep8-naming - "T10", # flake8-debugger - "T20", # flake8-print - "PT", # flake8-pytest-style - "LOG", # flake8-logging - "SIM", # flake8-simplify + "D101", "D103", # pydocstyle missing docstring in public function/class + "D201","D202","D207","D208","D214","D215","D300","D301","D417", "D419", # pydocstyle "YTT", # flake8-2020 "ANN", # flake8-annotations "ASYNC", # flake8-async + "S", # flake8-bandit "BLE", # flake8-blind-except + "B", # flake8-bugbear "A", # flake8-builtins + "COM", # flake8-commas + "CPY", # flake8-copyright + "C4", # flake8-comprehensions + "T10", # flake8-debugger + "ISC", # flake8-implicit-str-concat "ICN", # flake8-import-conventions + "LOG", # flake8-logging "PIE", # flake8-pie + "T20", # flake8-print + "PYI", # flake8-pyi + "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RET", # flake8-return + "SLF", # flake8-self + "SIM", # flake8-simplify "ARG", # flake8-unused-arguments + "PTH", # flake8-use-pathlib + "PERF", # perflint + "NPY", # numpy + "FAST", # fastapi "FURB", # refurb + "RUF", # ruff specific + "N", # pep8-naming ] ignore = [ "E501", # line too long, handled by black @@ -138,35 +149,34 @@ ignore = [ "N812", # lowercase imported as non-lowercase "ISC001", # implicit string concatenation (handled by format) "ANN401", # Dynamically typed expressions (typing.Any) are disallowed + "SLF001", # Private member accessed ] exclude = [".git"] -line-length = 120 -target-version = "py39" -preview = true -[tool.ruff.format] -quote-style = "double" -indent-style = "space" +[tool.ruff.lint.flake8-quotes] +docstring-quotes = "double" + +[tool.ruff.lint.isort] +known-first-party = ["torchscan", "app"] +known-third-party = ["torch", "torchvision"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "**/__init__.py" = ["I001", "F401", "CPY001"] "scripts/**.py" = ["D", "T201", "N812", "S101", "ANN"] ".github/**.py" = ["D", "T201", "S602", "S101", "ANN"] "docs/**.py" = ["E402", "D103", "ANN", "A001", "ARG001"] -"tests/**.py" = ["D101", "D103", "CPY001", "S101", "PT011", "ANN"] +"tests/**.py" = ["D101", "D103", "CPY001", "S101", "PT011", "ANN", "SLF001"] "demo/**.py" = ["D103", "ANN"] "setup.py" = ["T201"] "torchscan/process/memory.py" = ["S60"] -[tool.ruff.flake8-quotes] -docstring-quotes = "double" +[tool.ruff.format] +quote-style = "double" +indent-style = "space" -[tool.ruff.isort] -known-first-party = ["torchscan", "app"] -known-third-party = ["torch", "torchvision"] [tool.mypy] -python_version = "3.9" +python_version = "3.11" files = "torchscan/" show_error_codes = true pretty = true diff --git a/scripts/benchmark.py b/scripts/benchmark.py index 570f129..889a025 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -58,7 +58,7 @@ def main(): headers = ["Model", "Params (M)", "FLOPs (G)", "MACs (G)", "DMAs (G)", "RF"] max_w = [20, 10, 10, 10, 10, 10] - info_str = [(" " * margin).join([f"{col_name:<{col_w}}" for col_name, col_w in zip(headers, max_w)])] + info_str = [(" " * margin).join([f"{col_name:<{col_w}}" for col_name, col_w in zip(headers, max_w, strict=False)])] info_str.append("-" * len(info_str[0])) print("\n".join(info_str)) for name in TORCHVISION_MODELS: diff --git a/tests/test_utils.py b/tests/test_utils.py index 3edafdc..4865fcf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -16,7 +16,7 @@ def test_wrap_string(): wrap = "[...]" assert utils.wrap_string(example, max_len, mode="end") == example[: max_len - len(wrap)] + wrap - assert utils.wrap_string(example, max_len, mode="mid") == f"{example[:max_len - 2 - len(wrap)]}{wrap}.a" + assert utils.wrap_string(example, max_len, mode="mid") == f"{example[: max_len - 2 - len(wrap)]}{wrap}.a" assert utils.wrap_string(example, len(example), mode="end") == example with pytest.raises(ValueError): _ = utils.wrap_string(example, max_len, mode="test") diff --git a/torchscan/crawler.py b/torchscan/crawler.py index c98dc38..ea51b34 100644 --- a/torchscan/crawler.py +++ b/torchscan/crawler.py @@ -70,7 +70,8 @@ def crawl_module( dtype = [dtype] * len(input_shape) # Tensor arguments input_ts = [ - torch.rand(1, *in_shape).to(dtype=_dtype, device=device) for in_shape, _dtype in zip(input_shape, dtype) + torch.rand(1, *in_shape).to(dtype=_dtype, device=device) + for in_shape, _dtype in zip(input_shape, dtype, strict=False) ] pre_fw_handles, post_fw_handles = [], [] diff --git a/torchscan/modules/flops.py b/torchscan/modules/flops.py index 296ccf8..25d2fd3 100644 --- a/torchscan/modules/flops.py +++ b/torchscan/modules/flops.py @@ -30,41 +30,40 @@ def module_flops(module: Module, inputs: Tuple[Tensor, ...], out: Tensor) -> int """ if isinstance(module, (nn.Identity, nn.Flatten)): return 0 - elif isinstance(module, nn.Linear): + if isinstance(module, nn.Linear): return flops_linear(module, inputs) - elif isinstance(module, nn.ReLU): + if isinstance(module, nn.ReLU): return flops_relu(module, inputs) - elif isinstance(module, nn.ELU): + if isinstance(module, nn.ELU): return flops_elu(module, inputs) - elif isinstance(module, nn.LeakyReLU): + if isinstance(module, nn.LeakyReLU): return flops_leakyrelu(module, inputs) - elif isinstance(module, nn.ReLU6): + if isinstance(module, nn.ReLU6): return flops_relu6(module, inputs) - elif isinstance(module, nn.Tanh): + if isinstance(module, nn.Tanh): return flops_tanh(module, inputs) - elif isinstance(module, nn.Sigmoid): + if isinstance(module, nn.Sigmoid): return flops_sigmoid(module, inputs) - elif isinstance(module, _ConvTransposeNd): + if isinstance(module, _ConvTransposeNd): return flops_convtransposend(module, inputs, out) - elif isinstance(module, _ConvNd): + if isinstance(module, _ConvNd): return flops_convnd(module, inputs, out) - elif isinstance(module, _BatchNorm): + if isinstance(module, _BatchNorm): return flops_bn(module, inputs) - elif isinstance(module, _MaxPoolNd): + if isinstance(module, _MaxPoolNd): return flops_maxpool(module, inputs, out) - elif isinstance(module, _AvgPoolNd): + if isinstance(module, _AvgPoolNd): return flops_avgpool(module, inputs, out) - elif isinstance(module, _AdaptiveMaxPoolNd): + if isinstance(module, _AdaptiveMaxPoolNd): return flops_adaptive_maxpool(module, inputs, out) - elif isinstance(module, _AdaptiveAvgPoolNd): + if isinstance(module, _AdaptiveAvgPoolNd): return flops_adaptive_avgpool(module, inputs, out) - elif isinstance(module, nn.Dropout): + if isinstance(module, nn.Dropout): return flops_dropout(module, inputs) - elif isinstance(module, nn.Transformer): + if isinstance(module, nn.Transformer): return flops_transformer(module, inputs) - else: - warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1) - return 0 + warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1) + return 0 def flops_linear(module: nn.Linear, inputs: Tuple[Tensor, ...]) -> int: @@ -118,8 +117,7 @@ def flops_dropout(module: nn.Dropout, inputs: Tuple[Tensor, ...]) -> int: if module.p > 0: # Sample a random number for each input element return inputs[0].numel() - else: - return 0 + return 0 def flops_convtransposend(module: _ConvTransposeNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int: @@ -198,7 +196,7 @@ def flops_adaptive_maxpool(_: _AdaptiveMaxPoolNd, inputs: Tuple[Tensor, ...], ou # Approximate kernel_size using ratio of spatial shapes between input and output kernel_size = tuple( i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1 - for i_size, o_size in zip(inputs[0].shape[2:], out.shape[2:]) + for i_size, o_size in zip(inputs[0].shape[2:], out.shape[2:], strict=False) ) # for each spatial output element, check max element in kernel scope @@ -210,7 +208,7 @@ def flops_adaptive_avgpool(_: _AdaptiveAvgPoolNd, inputs: Tuple[Tensor, ...], ou # Approximate kernel_size using ratio of spatial shapes between input and output kernel_size = tuple( i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1 - for i_size, o_size in zip(inputs[0].shape[2:], out.shape[2:]) + for i_size, o_size in zip(inputs[0].shape[2:], out.shape[2:], strict=False) ) # for each spatial output element, sum elements in kernel scope and div by kernel size diff --git a/torchscan/modules/macs.py b/torchscan/modules/macs.py index 8b596f8..b7bdb6d 100644 --- a/torchscan/modules/macs.py +++ b/torchscan/modules/macs.py @@ -28,35 +28,32 @@ def module_macs(module: Module, inp: Tensor, out: Tensor) -> int: """ if isinstance(module, nn.Linear): return macs_linear(module, inp, out) - elif isinstance(module, (nn.Identity, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid, nn.Flatten)): + if isinstance(module, (nn.Identity, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid, nn.Flatten)): return 0 - elif isinstance(module, _ConvTransposeNd): + if isinstance(module, _ConvTransposeNd): return macs_convtransposend(module, inp, out) - elif isinstance(module, _ConvNd): + if isinstance(module, _ConvNd): return macs_convnd(module, inp, out) - elif isinstance(module, _BatchNorm): + if isinstance(module, _BatchNorm): return macs_bn(module, inp, out) - elif isinstance(module, _MaxPoolNd): + if isinstance(module, _MaxPoolNd): return macs_maxpool(module, inp, out) - elif isinstance(module, _AvgPoolNd): + if isinstance(module, _AvgPoolNd): return macs_avgpool(module, inp, out) - elif isinstance(module, _AdaptiveMaxPoolNd): + if isinstance(module, _AdaptiveMaxPoolNd): return macs_adaptive_maxpool(module, inp, out) - elif isinstance(module, _AdaptiveAvgPoolNd): + if isinstance(module, _AdaptiveAvgPoolNd): return macs_adaptive_avgpool(module, inp, out) - elif isinstance(module, nn.Dropout): - return 0 - else: - warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1) + if isinstance(module, nn.Dropout): return 0 + warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1) + return 0 def macs_linear(module: nn.Linear, _: Tensor, out: Tensor) -> int: """MACs estimation for `torch.nn.Linear`""" # batch size * out_chan * macs_per_elt (bias already counted in accumulation) - mm_mac = module.in_features * reduce(mul, out.shape) - - return mm_mac + return module.in_features * reduce(mul, out.shape) def macs_convtransposend(module: _ConvTransposeNd, inp: Tensor, out: Tensor) -> int: @@ -79,10 +76,9 @@ def macs_convnd(module: _ConvNd, inp: Tensor, out: Tensor) -> int: effective_in_chan = inp.shape[1] // module.groups # N * mac window_mac = effective_in_chan * window_macs_per_chan - conv_mac = out.numel() * window_mac + return out.numel() * window_mac # bias already counted in accumulation - return conv_mac def macs_bn(module: _BatchNorm, inp: Tensor, _: Tensor) -> int: @@ -133,7 +129,7 @@ def macs_adaptive_maxpool(_: _AdaptiveMaxPoolNd, inp: Tensor, out: Tensor) -> in # Approximate kernel_size using ratio of spatial shapes between input and output kernel_size = tuple( i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1 - for i_size, o_size in zip(inp.shape[2:], out.shape[2:]) + for i_size, o_size in zip(inp.shape[2:], out.shape[2:], strict=False) ) # for each spatial output element, check max element in kernel scope @@ -145,7 +141,7 @@ def macs_adaptive_avgpool(_: _AdaptiveAvgPoolNd, inp: Tensor, out: Tensor) -> in # Approximate kernel_size using ratio of spatial shapes between input and output kernel_size = tuple( i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1 - for i_size, o_size in zip(inp.shape[2:], out.shape[2:]) + for i_size, o_size in zip(inp.shape[2:], out.shape[2:], strict=False) ) # for each spatial output element, sum elements in kernel scope and div by kernel size diff --git a/torchscan/modules/memory.py b/torchscan/modules/memory.py index 6fd6bbe..a98f8b1 100644 --- a/torchscan/modules/memory.py +++ b/torchscan/modules/memory.py @@ -30,33 +30,32 @@ def module_dmas(module: Module, inp: Tensor, out: Tensor) -> int: """ if isinstance(module, nn.Identity): return dmas_identity(module, inp, out) - elif isinstance(module, nn.Flatten): + if isinstance(module, nn.Flatten): return dmas_flatten(module, inp, out) - elif isinstance(module, nn.Linear): + if isinstance(module, nn.Linear): return dmas_linear(module, inp, out) - elif isinstance(module, (nn.ReLU, nn.ReLU6)): + if isinstance(module, (nn.ReLU, nn.ReLU6)): return dmas_relu(module, inp, out) - elif isinstance(module, (nn.ELU, nn.LeakyReLU)): + if isinstance(module, (nn.ELU, nn.LeakyReLU)): return dmas_act_single_param(module, inp, out) - elif isinstance(module, nn.Sigmoid): + if isinstance(module, nn.Sigmoid): return dmas_sigmoid(module, inp, out) - elif isinstance(module, nn.Tanh): + if isinstance(module, nn.Tanh): return dmas_tanh(module, inp, out) - elif isinstance(module, _ConvTransposeNd): + if isinstance(module, _ConvTransposeNd): return dmas_convtransposend(module, inp, out) - elif isinstance(module, _ConvNd): + if isinstance(module, _ConvNd): return dmas_convnd(module, inp, out) - elif isinstance(module, _BatchNorm): + if isinstance(module, _BatchNorm): return dmas_bn(module, inp, out) - elif isinstance(module, (_MaxPoolNd, _AvgPoolNd)): + if isinstance(module, (_MaxPoolNd, _AvgPoolNd)): return dmas_pool(module, inp, out) - elif isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)): + if isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)): return dmas_adaptive_pool(module, inp, out) - elif isinstance(module, nn.Dropout): + if isinstance(module, nn.Dropout): return dmas_dropout(module, inp, out) - else: - warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1) - return 0 + warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1) + return 0 def num_params(module: Module) -> int: @@ -209,7 +208,7 @@ def dmas_adaptive_pool(_: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Te # Approximate kernel_size using ratio of spatial shapes between input and output kernel_size = tuple( i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1 - for i_size, o_size in zip(inp.shape[2:], out.shape[2:]) + for i_size, o_size in zip(inp.shape[2:], out.shape[2:], strict=False) ) # Each output element required K ** 2 memory accesses input_dma = reduce(mul, kernel_size) * out.numel() diff --git a/torchscan/modules/receptive.py b/torchscan/modules/receptive.py index 25611ae..dfb7bde 100644 --- a/torchscan/modules/receptive.py +++ b/torchscan/modules/receptive.py @@ -45,15 +45,14 @@ def module_rf(module: Module, inp: Tensor, out: Tensor) -> Tuple[float, float, f ), ): return 1.0, 1.0, 0.0 - elif isinstance(module, _ConvTransposeNd): + if isinstance(module, _ConvTransposeNd): return rf_convtransposend(module, inp, out) - elif isinstance(module, (_ConvNd, _MaxPoolNd, _AvgPoolNd)): + if isinstance(module, (_ConvNd, _MaxPoolNd, _AvgPoolNd)): return rf_aggregnd(module, inp, out) - elif isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)): + if isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)): return rf_adaptive_poolnd(module, inp, out) - else: - warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1) - return 1.0, 1.0, 0.0 + warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1) + return 1.0, 1.0, 0.0 def rf_convtransposend(module: _ConvTransposeNd, _: Tensor, __: Tensor) -> Tuple[float, float, float]: diff --git a/torchscan/utils.py b/torchscan/utils.py index 7c093e9..477996f 100644 --- a/torchscan/utils.py +++ b/torchscan/utils.py @@ -18,10 +18,9 @@ def format_name(name: str, depth: int = 0) -> str: """ if depth == 0: return name - elif depth == 1: + if depth == 1: return f"├─{name}" - else: - return f"{'| ' * (depth - 1)}└─{name}" + return f"{'| ' * (depth - 1)}└─{name}" def wrap_string(s: str, max_len: int, delimiter: str = ".", wrap: str = "[...]", mode: str = "end") -> str: @@ -41,12 +40,11 @@ def wrap_string(s: str, max_len: int, delimiter: str = ".", wrap: str = "[...]", if mode == "end": return s[: max_len - len(wrap)] + wrap - elif mode == "mid": + if mode == "mid": final_part = s.rpartition(delimiter)[-1] wrapped_end = f"{wrap}.{final_part}" return s[: max_len - len(wrapped_end)] + wrapped_end - else: - raise ValueError("received an unexpected value of argument `mode`") + raise ValueError("received an unexpected value of argument `mode`") def unit_scale(val: float) -> Tuple[float, str]: @@ -59,14 +57,13 @@ def unit_scale(val: float) -> Tuple[float, str]: """ if val // 1e12 > 0: return val / 1e12, "T" - elif val // 1e9 > 0: + if val // 1e9 > 0: return val / 1e9, "G" - elif val // 1e6 > 0: + if val // 1e6 > 0: return val / 1e6, "M" - elif val // 1e3 > 0: + if val // 1e3 > 0: return val / 1e3, "k" - else: - return val, "" + return val, "" def format_s(f_string: str, min_w: Optional[int] = None, max_w: Optional[int] = None) -> str: @@ -135,11 +132,12 @@ def format_info( for v, s in zip( col_w, format_line_str(layer, col_w=None, wrap_mode=wrap_mode, receptive_field=True, effective_rf_stats=True), + strict=False, ) ] # Truncate columns that are too long - col_w = list(starmap(min, zip(col_w, max_w))) + col_w = list(starmap(min, zip(col_w, max_w, strict=False))) if not receptive_field: col_w = col_w[:4] @@ -159,7 +157,7 @@ def format_info( # Header info_str = [ thin_line, - margin_str.join([f"{col_name:<{col_w}}" for col_name, col_w in zip(headers, col_w)]), + margin_str.join([f"{col_name:<{col_w}}" for col_name, col_w in zip(headers, col_w, strict=False)]), thick_line, ]