-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add code for feature distribution analysis
- Loading branch information
1 parent
d527c9f
commit 7261ea0
Showing
14 changed files
with
825 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Feature Extraction and Distance Matrix Calculation | ||
|
||
This Python script is designed for performing feature extraction and distance matrix calculation for a given dataset of images. It supports various models for feature extraction, norm types for distance calculation, and allows customization of parameters through command-line arguments. | ||
|
||
## Table of Contents | ||
|
||
- [Prerequisites](#prerequisites) | ||
- [Usage](#usage) | ||
- [Customization](#customization) | ||
- [Folder Structure](#folder-structure) | ||
- [Output](#output) | ||
- [Acknowledgments](#acknowledgments) | ||
|
||
## Prerequisites | ||
|
||
Before using this script, ensure you have the following dependencies installed: | ||
|
||
- Python 3.x | ||
- PyTorch | ||
- torchvision | ||
- NumPy | ||
- PIL (Python Imaging Library) | ||
- tqdm | ||
|
||
You can typically install these dependencies using `pip`: | ||
|
||
```bash | ||
pip install torch torchvision numpy pillow tqdm | ||
``` | ||
|
||
## Usage | ||
To use this script, follow these steps: | ||
|
||
1. Clone this repository or download the script to your local machine. | ||
``` | ||
git clone xxx | ||
``` | ||
2. Navigate to the directory where the script is located: | ||
``` | ||
cd xxx | ||
``` | ||
3. Run the script with the desired parameters. Here's the basic usage: | ||
``` | ||
python script_name.py --model [model] --image_type [image_type] --norm [norm_type] --csv [csv_file] --batch_size [batch_size] --data_root [data_root_directory] | ||
``` | ||
For example, | ||
``` | ||
python main.py --model resnet18 --image_type texture --norm l2 --csv "CSV/256_20x.csv" | ||
``` | ||
4. Wait for the script to perform feature extraction and distance matrix calculation. | ||
5. The results, including the distance matrix, will be saved in the data folder in the current directory. | ||
## Customization | ||
* Model Selection (--model): You can choose from the following models for feature extraction: | ||
- resnet18 | ||
- alexnet | ||
- convnext | ||
- vgg11 | ||
- vit | ||
- dinov2 | ||
* Image Type (--image_type): Specify the type of images in your dataset. Choose between "heightmap" and "texture." | ||
* Norm Type (--norm): Select the norm type for distance calculation. Currently, only "l2" (Euclidean distance) is supported. | ||
* CSV File (--csv): Provide the path to the CSV file containing image information. | ||
* Batch Size (--batch_size): Set the batch size for feature extraction. The default is 10. | ||
* Data Root Directory (--data_root): Specify the root directory of your dataset. | ||
## Output | ||
The script will generate the following output: | ||
- A distance matrix saved as a NumPy array in the data folder. The filename format is csv_filename_image_type_model_norm.npy. | ||
- Information about the image closest to the mean feature vector will be saved in a file called "closest_to_mean.txt." |
48 changes: 48 additions & 0 deletions
48
paper_figure/visual_feature_distribution/closest_to_mean.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
256_20x texture resnet18 l2 44.933701 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/fern/TRIBO_0C1_FLINT_FERN_9h_20X_005_021.bmp | ||
256_20x heightmap resnet18 l2 44.933701 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/fern/TRIBO_0C1_FLINT_FERN_9h_20X_005_021.bmp | ||
256_50x texture resnet18 l2 54.821190 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/beforeuse/TRIBO_023_FLINT_BEFOREUSE_0h_50X_002_016.bmp | ||
256_50x heightmap resnet18 l2 54.821190 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/beforeuse/TRIBO_023_FLINT_BEFOREUSE_0h_50X_002_016.bmp | ||
512_20x texture resnet18 l2 44.804939 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/fern/TRIBO_0C1_FLINT_FERN_9h_20X_020_003.bmp | ||
512_20x heightmap resnet18 l2 44.804939 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/fern/TRIBO_0C1_FLINT_FERN_9h_20X_020_003.bmp | ||
512_50x texture resnet18 l2 51.226322 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_023_FLINT_BEFOREUSE_0h_50X_002_001.bmp | ||
512_50x heightmap resnet18 l2 51.226322 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_023_FLINT_BEFOREUSE_0h_50X_002_001.bmp | ||
865_20x texture resnet18 l2 45.037403 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/bone/TRIBO_008_FLINT_BONE_5h_20X_065_001.bmp | ||
865_20x heightmap resnet18 l2 45.037403 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/bone/TRIBO_008_FLINT_BONE_5h_20X_065_001.bmp | ||
865_50x texture resnet18 l2 56.155617 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/barley/TRIBO_0A2_FLINT_BARLEY_9h_50X_010_001.bmp | ||
865_50x heightmap resnet18 l2 56.155617 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/barley/TRIBO_0A2_FLINT_BARLEY_9h_50X_010_001.bmp | ||
256_20x texture vgg11 l2 24.920313 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/beforeuse/TRIBO_022_FLINT_BEFOREUSE_0h_20X_002_001.bmp | ||
256_20x heightmap vgg11 l2 24.920313 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/beforeuse/TRIBO_022_FLINT_BEFOREUSE_0h_20X_002_001.bmp | ||
256_50x texture vgg11 l2 28.182840 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/beforeuse/TRIBO_008_FLINT_BEFOREUSE_0h_50X_003_005.bmp | ||
256_50x heightmap vgg11 l2 28.182840 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/beforeuse/TRIBO_008_FLINT_BEFOREUSE_0h_50X_003_005.bmp | ||
512_20x texture vgg11 l2 21.526806 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/ivory/TRIBO_014_FLINT_IVORY_5h_20X_017_003.bmp | ||
512_20x heightmap vgg11 l2 21.526806 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/ivory/TRIBO_014_FLINT_IVORY_5h_20X_017_003.bmp | ||
512_50x texture vgg11 l2 25.522821 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_022_FLINT_BEFOREUSE_0h_50X_003_002.bmp | ||
512_50x heightmap vgg11 l2 25.522821 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_022_FLINT_BEFOREUSE_0h_50X_003_002.bmp | ||
865_20x texture vgg11 l2 22.694120 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/barley/TRIBO_0A1_FLINT_BARLEY_9h_20X_009_001.bmp | ||
865_20x heightmap vgg11 l2 22.694120 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/barley/TRIBO_0A1_FLINT_BARLEY_9h_20X_009_001.bmp | ||
865_50x texture vgg11 l2 27.138252 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/barley/TRIBO_0A2_FLINT_BARLEY_9h_50X_014_001.bmp | ||
865_50x heightmap vgg11 l2 27.138252 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/barley/TRIBO_0A2_FLINT_BARLEY_9h_50X_014_001.bmp | ||
256_20x texture convnext l2 21.231182 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/beforeuse/TRIBO_012_FLINT_BEFOREUSE_0h_20X_001_021.bmp | ||
256_20x heightmap convnext l2 21.231182 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/beforeuse/TRIBO_012_FLINT_BEFOREUSE_0h_20X_001_021.bmp | ||
256_50x texture convnext l2 26.675819 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/beforeuse/TRIBO_012_FLINT_BEFOREUSE_0h_50X_003_002.bmp | ||
256_50x heightmap convnext l2 26.675819 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/beforeuse/TRIBO_012_FLINT_BEFOREUSE_0h_50X_003_002.bmp | ||
512_20x texture convnext l2 20.582735 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/bone/TRIBO_008_FLINT_BONE_5h_20X_004_002.bmp | ||
512_20x heightmap convnext l2 20.582735 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/bone/TRIBO_008_FLINT_BONE_5h_20X_004_002.bmp | ||
512_50x texture convnext l2 23.087261 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_008_FLINT_BEFOREUSE_0h_50X_002_003.bmp | ||
512_50x heightmap convnext l2 23.087261 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_008_FLINT_BEFOREUSE_0h_50X_002_003.bmp | ||
865_20x texture convnext l2 20.701147 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/beforeuse/TRIBO_016_FLINT_BEFOREUSE_0h_20X_002_001.bmp | ||
865_20x heightmap convnext l2 20.701147 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/beforeuse/TRIBO_016_FLINT_BEFOREUSE_0h_20X_002_001.bmp | ||
865_50x texture convnext l2 24.126169 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/beforeuse/TRIBO_004_FLINT_BEFOREUSE_0h_50X_001_001.bmp | ||
865_50x heightmap convnext l2 24.126169 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/beforeuse/TRIBO_004_FLINT_BEFOREUSE_0h_50X_001_001.bmp | ||
256_20x texture dinov2 l2 20.723244 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/sprucewood/TRIBO_012_FLINT_SPRUCEWOOD_5h_20X_010_001.bmp | ||
256_20x heightmap dinov2 l2 20.723244 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/20x/texture/sprucewood/TRIBO_012_FLINT_SPRUCEWOOD_5h_20X_010_001.bmp | ||
256_50x texture dinov2 l2 26.677032 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/bone/TRIBO_020_FLINT_BONE_5h_50X_004_013.bmp | ||
256_50x heightmap dinov2 l2 26.677032 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/256/50x/texture/bone/TRIBO_020_FLINT_BONE_5h_50X_004_013.bmp | ||
512_20x texture dinov2 l2 21.638500 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/fern/TRIBO_0C1_FLINT_FERN_9h_20X_020_006.bmp | ||
512_20x heightmap dinov2 l2 21.638500 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/20x/texture/fern/TRIBO_0C1_FLINT_FERN_9h_20X_020_006.bmp | ||
512_50x texture dinov2 l2 27.832876 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_025_FLINT_BEFOREUSE_0h_50X_004_001.bmp | ||
512_50x heightmap dinov2 l2 27.832876 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/512/50x/texture/beforeuse/TRIBO_025_FLINT_BEFOREUSE_0h_50X_004_001.bmp | ||
865_20x texture dinov2 l2 21.984348 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/barley/TRIBO_0A2_FLINT_BARLEY_9h_20X_014_001.bmp | ||
865_20x heightmap dinov2 l2 21.984348 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/20x/texture/barley/TRIBO_0A2_FLINT_BARLEY_9h_20X_014_001.bmp | ||
865_50x texture dinov2 l2 28.085300 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/beforeuse/TRIBO_025_FLINT_BEFOREUSE_0h_50X_004_001.bmp | ||
865_50x heightmap dinov2 l2 28.085300 /mnt/SSD_SATA_0/DATASET/LUA_Dataset/865/50x/texture/beforeuse/TRIBO_025_FLINT_BEFOREUSE_0h_50X_004_001.bmp |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
import os, sys | ||
import csv | ||
import numpy as np | ||
import torch | ||
from torchvision import transforms | ||
from torchvision.models import resnet50, ResNet50_Weights | ||
from PIL import Image | ||
from tqdm import tqdm | ||
import argparse | ||
from models import * | ||
|
||
|
||
def read_image_tensor(image_path): | ||
image = Image.open(image_path) | ||
|
||
transform = transforms.ToTensor() | ||
tensor_image = transform(image) | ||
|
||
return tensor_image | ||
|
||
def read_csv_file(file_path): | ||
data = [] | ||
with open(file_path, mode='r', newline='') as file: | ||
reader = csv.DictReader(file) | ||
next(reader) # Skip the header row | ||
for row in reader: | ||
data.append(row) | ||
return data | ||
|
||
def csv_to_path(file_path, image_type="texture"): | ||
path = [] | ||
# read csv file | ||
csv_data = read_csv_file(file_path) | ||
# check image type | ||
assert image_type in ["texture", "heightmap"] | ||
# use loop to generate path | ||
RES = os.path.basename(file_path)[:-4].split("_")[0] | ||
ZOOM = os.path.basename(file_path)[:-4].split("_")[1] | ||
for row in csv_data: | ||
path.append(os.path.join(DATA_FOLDER_ROOT, RES, ZOOM, image_type, row['Worked Material'].lower(), row['Image Name'])) | ||
assert os.path.exists(path[-1]), f"ERROR: CAN NOT LOCATED FILE: {path[-1]}" | ||
|
||
return path | ||
|
||
def batch_feature_extractor(feature_extractor, preprocess, image_path_list, batch_size, device): | ||
|
||
feature_tensor_all = [] | ||
print(f"\nBatch Feature Extraction") | ||
for i in tqdm(range(0, len(image_path_list), batch_size)): | ||
image_tensor_batch = [] | ||
for i in image_path_list[i:i+batch_size]: | ||
image_tensor_batch.append(read_image_tensor(i)) | ||
|
||
image_tensor_batch = torch.stack(image_tensor_batch) | ||
batch = preprocess(image_tensor_batch.to(device)) | ||
feature = feature_extractor(batch) | ||
|
||
# feature_tensor_all.append(feature) | ||
feature_tensor_all.append(feature.detach().cpu().numpy()) | ||
|
||
|
||
# feature_tensor_all = torch.cat(feature_tensor_all, dim=0) | ||
feature_tensor_all = np.vstack(feature_tensor_all) | ||
feature_tensor_all = feature_tensor_all.reshape(feature_tensor_all.shape[0], -1) | ||
|
||
|
||
return feature_tensor_all | ||
|
||
|
||
def distance_matrix_gpu(feature_tensor, norm_type, device): | ||
|
||
num_vectors = feature_tensor.shape[0] | ||
distance_matrix = torch.zeros((num_vectors, num_vectors)).to(device) | ||
|
||
if norm_type == "l2": | ||
distance_matrix = torch.cdist(feature_tensor, feature_tensor, p=2) | ||
elif norm_type == "l1": | ||
distance_matrix = torch.cdist(feature_tensor, feature_tensor, p=1) | ||
elif norm_type == "inf": | ||
distance_matrix = torch.cdist(feature_tensor, feature_tensor, p=0) | ||
|
||
return distance_matrix | ||
|
||
def find_cloest_to_mean(feature_array, image_path_list, csv_path): | ||
# calculate the mean | ||
mean_feature = np.mean(feature_tensor_all, axis=0) | ||
|
||
# Calculate the Euclidean distance between each feature tensor and the mean | ||
distances = np.linalg.norm(feature_array - mean_feature, axis=1) | ||
|
||
# Find the index of the feature tensor that is closest to the mean | ||
closest_idx = np.argmin(distances) | ||
|
||
filename = image_path_list[closest_idx] | ||
|
||
with open(f'closest_to_mean.txt', 'a') as f: | ||
content = f"{os.path.basename(csv_path[:-4])}\t{IMAGE_TYPE}\t{MODEL}\t{NORM_TYPE}\t{distances[closest_idx]:4f}\t{filename}" | ||
|
||
f.write(content + '\n') | ||
|
||
if __name__ == "__main__": | ||
|
||
# change working dir to script location | ||
os.chdir(sys.path[0]) | ||
|
||
print(f"\n##################################################") | ||
print(f"ML Toolkits: {os.path.basename(os.getcwd())}") | ||
|
||
parser = argparse.ArgumentParser(description="Feature extraction and distance matrix calculation") | ||
|
||
parser.add_argument("--model", choices=["resnet18", "alexnet", "convnext", "vgg11", "vit", "dinov2"], default="resnet18", help="Choose the model (default: resnet18)") | ||
parser.add_argument("--image_type", choices=["heightmap", "texture"], default="texture", help="Choose the image type (default: texture)") | ||
parser.add_argument("--norm", choices=["l2"], default="l2", help="Choose the norm type (default: l2)") | ||
parser.add_argument("--csv", required=True, help="Path to the CSV file containing image information") | ||
parser.add_argument("--batch_size", type=int, default=10, help="Batch size for feature extraction (default: 10)") | ||
parser.add_argument("--data_root", default="/mnt/SSD_SATA_0/DATASET/LUA_Dataset", help="Root folder of the dataset (LUA_Dataset)") | ||
|
||
args = parser.parse_args() | ||
MODEL = args.model | ||
IMAGE_TYPE = args.image_type | ||
NORM_TYPE = args.norm | ||
CSV_PATH = args.csv | ||
BATCH_SIZE = args.batch_size | ||
DATA_FOLDER_ROOT = args.data_root | ||
|
||
# Check for GPU availability | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
print(f"Using GPU: {torch.cuda.is_available()}") | ||
else: | ||
device = torch.device("cpu") | ||
print(f"Using GPU: {torch.cuda.is_available()}") | ||
|
||
print(f"Model:\t\t{MODEL.upper()}") | ||
print(f"Image Type:\t{IMAGE_TYPE.upper()}") | ||
print(f'CSV FILE:\t{CSV_PATH}') | ||
print(f"Batch Size:\t{BATCH_SIZE}\n") | ||
|
||
|
||
|
||
# read csv file | ||
CSV_PATH = os.path.join(DATA_FOLDER_ROOT, CSV_PATH) | ||
image_path_list = csv_to_path(CSV_PATH) | ||
# image_path_list = [os.path.join(DATA_FOLDER_ROOT, i) for i in image_path_list] | ||
print(f"Number of Images: {len(image_path_list)}") | ||
|
||
################## CHANGE THIS PART FOR DIFFERENT MODEL ################## | ||
# feature extractor initialization | ||
if MODEL == "resnet18": | ||
feature_extractor, preprocess = feature_extractor_resnet18() | ||
elif MODEL == "alexnet": | ||
feature_extractor, preprocess = feature_extractor_alexnet() | ||
elif MODEL == "convnext": | ||
feature_extractor, preprocess = feature_extractor_convnext() | ||
elif MODEL == "vgg11": | ||
feature_extractor, preprocess = feature_extractor_vgg11() | ||
elif MODEL == "vit": | ||
feature_extractor, preprocess = feature_extractor_vit() | ||
elif MODEL == "dinov2": | ||
feature_extractor, preprocess = feature_extractor_dinov2() | ||
else: | ||
print("ERROR: MODEL NOT SUPPORTED") | ||
|
||
########################################################################## | ||
|
||
# perform batch feature extraction | ||
with torch.no_grad(): | ||
feature_tensor_all = batch_feature_extractor(feature_extractor.to(device), preprocess, image_path_list, BATCH_SIZE, device) | ||
|
||
# find the image that is cloest to the mean | ||
find_cloest_to_mean(feature_tensor_all, image_path_list, CSV_PATH) | ||
|
||
print(f"Feature Tensor: {feature_tensor_all.shape}") | ||
# feature_tensor_all = feature_tensor_all.reshape(feature_tensor_all.size(0), -1) | ||
feature_tensor_all = torch.from_numpy(feature_tensor_all).to(device) | ||
|
||
# calculate distance matrix | ||
distance_matrix = distance_matrix_gpu(feature_tensor_all, NORM_TYPE, device) | ||
print(f"Distance Matrix: {distance_matrix.shape}") | ||
|
||
# save distance matrix | ||
filename = f"{os.path.basename(CSV_PATH[:-4])}_{IMAGE_TYPE}_{MODEL}_{NORM_TYPE}.npy" | ||
print(f"\nSaving distance matrix to {filename}...") | ||
np.save(f"data/{filename}", distance_matrix.cpu().numpy()) | ||
|
||
print("Done\n") | ||
|
||
|
||
|
||
|
Oops, something went wrong.