Skip to content

Commit

Permalink
CDL: some requested changes. Remove redundant variable, remove unused…
Browse files Browse the repository at this point in the history
… dtype arg, rename set_device, etc
  • Loading branch information
cleong110 committed Dec 5, 2024
1 parent a495c67 commit 0e54bf9
Showing 1 changed file with 8 additions and 35 deletions.
43 changes: 8 additions & 35 deletions pose_evaluation/metrics/embedding_distance_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ def __init__(
self,
kind: ValidDistanceKinds = "cosine",
device: Union[torch.device, str] = None,
dtype=torch.float32,
dtype=None,
):
"""
Initialize the embedding distance metric.
Args:
kind (ValidDistanceKinds): The type of distance metric.
kind (ValidDistanceKinds): The type of distance metric, e.g. "cosine", or "euclidean".
device (Union[torch.device, str]): The device to use for computation.
If None, automatically detects.
dtype (torch.dtype): The data type to use for tensors.
Expand All @@ -52,6 +50,8 @@ def __init__(
if dtype is None:
dtype = torch.get_default_dtype()

self.dtype = dtype

# Dispatch table for metric computations
self._metric_dispatch = {
"cosine": self.cosine_distances,
Expand All @@ -60,9 +60,7 @@ def __init__(
"manhattan": self.manhattan_distances,
}

self.dtype = dtype

def set_device(self, device: Union[torch.device, str]) -> None:
def to(self, device: Union[torch.device, str]) -> None:
"""
Explicitly set the device used for tensors.
Expand All @@ -72,35 +70,16 @@ def set_device(self, device: Union[torch.device, str]) -> None:
self.device = torch.device(device)
logger.info(f"Device set to: {self.device}")

def _to_tensor_on_device(self, data: TensorConvertableType, dtype=None) -> Tensor:
"""
Convert input data to a tensor on the specified device.
Args:
data (TensorConvertableType: The input data to convert.
dtype (torch.dtype): The data type for the tensor.
Returns:
Tensor: Tensor representation of the data on the specified device.
"""
if dtype is None:
dtype = self.dtype
return st_util._convert_to_tensor(data).to(device=self.device, dtype=dtype)

def _to_batch_tensor_on_device(self, data: TensorConvertableType, dtype=None) -> Tensor:
def _to_batch_tensor_on_device(self, data: TensorConvertableType) -> Tensor:
"""
Convert input data to a batch tensor on the specified device.
Args:
data (TensorConvertableType): The input data to convert.
dtype (torch.dtype): The data type for the tensor.
Returns:
Tensor: Batch tensor representation of the data on the specified device.
"""
if dtype is None:
dtype = self.dtype

# better performance this way, see https://github.com/pytorch/pytorch/issues/13918
if isinstance(data, list) and all(isinstance(x, np.ndarray) for x in data):
data = np.asanyarray(data)
Expand All @@ -110,7 +89,7 @@ def _to_batch_tensor_on_device(self, data: TensorConvertableType, dtype=None) ->
# https://stackoverflow.com/questions/55050717/converting-list-of-tensors-to-tensors-pytorch
data = torch.stack(data)

return st_util._convert_to_batch_tensor(data).to(device=self.device, dtype=dtype)
return st_util._convert_to_batch_tensor(data).to(device=self.device)

def score(
self,
Expand All @@ -128,7 +107,6 @@ def score(
Number: The calculated distance.
"""

return self.score_all(hypothesis, reference).item()

def score_all(
Expand Down Expand Up @@ -163,12 +141,7 @@ def score_all(
except RuntimeError as e:
raise TypeError(f"Inputs must support conversion to device tensors: {e}") from e

if self.kind not in self._metric_dispatch:
logger.error(f"Unsupported distance metric: {self.kind}")
raise ValueError(f"Unsupported distance metric: {self.kind}")

distance_matrix = self._metric_dispatch[self.kind](hypotheses, references)
return distance_matrix
return self._metric_dispatch[self.kind](hypotheses, references)

def dot_product(self, hypotheses: TensorConvertableType, references: TensorConvertableType) -> Tensor:
"""
Expand Down

0 comments on commit 0e54bf9

Please sign in to comment.