Skip to content

Commit

Permalink
Comment out unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Dec 20, 2023
1 parent 45dc548 commit 96652ca
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 52 deletions.
85 changes: 41 additions & 44 deletions src/careamics/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
These functions are used to control certain aspects and behaviours of PyTorch.
"""
import logging
import os
import sys

import torch
Expand Down Expand Up @@ -48,46 +47,44 @@ def compile_model(model: torch.nn.Module) -> torch.nn.Module:
return model

Check warning on line 47 in src/careamics/utils/torch_utils.py

View check run for this annotation

Codecov / codecov/patch

src/careamics/utils/torch_utils.py#L47

Added line #L47 was not covered by tests


def seed_everything(seed: int) -> None:
"""
Seed all random number generators for reproducibility.
Parameters
----------
seed : int
Seed.
"""
import random

import numpy as np

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def setup_cudnn_reproducibility(
deterministic: bool = True, benchmark: bool = True
) -> None:
"""
Prepare CuDNN benchmark and sets it to be deterministic/non-deterministic mode.
https://pytorch.org/docs/stable/notes/randomness.html#cuda-convolution-benchmarking.
Parameters
----------
deterministic : bool
Deterministic mode, if running CuDNN backend.
benchmark : bool
If True, uses CuDNN heuristics to figure out which algorithm will be most
performant for your model architecture and input. False may slow down training.
"""
if torch.cuda.is_available():
if deterministic:
deterministic = os.environ.get("CUDNN_DETERMINISTIC", "True") == "True"
torch.backends.cudnn.deterministic = deterministic

if benchmark:
benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True"
torch.backends.cudnn.benchmark = benchmark
# def seed_everything(seed: int) -> None:
# """
# Seed all random number generators for reproducibility.

# Parameters
# ----------
# seed : int
# Seed.
# """
# import random

# import numpy as np

# random.seed(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)


# def setup_cudnn_reproducibility(
# deterministic: bool = True, benchmark: bool = True
# ) -> None:
# """
# Prepare CuDNN benchmark and sets it to be deterministic/non-deterministic mode.

# Parameters
# ----------
# deterministic : bool
# Deterministic mode, if running CuDNN backend.
# benchmark : bool
# If True, uses CuDNN heuristics to figure out which algorithm will be most
# performant for your model architecture and input. False may slow down training
# """
# if torch.cuda.is_available():
# if deterministic:
# deterministic = os.environ.get("CUDNN_DETERMINISTIC", "True") == "True"
# torch.backends.cudnn.deterministic = deterministic

# if benchmark:
# benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True"
# torch.backends.cudnn.benchmark = benchmark
15 changes: 7 additions & 8 deletions tests/utils/test_torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from careamics.utils.torch_utils import (
get_device,
setup_cudnn_reproducibility,
)


Expand All @@ -14,10 +13,10 @@ def test_get_device(device):
assert device.type == "cuda" if torch.cuda.is_available() else "cpu"


@pytest.mark.gpu
@pytest.mark.parametrize("deterministic", [True, False])
@pytest.mark.parametrize("benchmark", [True, False])
def test_setup_cudnn_reproducibility(deterministic, benchmark):
setup_cudnn_reproducibility(deterministic=deterministic, benchmark=benchmark)
assert torch.backends.cudnn.deterministic == deterministic
assert torch.backends.cudnn.benchmark == benchmark
# @pytest.mark.gpu
# @pytest.mark.parametrize("deterministic", [True, False])
# @pytest.mark.parametrize("benchmark", [True, False])
# def test_setup_cudnn_reproducibility(deterministic, benchmark):
# setup_cudnn_reproducibility(deterministic=deterministic, benchmark=benchmark)
# assert torch.backends.cudnn.deterministic == deterministic
# assert torch.backends.cudnn.benchmark == benchmark

0 comments on commit 96652ca

Please sign in to comment.