Skip to content

Commit

Permalink
Merge branch 'master' into newmetric/vmaf
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Mar 5, 2025
2 parents ce56ce4 + d6a1ad2 commit 1c41c8c
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .azure/gpu-unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ jobs:
- bash: |
du -h --max-depth=1 .
python -m pytest $(TEST_DIRS) \
-m "not DDP" --numprocesses=5 --dist=loadfile \
-m "not DDP" --numprocesses=9 --dist=loadfile \
--cov=torchmetrics --timeout=240 --durations=100 \
--reruns 3 --reruns-delay 1
workingDirectory: "tests/"
Expand Down
19 changes: 14 additions & 5 deletions .github/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,23 @@
import os
import re
import sys
from pathlib import Path
from typing import Optional, Union

import fire
from packaging.version import parse

_REQUEST_TIMEOUT = 10
_PATH_ROOT = os.path.dirname(os.path.dirname(__file__))
_PATH_REPO_ROOT = Path(__file__).resolve().parent.parent
_PATH_DIR_TESTS = _PATH_REPO_ROOT / "tests"
_PKG_WIDE_SUBPACKAGES = ("utilities", "helpers")
LUT_PYTHON_TORCH = {
"3.8": "1.4",
"3.9": "1.7.1",
"3.10": "1.11",
"3.11": "1.13",
}
_path_root = lambda *ds: os.path.join(_PATH_ROOT, *ds)
_path_root = lambda *ds: os.path.join(_PATH_REPO_ROOT, *ds)
REQUIREMENTS_FILES = (*glob.glob(_path_root("requirements", "*.txt")), _path_root("requirements.txt"))


Expand Down Expand Up @@ -190,10 +192,17 @@ def _crop_path(fname: str, paths: tuple[str] = ("src/torchmetrics/", "tests/unit
if as_list: # keep only unique
return list(test_modules)

test_modules = [f"unittests/{md}" for md in set(test_modules)]
not_exists = [p for p in test_modules if os.path.exists(p)]
test_modules = [os.path.join("unittests", fp) for fp in set(test_modules)]
# filter only existing modules
not_exists = [fp for fp in test_modules if not (_PATH_DIR_TESTS / fp).exists()]
if not_exists:
raise ValueError(f"Missing following paths: {not_exists}")
logging.debug(f"Missing following paths: {not_exists}")
# filter only existing path in repo
test_modules = [fp for fp in test_modules if (_PATH_DIR_TESTS / fp).exists()]
if not test_modules:
logging.debug("No tests were changed -> rather test everything...")
return _return_all

return " ".join(test_modules)

@staticmethod
Expand Down
15 changes: 9 additions & 6 deletions .github/workflows/ci-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@ concurrency:

jobs:
check-code:
uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@v0.12.0
uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@v0.13.1
with:
actions-ref: v0.12.0
actions-ref: v0.13.1
extra-typing: "typing"

check-schema:
uses: Lightning-AI/utilities/.github/workflows/[email protected]
uses: Lightning-AI/utilities/.github/workflows/[email protected]
with:
actions-ref: v0.13.1
azure-schema-version: "v1.208.0"

check-package:
if: github.event.pull_request.draft == false
uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.12.0
uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.13.1
with:
actions-ref: v0.12.0
actions-ref: v0.13.1
artifact-name: dist-packages-${{ github.sha }}
import-name: "torchmetrics"
testing-matrix: |
Expand All @@ -35,7 +38,7 @@ jobs:
}
check-md-links:
uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.12.0
uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.13.1
with:
base-branch: master
config-file: ".github/markdown-links-config.json"
8 changes: 4 additions & 4 deletions .github/workflows/clear-cache.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ on:
jobs:
cron-clear:
if: github.event_name == 'schedule' || github.event_name == 'pull_request'
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.12.0
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.13.1
with:
scripts-ref: v0.12.0
scripts-ref: v0.13.1
dry-run: ${{ github.event_name == 'pull_request' }}
pattern: "pip-latest"
age-days: 7

direct-clear:
if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request'
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.12.0
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.13.1
with:
scripts-ref: v0.12.0
scripts-ref: v0.13.1
dry-run: ${{ github.event_name == 'pull_request' }}
pattern: ${{ inputs.pattern || 'pypi_wheels' }} # setting str in case of PR / debugging
age-days: ${{ fromJSON(inputs.age-days) || 0 }} # setting 0 in case of PR / debugging
2 changes: 1 addition & 1 deletion requirements/_docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ pydantic > 1.0.0, < 3.0.0
# todo: until this has resolution - https://github.com/sphinx-gallery/sphinx-gallery/issues/1290
# Image
scikit-image ~=0.22; python_version < "3.10"
scikit-image ~=0.19; python_version > "3.9" # we do not use `> =` because of oldest replcement
scikit-image ~=0.25; python_version > "3.9" # we do not use `> =` because of oldest replcement
2 changes: 1 addition & 1 deletion requirements/audio_test.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

mir-eval >=0.6, <=0.8.1
mir-eval >=0.6, <=0.8.2
fast-bss-eval >=0.1.0, <0.1.5
torch_complex <0.5.0 # needed for fast-bss-eval
srmrpy @ git+https://github.com/Lightning-Sandbox/SRMRpy
37 changes: 27 additions & 10 deletions src/torchmetrics/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,33 @@ class CLIPScore(Metric):
As input to ``forward`` and ``update`` the metric accepts the following input
- source: Source input. This can be:
- Images: (:class:`~torch.Tensor` or list of tensors): tensor with images feed to the feature extractor with. If
a single tensor it should have shape ``(N, C, H, W)``. If a list of tensors, each tensor should have shape
``(C, H, W)``. ``C`` is the number of channels, ``H`` and ``W`` are the height and width of the image.
- Text: (:class:`~str` or :class:`~list` of :class:`~str`): text to compare with the images, one for each image.
- target: Target input. This can be:
- Images: (:class:`~torch.Tensor` or list of tensors): tensor with images feed to the feature extractor with. If
a single tensor it should have shape ``(N, C, H, W)``. If a list of tensors, each tensor should have shape
``(C, H, W)``. ``C`` is the number of channels, ``H`` and ``W`` are the height and width of the image.
- Text: (:class:`~str` or :class:`~list` of :class:`~str`): text to compare with the images, one for each image.
- source: Source input.
This can be:
- Images: ``Tensor`` or list of ``Tensor``
If a single tensor, it should have shape ``(N, C, H, W)``.
If a list of tensors, each tensor should have shape ``(C, H, W)``.
``C`` is the number of channels, ``H`` and ``W`` are the height and width of the image.
- Text: ``str`` or list of ``str``
Either a single caption or a list of captions.
- target: Target input.
This can be:
- Images: ``Tensor`` or list of ``Tensor``
If a single tensor, it should have shape ``(N, C, H, W)``.
If a list of tensors, each tensor should have shape ``(C, H, W)``.
``C`` is the number of channels, ``H`` and ``W`` are the height and width of the image.
- Text: ``str`` or list of ``str``
Either a single caption or a list of captions.
As output of `forward` and `compute` the metric returns the following output
Expand Down

0 comments on commit 1c41c8c

Please sign in to comment.