Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script for Testing Mammo-CLIP Model #10

Closed
GuilhermeJC13 opened this issue Aug 14, 2024 · 21 comments
Closed

Script for Testing Mammo-CLIP Model #10

GuilhermeJC13 opened this issue Aug 14, 2024 · 21 comments

Comments

@GuilhermeJC13
Copy link

Congratulations on your impressive work with the Mammo-CLIP. I am currently interested in testing this model on a new dataset and analyzing the resulting embeddings. Please let me know if there is an existing script available for this purpose.

Thank you for your time and assistance.

@shantanu-ai
Copy link
Member

Hi,
Thanks for taking interest in our work. So u want to get the embeddings from Mammo-CLIP encoders, so that u can test it on a new data, is my understanding correct? Do u want the embeddings from the vision encoder or text encoder?

@shantanu-ai
Copy link
Member

If you want to get the vision embeddings you can refer here

Also you can follow the workflow for the linear probe:

  python ./src/codebase/train_classifier.py \
    --data-dir '/restricted/projectnb/batmanlab/shawn24/PhD/RSNA_Breast_Imaging/Dataset' \
    --img-dir 'External/Vindr/vindr-mammo-a-large-scale-benchmark-dataset-for-computer-aided-detection-and-diagnosis-in-full-field-digital-mammography-1.0.0/images_png' \
    --csv-file 'External/Vindr/vindr-mammo-a-large-scale-benchmark-dataset-for-computer-aided-detection-and-diagnosis-in-full-field-digital-mammography-1.0.0/vindr_detection_v1_folds.csv' \
    --clip_chk_pt_path "/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/src/codebase/outputs/upmc_clip/b5_detector_period_n/checkpoints/fold_0/b5-model-best-epoch-7.tar" \
    --data_frac 1.0 \
    --dataset 'ViNDr' \
    --arch 'upmc_breast_clip_det_b5_period_n_lp' \
    --label "Mass" \
    --epochs 30 \
    --batch-size 8 \
    --num-workers 0 \
    --print-freq 10000 \
    --log-freq 500 \
    --running-interactive 'n' \
    --n_folds 1 \
    --lr 5.0e-5 \
    --weighted-BCE 'y' \
    --balanced-dataloader 'n' 

This script will get the embeddings from the encoders and train a linear classifier at the same time. If you go to experiments.py file (line-296 and 297) and breast_clip_classifier.py (Line-53-56), you get the embeddings. From breast_clip_classifier.py file, you can retrun the embedding directly and save it in the experiments.py file.

@shantanu-ai
Copy link
Member

shantanu-ai commented Aug 15, 2024

@GuilhermeJC13
You can use this:

import torch
import gc
import time
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import f1_score
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup
import sys
sys.path.append('/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/src/codebase/')

from Classifiers.models.breast_clip_classifier import BreastClipClassifier
from Datasets.dataset_utils import get_dataloader_RSNA
from breastclip.scheduler import LinearWarmupCosineAnnealingLR
from metrics import pfbeta_binarized, pr_auc, compute_auprc, auroc, compute_accuracy_np_array
from utils import seed_all, AverageMeter, timeSince
from breastclip.model.modules import load_image_encoder, LinearClassifier

class Args:
    def __init__(self):
        self.tensorboard_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/log'
        self.checkpoints = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/checkpoints'
        self.output_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/out'
        self.data_dir = '/restricted/projectnb/batmanlab/shared/Data/RSNA_Breast_Imaging/Dataset'
        self.img_dir = 'RSNA_Cancer_Detection/train_images_png'
        self.clip_chk_pt_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Breast-CLIP/src/codebase/outputs/upmc_clip/b5_detector_period_n/checkpoints/fold_0/b5-model-best-epoch-7.tar'
        self.csv_file = 'RSNA_Cancer_Detection/train_folds.csv'
        self.dataset = 'RSNA'
        self.data_frac = 1.0
        self.arch = 'upmc_breast_clip_det_b5_period_n_ft'
        self.label = 'cancer'
        self.detector_threshold = 0.1
        self.swin_encoder = 'microsoft/swin-tiny-patch4-window7-224'
        self.pretrained_swin_encoder = 'y'
        self.swin_model_type = 'y'
        self.VER = '084'
        self.epochs_warmup = 0
        self.num_cycles = 0.5
        self.alpha = 10
        self.sigma = 15
        self.p = 1.0
        self.mean = 0.3089279
        self.std = 0.25053555408335154
        self.focal_alpha = 0.6
        self.focal_gamma = 2.0
        self.num_classes = 1
        self.n_folds = 4
        self.start_fold = 0
        self.seed = 10
        self.batch_size = 1
        self.num_workers = 4
        self.epochs = 9
        self.lr = 5.0e-5
        self.weight_decay = 1e-4
        self.warmup_epochs = 1
        self.img_size = [1520, 912]
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.apex = 'y'
        self.print_freq = 5000
        self.log_freq = 1000
        self.running_interactive = 'n'
        self.inference_mode = 'n'
        self.model_type = "Classifier"
        self.weighted_BCE = 'n'
        self.balanced_dataloader = 'n'

# Create an instance of the Args class
args = Args()

# Now you can use args just like you would in your script
print(args.tensorboard_path) 
# /restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/log
args.model_base_name = 'efficientnetb5'
args.data_dir = Path(args.data_dir)
args.df = pd.read_csv(args.data_dir / args.csv_file)
args.df = args.df.fillna(0)
args.cur_fold = 0
args.train_folds = args.df[
                (args.df['fold'] == 1) | (args.df['fold'] == 2)].reset_index(drop=True)
args.valid_folds = args.df[args.df['fold'] == args.cur_fold].reset_index(drop=True)

print(f"train_folds shape: {args.train_folds.shape}")
print(f"valid_folds shape: {args.valid_folds.shape}")
# train_folds shape: (27258, 15)
# valid_folds shape: (13682, 15)

ckpt = torch.load(args.clip_chk_pt_path, map_location="cpu")
args.image_encoder_type = ckpt["config"]["model"]["image_encoder"]["name"]
train_loader, valid_loader = get_dataloader_RSNA(args)
print(f'train_loader: {len(train_loader)}, valid_loader: {len(valid_loader)}')
# Compose([
#   HorizontalFlip(p=0.5),
#   VerticalFlip(p=0.5),
#   Affine(p=0.5, interpolation=1, mask_interpolation=0, cval=0.0, mode=0, scale={'x': (0.8, 1.2), 'y': (0.8, 1.2)}, translate_percent={'x': (0.1, 0.1), 'y': (0.1, 0.1)}, translate_px=None, rotate=(20.0, 20.0), fit_output=False, shear={'x': (20.0, 20.0), 'y': (20.0, 20.0)}, cval_mask=0.0, keep_ratio=False, rotate_method='largest_box', balanced_scale=False),
#   ElasticTransform(p=0.5, alpha=10.0, sigma=15.0, interpolation=1, border_mode=4, value=None, mask_value=None, approximate=False, same_dxdy=False),
# ], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)
# None
# train_loader: 3407, valid_loader: 1711

n_class = 1
print(ckpt["config"]["model"]["image_encoder"])
config = ckpt["config"]["model"]["image_encoder"]
image_encoder = load_image_encoder(ckpt["config"]["model"]["image_encoder"])
image_encoder_weights = {}
for k in ckpt["model"].keys():
    if k.startswith("image_encoder."):
        image_encoder_weights[".".join(k.split(".")[1:])] = ckpt["model"][k]
image_encoder.load_state_dict(image_encoder_weights, strict=True)
image_encoder_type = ckpt["config"]["model"]["image_encoder"]["model_type"]
image_encoder = image_encoder.to(args.device)

print(image_encoder_type)
print(config["name"].lower()) 
# cnn
# tf_efficientnet_b5_ns-detect

progress_iter = tqdm(enumerate(valid_loader), desc=f"[tutorial]",
                     total=len(valid_loader))
for step, data in progress_iter:
    inputs = data['x'].to(args.device)
    inputs = inputs.squeeze(1).permute(0, 3, 1, 2)
    batch_size = inputs.size(0)
 
    image_features = image_encoder(inputs)
    print(image_features.shape)
    break
    # torch.Size([1, 2048])

@GuilhermeJC13
Copy link
Author

@shantanu-ai Is the clip_chk_pt_path, the path to the model available on huggingface?

@shantanu-ai shantanu-ai reopened this Aug 26, 2024
@shantanu-ai
Copy link
Member

@GuilhermeJC13
Copy link
Author

Hello @shantanu-ai,

Every time I try to run the scripts, I keep getting the error "KeyError: filename 'storages' not found" when I try to load the model using "ckpt = torch.load(args.clip_chk_pt_path)". I'm not sure what's causing this error. Could it be due to corrupted or improperly formatted models? Am I doing something wrong in the execution?

I unzipped the hugginface file and turned it into a .tar

@shantanu-ai
Copy link
Member

shantanu-ai commented Aug 28, 2024

Hi @GuilhermeJC13
The files are good. I think you probably set the path in an incorrect way. For b5, download only this file and follow this notebook.
There was a bug in the above notebook for the calling of the encoder, I fixed it and tested it. So now it is good. Thanks for pointing out.

You only need b5-model-best-epoch-7.tar file.

Also this checkpoint is available at google drive

If the problem persists, can u please share the code?

@shantanu-ai shantanu-ai pinned this issue Aug 28, 2024
@shantanu-ai
Copy link
Member

Also, under the hood, the notebook is calling this function

If you want to modify anything custom, you can modify the forward function of the above method.

Also, we uploaded a tutorial notebook on setting up classifier using Mammo-CLIP vision encoder. U can take a look as well.

@shantanu-ai
Copy link
Member

Let me know if you have further issues. If not, let me know if i can close the issue?

@shantanu-ai
Copy link
Member

I am closing the issue. If you have further queries, let us know.

@GuilhermeJC13
Copy link
Author

Hi @shantanu-ai,

Thanks, this helped!

I plan to do the same thing I did with image encoding, but this time with vision and text encoders together. In short, I want to extract the actual embedding from the CLIP model. Do you know if you already have a script to get these embeddings?

I really appreciate your attention!

@shantanu-ai shantanu-ai reopened this Sep 30, 2024
@shantanu-ai
Copy link
Member

Hi @GuilhermeJC13
Can you clarify by "extract the actual embedding from the CLIP model"? Do u want to do that for the text encoder of Mammo-CLIP or u want from the actual clip?

For text embedding from Mammo-CLIP

def save_rsna_text_emb(clip_model, args):
    prompts = create_rsna_mammo_prompts()
    sentences_list_unique = save_sent_dict_rsna(args, sent_level=True)
    idx = 0
    text_embeddings_list = []
    with torch.no_grad():
        with tqdm(total=len(sentences_list_unique)) as t:
            for sent in sentences_list_unique:
                text_token = clip_model["tokenizer"](
                    sent, padding="longest", truncation=True, return_tensors="pt", max_length=256)

                text_emb = clip_model["model"].encode_text(text_token.to(args.device))
                text_emb = clip_model["model"].text_projection(text_emb) if clip_model["model"].projection else text_emb
                text_emb = text_emb / torch.norm(text_emb, dim=1, keepdim=True)
                text_emb = text_emb.detach().cpu().numpy()
                text_embeddings_list.append(text_emb)

                t.set_postfix(batch_id='{0}'.format(idx + 1))
                t.update()
                idx += 1

    text_emb_np = np.concatenate(text_embeddings_list, axis=0)
    print(f"Sent list shape: {len(sentences_list_unique)}")
    print(f"Text embedding shape: {text_emb_np.shape}")
    np.save(args.save_path / f"sent_emb_word_ge_{args.report_word_ge}.npy", text_emb_np)
    print(f"files saved at: {args.save_path}")

Note this code, I copied from another project of mine and that codebase is messy, so it may contain trivial errors which you can fix

For extracting embeddings from CLIP

We compared our model with CLIP as a baseline so did not save the embeddings of CLIP. If you want to setup the baseline, refer to this issue and then u can use the code I shared earlier.

@Al-Dai
Copy link

Al-Dai commented Oct 28, 2024

is possible to evaluate the model provided for downstream on HuggingFace. I don't know what I am doing wrong but I am defining the model

  n_class = 1
  model = BreastClipClassifier(args, ckpt=ckpt, n_class=n_class)
  model.load_state_dict(torch.load(args.clf_chk_pr_path)["model"])
  model = model.to(args.device)
  model.eval()
  
  where clf path   is Downstream_evalualtion_b5_fold0/classification/Models/Classifier/fine_tune/mass/upmc_breast_clip_det_b5_period_n_ft_seed_10_fold0_best_acc_cancer_ver084.pth  and doing the usual prediction 
  
  for step, data in progress_iter:
    inputs = data['x'].to(args.device)
    inputs = inputs.squeeze(1).permute(0, 3, 1, 2)
    batch_size = inputs.size(0)
    with torch.cuda.amp.autocast(enabled=True):
        y_preds = model(inputs)  # Get raw model outputs (logits)
    
        # Apply sigmoid activation to get probabilities
        probabilities = torch.sigmoid(y_preds)
    
        # Compare probabilities with threshold 0.5
        predictions = (probabilities >= 0.5).float()
    
        # Display predictions with labels
        for i, pred in enumerate(predictions):
            label = "Cancer" if pred == 1 else "No Cancer"
            print(f"Sample {i}: {label} (Probability: {probabilities[i].item():.4f})")

and testing this on the folder of rsna , and I am getting 18% correct , I don't know what I am doing wrong.

@shantanu-ai
Copy link
Member

shantanu-ai commented Oct 28, 2024

@Al-Dai
Can you use the valid_fn() in this file. Also, did you preprocess the RSNA images with this script? Also, make sure the transforms are correct.

@Al-Dai
Copy link

Al-Dai commented Oct 28, 2024

sure, I will give it a try! thanks for quick response.

should I keep the args as defined in the notebook or I need to change them for the folder0-downstream weights during intiazation of breastclassifer class ?

class Args:
def __init__(self):
    self.tensorboard_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/log'
    self.checkpoints = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/checkpoints'
    self.output_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/out'
    self.data_dir = '/restricted/projectnb/batmanlab/shared/Data/RSNA_Breast_Imaging/Dataset'
    self.img_dir = 'RSNA_Cancer_Detection/train_images_png'
    self.clip_chk_pt_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Breast-CLIP/src/codebase/outputs/upmc_clip/b5_detector_period_n/checkpoints/fold_0/b5-model-best-epoch-7.tar'
    self.clf_chk_pr_path = '/restricted/projectnb/batmanlab/shawn24/PhD/Mammo-CLIP/src/temp/upmc_breast_clip_det_b5_period_n_ft_seed_10_fold0_best_aucroc_ver084.pth'
    self.csv_file = 'RSNA_Cancer_Detection/train_folds.csv'
    self.dataset = 'RSNA'
    self.data_frac = 1.0
    self.arch = 'upmc_breast_clip_det_b5_period_n_ft'
    self.label = 'cancer'
    self.detector_threshold = 0.1
    self.swin_encoder = 'microsoft/swin-tiny-patch4-window7-224'
    self.pretrained_swin_encoder = 'y'
    self.swin_model_type = 'y'
    self.VER = '084'
    self.epochs_warmup = 0
    self.num_cycles = 0.5
    self.alpha = 10
    self.sigma = 15
    self.p = 1.0
    self.mean = 0.3089279
    self.std = 0.25053555408335154
    self.focal_alpha = 0.6
    self.focal_gamma = 2.0
    self.num_classes = 1
    self.n_folds = 4
    self.start_fold = 0
    self.seed = 10
    self.batch_size = 1
    self.num_workers = 4
    self.epochs = 9
    self.lr = 5.0e-5
    self.weight_decay = 1e-4
    self.warmup_epochs = 1
    self.img_size = [1520, 912]
    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    self.apex = 'y'
    self.print_freq = 5000
    self.log_freq = 1000
    self.running_interactive = 'n'
    self.inference_mode = 'n'
    self.model_type = "Classifier"
    self.weighted_BCE = 'n'
    self.balanced_dataloader = 'n'

@shantanu-ai
Copy link
Member

@Al-Dai
The args and argparse in the train_classifier are the same. If you follow, it will go the BreastClassifier. So, have it accordingly. Important is to do the preprocessing of RSNA. Also, for a sanity check, u can do it on VinDr - mass, calcification and density classification. For VinDr, u dont need to perform preprocessing, we directly uploaded the preprocessed files here.

@Al-Dai
Copy link

Al-Dai commented Oct 28, 2024

Thanks, I did apply the correct transformation and it worked! thanks.

Another side question, If I want to train the Mammo-CLIP from scratch, say with torchrun --nproc_per_node=4 ./src/codebase/train.py --config-name pre_train_b5_clip.yaml, I would need to have upmc datasets right ? or is it possible to just train with vindr or rsna alone?

@kayhan-batmanghelich
Copy link
Contributor

kayhan-batmanghelich commented Oct 28, 2024 via email

@shantanu-ai
Copy link
Member

@Al-Dai
As mentioned by Kayhan, the upmc dataset is private one. VinDr and RSNA do not have reports. To train Mammo-CLIP, u need at least some image+text datasets. The results will be better if you mix an image+label dataset (e.g, RSNA or VinDr) with the image+text dataset. So, if you have any image+text dataset, you can train Mammo-CLIP. Just follow the settings for UPMC for your own image+text data. The text means radiology reports.

@Al-Dai
Copy link

Al-Dai commented Oct 29, 2024

I understand, I went through your work, and I think it's excellent! I wanted to thank you.

One last question: how is the location identified with text? I read in the paper this line: '...With Mammo-FActOR, Mammo-CLIP vision encoder excels in localization tasks, accurately identifying findings like masses and calcifications using descriptive sentences, without relying on ground truth bounding boxes.' I was wondering how these lines are generated or if there is code for it.

@shantanu-ai
Copy link
Member

@Al-Dai ,
So, this is weak localization using text. That's the Mammo-Factor part of the paper. Read Section 2.3. For a TL;DR: we have templated sentences constructed with the help of a radiologist. We use these sentences and the vision encoder of the trained Mammo-CLIP to train a lightweight projector (Eq.3 in the paper) to learn the mapping: which activation unit (neuron) in the representation from Mammo-CLIP corresponds to a mammography finding (mass or calcification).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants