Skip to content

Commit

Permalink
add code for feature distribution analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
hanwenzhao committed Nov 14, 2023
1 parent d527c9f commit 7261ea0
Show file tree
Hide file tree
Showing 14 changed files with 825 additions and 0 deletions.
81 changes: 81 additions & 0 deletions paper_figure/visual_feature_distribution/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Feature Extraction and Distance Matrix Calculation

This Python script is designed for performing feature extraction and distance matrix calculation for a given dataset of images. It supports various models for feature extraction, norm types for distance calculation, and allows customization of parameters through command-line arguments.

## Table of Contents

- [Prerequisites](#prerequisites)
- [Usage](#usage)
- [Customization](#customization)
- [Folder Structure](#folder-structure)
- [Output](#output)
- [Acknowledgments](#acknowledgments)

## Prerequisites

Before using this script, ensure you have the following dependencies installed:

- Python 3.x
- PyTorch
- torchvision
- NumPy
- PIL (Python Imaging Library)
- tqdm

You can typically install these dependencies using `pip`:

```bash
pip install torch torchvision numpy pillow tqdm
```

## Usage
To use this script, follow these steps:

1. Clone this repository or download the script to your local machine.
```
git clone xxx
```
2. Navigate to the directory where the script is located:
```
cd xxx
```
3. Run the script with the desired parameters. Here's the basic usage:
```
python script_name.py --model [model] --image_type [image_type] --norm [norm_type] --csv [csv_file] --batch_size [batch_size] --data_root [data_root_directory]
```
For example,
```
python main.py --model resnet18 --image_type texture --norm l2 --csv "CSV/256_20x.csv"
```
4. Wait for the script to perform feature extraction and distance matrix calculation.
5. The results, including the distance matrix, will be saved in the data folder in the current directory.
## Customization
* Model Selection (--model): You can choose from the following models for feature extraction:
- resnet18
- alexnet
- convnext
- vgg11
- vit
- dinov2
* Image Type (--image_type): Specify the type of images in your dataset. Choose between "heightmap" and "texture."
* Norm Type (--norm): Select the norm type for distance calculation. Currently, only "l2" (Euclidean distance) is supported.
* CSV File (--csv): Provide the path to the CSV file containing image information.
* Batch Size (--batch_size): Set the batch size for feature extraction. The default is 10.
* Data Root Directory (--data_root): Specify the root directory of your dataset.
## Output
The script will generate the following output:
- A distance matrix saved as a NumPy array in the data folder. The filename format is csv_filename_image_type_model_norm.npy.
- Information about the image closest to the mean feature vector will be saved in a file called "closest_to_mean.txt."
48 changes: 48 additions & 0 deletions paper_figure/visual_feature_distribution/closest_to_mean.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
256_20x texture resnet18 l2 44.933701 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/fern/TRIBO_0C1_FLINT_FERN_9h_20X_005_021.bmp
256_20x heightmap resnet18 l2 44.933701 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/fern/TRIBO_0C1_FLINT_FERN_9h_20X_005_021.bmp
256_50x texture resnet18 l2 54.821190 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/beforeuse/TRIBO_023_FLINT_BEFOREUSE_0h_50X_002_016.bmp
256_50x heightmap resnet18 l2 54.821190 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/beforeuse/TRIBO_023_FLINT_BEFOREUSE_0h_50X_002_016.bmp
512_20x texture resnet18 l2 44.804939 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/fern/TRIBO_0C1_FLINT_FERN_9h_20X_020_003.bmp
512_20x heightmap resnet18 l2 44.804939 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/fern/TRIBO_0C1_FLINT_FERN_9h_20X_020_003.bmp
512_50x texture resnet18 l2 51.226322 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_023_FLINT_BEFOREUSE_0h_50X_002_001.bmp
512_50x heightmap resnet18 l2 51.226322 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_023_FLINT_BEFOREUSE_0h_50X_002_001.bmp
865_20x texture resnet18 l2 45.037403 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/bone/TRIBO_008_FLINT_BONE_5h_20X_065_001.bmp
865_20x heightmap resnet18 l2 45.037403 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/bone/TRIBO_008_FLINT_BONE_5h_20X_065_001.bmp
865_50x texture resnet18 l2 56.155617 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/barley/TRIBO_0A2_FLINT_BARLEY_9h_50X_010_001.bmp
865_50x heightmap resnet18 l2 56.155617 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/barley/TRIBO_0A2_FLINT_BARLEY_9h_50X_010_001.bmp
256_20x texture vgg11 l2 24.920313 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/beforeuse/TRIBO_022_FLINT_BEFOREUSE_0h_20X_002_001.bmp
256_20x heightmap vgg11 l2 24.920313 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/beforeuse/TRIBO_022_FLINT_BEFOREUSE_0h_20X_002_001.bmp
256_50x texture vgg11 l2 28.182840 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/beforeuse/TRIBO_008_FLINT_BEFOREUSE_0h_50X_003_005.bmp
256_50x heightmap vgg11 l2 28.182840 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/beforeuse/TRIBO_008_FLINT_BEFOREUSE_0h_50X_003_005.bmp
512_20x texture vgg11 l2 21.526806 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/ivory/TRIBO_014_FLINT_IVORY_5h_20X_017_003.bmp
512_20x heightmap vgg11 l2 21.526806 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/ivory/TRIBO_014_FLINT_IVORY_5h_20X_017_003.bmp
512_50x texture vgg11 l2 25.522821 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_022_FLINT_BEFOREUSE_0h_50X_003_002.bmp
512_50x heightmap vgg11 l2 25.522821 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_022_FLINT_BEFOREUSE_0h_50X_003_002.bmp
865_20x texture vgg11 l2 22.694120 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/barley/TRIBO_0A1_FLINT_BARLEY_9h_20X_009_001.bmp
865_20x heightmap vgg11 l2 22.694120 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/barley/TRIBO_0A1_FLINT_BARLEY_9h_20X_009_001.bmp
865_50x texture vgg11 l2 27.138252 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/barley/TRIBO_0A2_FLINT_BARLEY_9h_50X_014_001.bmp
865_50x heightmap vgg11 l2 27.138252 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/barley/TRIBO_0A2_FLINT_BARLEY_9h_50X_014_001.bmp
256_20x texture convnext l2 21.231182 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/beforeuse/TRIBO_012_FLINT_BEFOREUSE_0h_20X_001_021.bmp
256_20x heightmap convnext l2 21.231182 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/beforeuse/TRIBO_012_FLINT_BEFOREUSE_0h_20X_001_021.bmp
256_50x texture convnext l2 26.675819 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/beforeuse/TRIBO_012_FLINT_BEFOREUSE_0h_50X_003_002.bmp
256_50x heightmap convnext l2 26.675819 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/beforeuse/TRIBO_012_FLINT_BEFOREUSE_0h_50X_003_002.bmp
512_20x texture convnext l2 20.582735 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/bone/TRIBO_008_FLINT_BONE_5h_20X_004_002.bmp
512_20x heightmap convnext l2 20.582735 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/bone/TRIBO_008_FLINT_BONE_5h_20X_004_002.bmp
512_50x texture convnext l2 23.087261 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_008_FLINT_BEFOREUSE_0h_50X_002_003.bmp
512_50x heightmap convnext l2 23.087261 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_008_FLINT_BEFOREUSE_0h_50X_002_003.bmp
865_20x texture convnext l2 20.701147 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/beforeuse/TRIBO_016_FLINT_BEFOREUSE_0h_20X_002_001.bmp
865_20x heightmap convnext l2 20.701147 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/beforeuse/TRIBO_016_FLINT_BEFOREUSE_0h_20X_002_001.bmp
865_50x texture convnext l2 24.126169 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/beforeuse/TRIBO_004_FLINT_BEFOREUSE_0h_50X_001_001.bmp
865_50x heightmap convnext l2 24.126169 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/beforeuse/TRIBO_004_FLINT_BEFOREUSE_0h_50X_001_001.bmp
256_20x texture dinov2 l2 20.723244 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/sprucewood/TRIBO_012_FLINT_SPRUCEWOOD_5h_20X_010_001.bmp
256_20x heightmap dinov2 l2 20.723244 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/sprucewood/TRIBO_012_FLINT_SPRUCEWOOD_5h_20X_010_001.bmp
256_50x texture dinov2 l2 26.677032 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/bone/TRIBO_020_FLINT_BONE_5h_50X_004_013.bmp
256_50x heightmap dinov2 l2 26.677032 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/bone/TRIBO_020_FLINT_BONE_5h_50X_004_013.bmp
512_20x texture dinov2 l2 21.638500 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/fern/TRIBO_0C1_FLINT_FERN_9h_20X_020_006.bmp
512_20x heightmap dinov2 l2 21.638500 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/fern/TRIBO_0C1_FLINT_FERN_9h_20X_020_006.bmp
512_50x texture dinov2 l2 27.832876 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_025_FLINT_BEFOREUSE_0h_50X_004_001.bmp
512_50x heightmap dinov2 l2 27.832876 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_025_FLINT_BEFOREUSE_0h_50X_004_001.bmp
865_20x texture dinov2 l2 21.984348 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/barley/TRIBO_0A2_FLINT_BARLEY_9h_20X_014_001.bmp
865_20x heightmap dinov2 l2 21.984348 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/barley/TRIBO_0A2_FLINT_BARLEY_9h_20X_014_001.bmp
865_50x texture dinov2 l2 28.085300 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/beforeuse/TRIBO_025_FLINT_BEFOREUSE_0h_50X_004_001.bmp
865_50x heightmap dinov2 l2 28.085300 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/beforeuse/TRIBO_025_FLINT_BEFOREUSE_0h_50X_004_001.bmp
Empty file.
190 changes: 190 additions & 0 deletions paper_figure/visual_feature_distribution/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import os, sys
import csv
import numpy as np
import torch
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
from tqdm import tqdm
import argparse
from models import *


def read_image_tensor(image_path):
image = Image.open(image_path)

transform = transforms.ToTensor()
tensor_image = transform(image)

return tensor_image

def read_csv_file(file_path):
data = []
with open(file_path, mode='r', newline='') as file:
reader = csv.DictReader(file)
next(reader) # Skip the header row
for row in reader:
data.append(row)
return data

def csv_to_path(file_path, image_type="texture"):
path = []
# read csv file
csv_data = read_csv_file(file_path)
# check image type
assert image_type in ["texture", "heightmap"]
# use loop to generate path
RES = os.path.basename(file_path)[:-4].split("_")[0]
ZOOM = os.path.basename(file_path)[:-4].split("_")[1]
for row in csv_data:
path.append(os.path.join(DATA_FOLDER_ROOT, RES, ZOOM, image_type, row['Worked Material'].lower(), row['Image Name']))
assert os.path.exists(path[-1]), f"ERROR: CAN NOT LOCATED FILE: {path[-1]}"

return path

def batch_feature_extractor(feature_extractor, preprocess, image_path_list, batch_size, device):

feature_tensor_all = []
print(f"\nBatch Feature Extraction")
for i in tqdm(range(0, len(image_path_list), batch_size)):
image_tensor_batch = []
for i in image_path_list[i:i+batch_size]:
image_tensor_batch.append(read_image_tensor(i))

image_tensor_batch = torch.stack(image_tensor_batch)
batch = preprocess(image_tensor_batch.to(device))
feature = feature_extractor(batch)

# feature_tensor_all.append(feature)
feature_tensor_all.append(feature.detach().cpu().numpy())


# feature_tensor_all = torch.cat(feature_tensor_all, dim=0)
feature_tensor_all = np.vstack(feature_tensor_all)
feature_tensor_all = feature_tensor_all.reshape(feature_tensor_all.shape[0], -1)


return feature_tensor_all


def distance_matrix_gpu(feature_tensor, norm_type, device):

num_vectors = feature_tensor.shape[0]
distance_matrix = torch.zeros((num_vectors, num_vectors)).to(device)

if norm_type == "l2":
distance_matrix = torch.cdist(feature_tensor, feature_tensor, p=2)
elif norm_type == "l1":
distance_matrix = torch.cdist(feature_tensor, feature_tensor, p=1)
elif norm_type == "inf":
distance_matrix = torch.cdist(feature_tensor, feature_tensor, p=0)

return distance_matrix

def find_cloest_to_mean(feature_array, image_path_list, csv_path):
# calculate the mean
mean_feature = np.mean(feature_tensor_all, axis=0)

# Calculate the Euclidean distance between each feature tensor and the mean
distances = np.linalg.norm(feature_array - mean_feature, axis=1)

# Find the index of the feature tensor that is closest to the mean
closest_idx = np.argmin(distances)

filename = image_path_list[closest_idx]

with open(f'closest_to_mean.txt', 'a') as f:
content = f"{os.path.basename(csv_path[:-4])}\t{IMAGE_TYPE}\t{MODEL}\t{NORM_TYPE}\t{distances[closest_idx]:4f}\t{filename}"

f.write(content + '\n')

if __name__ == "__main__":

# change working dir to script location
os.chdir(sys.path[0])

print(f"\n##################################################")
print(f"ML Toolkits: {os.path.basename(os.getcwd())}")

parser = argparse.ArgumentParser(description="Feature extraction and distance matrix calculation")

parser.add_argument("--model", choices=["resnet18", "alexnet", "convnext", "vgg11", "vit", "dinov2"], default="resnet18", help="Choose the model (default: resnet18)")
parser.add_argument("--image_type", choices=["heightmap", "texture"], default="texture", help="Choose the image type (default: texture)")
parser.add_argument("--norm", choices=["l2"], default="l2", help="Choose the norm type (default: l2)")
parser.add_argument("--csv", required=True, help="Path to the CSV file containing image information")
parser.add_argument("--batch_size", type=int, default=10, help="Batch size for feature extraction (default: 10)")
parser.add_argument("--data_root", default="/mnt/SSD_SATA_0/DATASET/LUA_Dataset", help="Root folder of the dataset (LUA_Dataset)")

args = parser.parse_args()
MODEL = args.model
IMAGE_TYPE = args.image_type
NORM_TYPE = args.norm
CSV_PATH = args.csv
BATCH_SIZE = args.batch_size
DATA_FOLDER_ROOT = args.data_root

# Check for GPU availability
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.is_available()}")
else:
device = torch.device("cpu")
print(f"Using GPU: {torch.cuda.is_available()}")

print(f"Model:\t\t{MODEL.upper()}")
print(f"Image Type:\t{IMAGE_TYPE.upper()}")
print(f'CSV FILE:\t{CSV_PATH}')
print(f"Batch Size:\t{BATCH_SIZE}\n")



# read csv file
CSV_PATH = os.path.join(DATA_FOLDER_ROOT, CSV_PATH)
image_path_list = csv_to_path(CSV_PATH)
# image_path_list = [os.path.join(DATA_FOLDER_ROOT, i) for i in image_path_list]
print(f"Number of Images: {len(image_path_list)}")

################## CHANGE THIS PART FOR DIFFERENT MODEL ##################
# feature extractor initialization
if MODEL == "resnet18":
feature_extractor, preprocess = feature_extractor_resnet18()
elif MODEL == "alexnet":
feature_extractor, preprocess = feature_extractor_alexnet()
elif MODEL == "convnext":
feature_extractor, preprocess = feature_extractor_convnext()
elif MODEL == "vgg11":
feature_extractor, preprocess = feature_extractor_vgg11()
elif MODEL == "vit":
feature_extractor, preprocess = feature_extractor_vit()
elif MODEL == "dinov2":
feature_extractor, preprocess = feature_extractor_dinov2()
else:
print("ERROR: MODEL NOT SUPPORTED")

##########################################################################

# perform batch feature extraction
with torch.no_grad():
feature_tensor_all = batch_feature_extractor(feature_extractor.to(device), preprocess, image_path_list, BATCH_SIZE, device)

# find the image that is cloest to the mean
find_cloest_to_mean(feature_tensor_all, image_path_list, CSV_PATH)

print(f"Feature Tensor: {feature_tensor_all.shape}")
# feature_tensor_all = feature_tensor_all.reshape(feature_tensor_all.size(0), -1)
feature_tensor_all = torch.from_numpy(feature_tensor_all).to(device)

# calculate distance matrix
distance_matrix = distance_matrix_gpu(feature_tensor_all, NORM_TYPE, device)
print(f"Distance Matrix: {distance_matrix.shape}")

# save distance matrix
filename = f"{os.path.basename(CSV_PATH[:-4])}_{IMAGE_TYPE}_{MODEL}_{NORM_TYPE}.npy"
print(f"\nSaving distance matrix to {filename}...")
np.save(f"data/{filename}", distance_matrix.cpu().numpy())

print("Done\n")




Loading

0 comments on commit 7261ea0

Please sign in to comment.