Skip to content

Commit

Permalink
added cupy acceleration
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucas Camillo authored and Lucas Camillo committed Aug 3, 2024
1 parent 827c3ae commit 3df8458
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
1 change: 1 addition & 0 deletions pyaging/predict/_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,6 @@ def predict_age(

# Flush memory
gc.collect()
torch.cuda.empty_cache()

logger.done()
23 changes: 11 additions & 12 deletions pyaging/predict/_pred_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
from anndata.experimental.pytorch import AnnLoader
from torch.utils.data import DataLoader, TensorDataset

try:
import cupy as cp
CUPY_AVAILABLE = cp.cuda.is_available()
except:
CUPY_AVAILABLE = False

from ..logger import LoggerManager, main_tqdm, silence_logger
from ..models import *
from ..utils import download, load_clock_metadata, progress
Expand Down Expand Up @@ -152,7 +158,7 @@ def check_features_in_adata(
"""

# Preallocate the data matrix
X_model = np.empty((adata.n_obs, len(model.features)), order="F")
adata.obsm[f"X_{model.metadata['clock_name']}"] = cp.empty((adata.n_obs, len(model.features))) if CUPY_AVAILABLE else np.empty((adata.n_obs, len(model.features)), order="F")

# Find indices of matching features in adata.var_names
feature_indices = {feature: i for i, feature in enumerate(adata.var_names)}
Expand All @@ -165,13 +171,10 @@ def check_features_in_adata(
# Assign values for existing features
existing_features_mask = ~missing_features_mask
existing_features_indices = model_feature_indices[existing_features_mask]
X_model[:, existing_features_mask] = np.asfortranarray(adata.X)[:, existing_features_indices]
adata.obsm[f"X_{model.metadata['clock_name']}"][:, existing_features_mask] = adata.X[:, existing_features_indices]

# Handle missing features
if model.reference_values is not None:
X_model[:, missing_features_mask] = np.array(model.reference_values)[missing_features_mask]
else:
X_model[:, missing_features_mask] = 0
adata.obsm[f"X_{model.metadata['clock_name']}"][:, missing_features_mask] = np.array(model.reference_values)[missing_features_mask] if model.reference_values is not None else 0

# Calculate missing features statistics
num_missing_features = len(missing_features)
Expand Down Expand Up @@ -215,10 +218,6 @@ def check_features_in_adata(
indent_level=indent_level + 1,
)

# Add matrix to obsm
adata.obsm[f"X_{model.metadata['clock_name']}"] = X_model


@progress("Predict ages with model")
def predict_ages_with_model(
adata: anndata.AnnData,
Expand Down Expand Up @@ -303,14 +302,14 @@ def predict_ages_with_model(

# Use the AnnLoader for batched prediction
predictions = []
with torch.no_grad():
with torch.inference_mode():
for batch in main_tqdm(dataloader, indent_level=indent_level + 1, logger=logger):
batch_pred = model(batch.obsm[f"X_{model.metadata['clock_name']}"])
predictions.append(batch_pred)
# Concatenate all batch predictions
predictions = torch.cat(predictions)
return predictions

return predictions

@progress("Add predicted ages and clock metadata to adata")
def add_pred_ages_and_clock_metadata_adata(
Expand Down
9 changes: 9 additions & 0 deletions pyaging/preprocess/_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
PYBIGWIG_AVAILABLE = True
except ImportError:
PYBIGWIG_AVAILABLE = False

try:
import cupy as cp
CUPY_AVAILABLE = cp.cuda.is_available()
except:
CUPY_AVAILABLE = False

from ..logger import LoggerManager, main_tqdm, silence_logger
from ._preprocess_utils import *
Expand Down Expand Up @@ -200,6 +206,9 @@ def df_to_adata(
if "X_imputed" in adata.layers:
add_unstructured_data(adata, imputer_strategy, logger)

# Move adata.X to GPU if possible
adata.X = cp.array(adata.X) if CUPY_AVAILABLE else np.asfortranarray(adata.X)

logger.done()

return adata
Expand Down

0 comments on commit 3df8458

Please sign in to comment.