Skip to content

Commit

Permalink
Fixes frd-score and requirements.
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardObi committed Apr 24, 2024
1 parent 3709a66 commit 80a01a7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 96 deletions.
14 changes: 8 additions & 6 deletions requirements.in
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
nibabel~=3.2.1
nox~=2024.4.15
# Library dependencies
numpy>=1.23.5
opencv_contrib_python_headless~=4.8.1.78
Pillow~=10.3.0
pyradiomics==3.0.1a3
pytest~=8.1.1
scipy~=1.10.0
setuptools~=65.6.3
SimpleITK~=2.3.1
torch~=2.0.0
tqdm~=4.64.1
tqdm~=4.64.1

# Required for testing
setuptools~=65.6.3
pytest~=8.1.1
nibabel~=3.2.1
nox~=2024.4.15
107 changes: 32 additions & 75 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,119 +4,76 @@
#
# pip-compile
#
argcomplete~=3.3.0
argcomplete==3.3.0
# via nox
colorlog~=6.8.2
colorlog==6.8.2
# via nox
distlib~=0.3.8
distlib==0.3.8
# via virtualenv
docopt~=0.6.2
docopt==0.6.2
# via pykwalify
exceptiongroup~=1.2.0
exceptiongroup==1.2.1
# via pytest
filelock~=3.13.4
# via
# torch
# virtualenv
imageio~=2.34.0
# via scikit-image
iniconfig~=2.0.0
filelock==3.13.4
# via virtualenv
iniconfig==2.0.0
# via pytest
jinja2~=3.1.3
# via torch
lazy-loader~=0.4
# via scikit-image
markupsafe~=2.1.5
# via jinja2
mpmath~=1.2.1
# via sympy
networkx~=3.3
# via
# scikit-image
# torch
nibabel~=3.2.1
nibabel==3.2.2
# via -r requirements.in
nox~=2024.4.15
nox==2024.4.15
# via -r requirements.in
numpy>=1.23.5
numpy==1.26.4
# via
# -r requirements.in
# imageio
# nibabel
# opencv-contrib-python-headless
# pyradiomics
# pywavelets
# radiomics
# scikit-image
# scipy
# tifffile
opencv-contrib-python-headless~=4.8.1.78
opencv-contrib-python-headless==4.8.1.78
# via -r requirements.in
packaging~=24.0
packaging==24.0
# via
# lazy-loader
# nibabel
# nox
# pytest
# scikit-image
pillow~=10.3.0
# via
# -r requirements.in
# imageio
# scikit-image
platformdirs~=4.2.0
pillow==10.3.0
# via -r requirements.in
platformdirs==4.2.1
# via virtualenv
pluggy~=1.4.0
pluggy==1.5.0
# via pytest
pykwalify~=1.8.0
pykwalify==1.8.0
# via pyradiomics
pyradiomics==3.0.1a3 #3.1.0 #~=3.0.1a3
#pyradiomics @ git+https://github.com/AIM-Harvard/pyradiomics@releases/tag/v3.1.0
pyradiomics==3.0.1a3
# via -r requirements.in
pytest~=8.1.1
pytest==8.1.1
# via -r requirements.in
python-dateutil~=2.9.0.post0
python-dateutil==2.9.0.post0
# via pykwalify
pywavelets~=1.6.0
# via
# pyradiomics
# radiomics
radiomics~=0.1
# via -r requirements.in
ruamel-yaml~=0.18.6
pywavelets==1.6.0
# via pyradiomics
ruamel-yaml==0.18.6
# via pykwalify
ruamel-yaml-clib~=0.2.8
ruamel-yaml-clib==0.2.8
# via ruamel-yaml
scikit-image~=0.23.1
# via radiomics
scipy~=1.10.0
# via
# -r requirements.in
# radiomics
# scikit-image
simpleitk~=2.3.1
scipy==1.10.1
# via -r requirements.in
simpleitk==2.3.1
# via
# -r requirements.in
# pyradiomics
six~=1.16.0
six==1.16.0
# via
# pyradiomics
# python-dateutil
sympy~=1.12
# via torch
tifffile~=2024.4.18
# via scikit-image
tomli~=2.0.1
tomli==2.0.1
# via
# nox
# pytest
torch~=2.0.0
# via -r requirements.in
tqdm~=4.64.1
tqdm==4.64.1
# via -r requirements.in
typing-extensions~=4.11.0
# via torch
virtualenv~=20.25.3
virtualenv==20.26.0
# via nox

# The following packages are considered to be unsafe in a requirements file:
Expand Down
28 changes: 13 additions & 15 deletions src/frd/frd_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@
import pathlib
import time
from pathlib import Path
from typing import List

import cv2
import numpy as np
import SimpleITK as sitk
import torch
from PIL import Image

from radiomics import featureextractor
from scipy import linalg, stats
from scipy import linalg
from tqdm import tqdm


Expand Down Expand Up @@ -59,7 +57,7 @@ def parse_args() -> argparse.Namespace:
"features."
)
parser.add_argument(
"--paths",
"paths",
type=str,
nargs=2,
help="The two paths to the generated images or to .npz statistic files",
Expand Down Expand Up @@ -752,7 +750,7 @@ def calculate_frd_given_paths(
normalization_type,
normalization_range,
is_mask_used=True,
paths_mask=None,
paths_masks=None,
resize_size=None,
verbose=False,
save_features=True,
Expand All @@ -771,7 +769,7 @@ def calculate_frd_given_paths(
normalization_range,
feature_extractor,
is_mask_used=is_mask_used,
path_mask=None if paths_mask is None else paths_mask[0],
path_mask=None if paths_masks is None else paths_masks[0],
resize_size=resize_size,
verbose=verbose,
save_features=save_features,
Expand All @@ -784,7 +782,7 @@ def calculate_frd_given_paths(
normalization_range,
feature_extractor,
is_mask_used=is_mask_used,
path_mask=None if paths_mask is None else paths_mask[1],
path_mask=None if paths_masks is None else paths_masks[1],
resize_size=resize_size,
verbose=verbose,
save_features=save_features,
Expand All @@ -802,7 +800,7 @@ def save_frd_stats(
normalization_type: str,
normalization_range: list,
is_mask_used=True,
paths_mask=None,
paths_masks=None,
resize_size=None,
verbose=False,
save_features=True,
Expand Down Expand Up @@ -835,7 +833,7 @@ def save_frd_stats(
normalization_range=normalization_range,
feature_extractor=feature_extractor,
is_mask_used=is_mask_used,
path_mask=None if paths_mask is None else paths_mask[0],
path_mask=None if paths_masks is None else paths_masks[0],
resize_size=resize_size,
verbose=verbose,
save_features=save_features,
Expand All @@ -859,32 +857,32 @@ def main():

if args.save_stats:
save_frd_stats(
args.path,
args.paths,
features=features,
normalization_type=args.normalization_type,
normalization_range=args.normalization_range,
is_mask_used=args.is_mask_used,
paths_mask=args.paths_mask,
paths_masks=args.paths_masks,
resize_size=args.resize_size,
verbose=args.verbose,
save_features=args.save_features,
)
return

frd_value = calculate_frd_given_paths(
args.path,
args.paths,
features=features,
normalization_type=args.normalization_type,
normalization_range=args.normalization_range,
is_mask_used=args.is_mask_used,
paths_mask=args.paths_mask,
paths_masks=args.paths_masks,
resize_size=args.resize_size,
verbose=args.verbose,
save_features=args.save_features,
)
# logging the result
logging.info(
f"Fréchet Radiomics Distance: {frd_value}. Based on features: {features} with normalization type: {args.normalization_type} and normalization range: {args.normalization_range}{f', with masks: {args.paths_mask}' if args.is_mask_used else ''}{f', resized to {args.resize_size}' if args.resize is not None else ''}."
f"Fréchet Radiomics Distance: {frd_value}. Based on features: {features} with normalization type: {args.normalization_type} and normalization range: {args.normalization_range}{f', with masks: {args.paths_masks}' if args.is_mask_used else ''}{f', resized to {args.resize_size}' if args.resize_size is not None else ''}."
)
print(f"FRD: {frd_value}")

Expand Down

0 comments on commit 80a01a7

Please sign in to comment.