PyTorch implementation to learn a domain invariant representation for cross-domain fashion item retrieval, for example:
- Tested on Linux: Ubuntu 16.04.1, and CUDA-9.0
mkdir UDMA_Codebase && cd UDMA_Codebase
- Clone this repo:
git clone https://github.com/almazan/deep-image-retrieval.git
cd deep-image-retrieval/dirtorch
mkdir data && cd data
- Download the pre-trained model: Resnet101-TL-GeM model available from Deep Image Retrieval and save it in the
data
folder.
wget https://cvhci.anthropomatik.kit.edu/~datasets-publisher/published_datasets/domain_adaption/UDMA/dirtorch_model/Resnet101-TL-GeM.pt
- setup vitual environment
cd ../..
python3 -m venv venv
source venv/bin/activate
- Install packages
pip install torch==1.3.1 torchvision==0.4.2
pip install pytorch-metric-learning==0.9.65
pip install pandas==0.25.3
pip install h5py==2.10.0 matplotlib==3.1.2
pip install opencv-python==4.1.2.30
- Clone the
UDMA
repo indeep-image-retrieval
>> PYTHONPATH=$PYTHONPATH:$PWD
>> git clone https://github.com/vivoutlaw/UDMA.git
>> cd UDMA
Download the pre-extracted features for both DeepFashion
and Street2Shop
datasets using the following script. The feature were extracted using our pre-trained model on DeepFashion. Trainset (ClusteringFeats
) and Testset (Street2Shop
| DeepFashion
)
bash ./features/download_dataset.sh Street2Shop
bash ./features/download_dataset.sh DeepFashion
bash ./features/download_dataset.sh ClusteringFeats
Feature extraction using pre-trained UDMA-MLP model
:
>> CUDA_VISIBLE_DEVICES=0 python test_mlp.py --WS=WS5 --model-name=DeepFashion --comb=L12 --optimizer=ADAM --eval-dataset=Street2Shop --load-epoch=45000 --batch-size=2000 --resume --finch-part=0
Quantitative Results: DF-BL, UDMA-MLP
:
>> cd evaluation_scripts
>> eval_final_s2s_retrieval('Street2Shop', 'DeepFashion_ADAM_ALL', 60, 'X', 'regular') % DF-BL
mAP = 0.2297, r1 precision = 0.3297, r5 precision = 0.4508, r10 precision = 0.4921, r20 precision = 0.5381, r50 precision = 0.5939
>> eval_final_mlp_s2s_retrieval('Street2Shop', 'DeepFashion_ADAM_ALL', 60 , 'X', 'regular', 'L12_0_WS5', 45000) % UDMA-MLP
mAP = 0.2451, r1 precision = 0.3612, r5 precision = 0.4774, r10 precision = 0.5270, r20 precision = 0.5639, r50 precision = 0.6240
Weighting strategy used for MLP training
:
>> python -W ignore weighting_strategy_part1.py --finch-part=0
>> python -W ignore weighting_strategy_part1.py --finch-part=0
>> python -W ignore weighting_strategy_part2.py --comb=L12 --optimizer=ADAM --finch-part=0
- Download the pre-trained model trained on DeepFashion (
trainval set
) using Resnet101-TL-GeM model.
bash ./models/download_our_DF_model.sh DeepFashion
Script for UDMA-MLP training
:
>> CUDA_VISIBLE_DEVICES=0 python -W ignore train_mlp.py --WS=WS5 --dataset=DeepFashion --comb=L12 --optimizer=ADAM --num-threads=8 --batch-size=128 --lr=1e-4 --resume-df --load-epoch-df=60 --epochs=45000 --finch-part=0 --batch-category-size=12
Finetuning Full Resnet101-TL-GeM model on DeepFashion dataset
In our work, we use both train
and val
set of DeepFashion for model train, and tested on test
set.
- Change
train_test_type
fromtrainval
totrain
for training the model only ontrain
set. Quantitative Results: DF test set
:
>> cd evaluation_scripts
>> eval_df_retrieval('DeepFashion', 'DeepFashion_ADAM_ALL', 60, 'X', 'regular') % DF test set
mAP = 0.3075, r1 precision = 0.3107, r5 precision = 0.5209, r10 precision = 0.5994, r20 precision = 0.6712, r50 precision = 0.7603
Script for training the full-model
:
>> CUDA_VISIBLE_DEVICES=0,1,2,3 python main_train_df.py --dataset=DeepFashion --df-comb=ALL --optimizer=ADAM --num-threads=8 --batch-size=128 --lr=1e-4 --epochs=60 --checkpoint=../dirtorch/data/Resnet101-TL-GeM.pt
- After the model is trained, we use the last
fc
layer of this model forUDMA-MLP
. - Optional.
Script for feature extraction
: Download DeepFashion and Street2Shop datasets. For bounding boxes ofStreet2Shop
, please seedataset_files
. Also modifypath_to_images_
with correct path to images.
# Train Set (fc feats) ############### DeepFashion evaluation and Street2Shop Features.
CUDA_VISIBLE_DEVICES=0 python main_extract_train_feats.py --model-name=DeepFashion --df-comb=ALL --optimizer=ADAM --eval-dataset=DeepFashion --load-epoch=60 --batch-size=256 --resume --layer=X
CUDA_VISIBLE_DEVICES=0 python main_extract_train_feats.py --model-name=DeepFashion --df-comb=ALL --optimizer=ADAM --eval-dataset=Street2Shop --load-epoch=60 --batch-size=256 --resume --layer=X
# Train Set (GEM normalized feats) -- Commment lines after "x.squeeze_()" in dirtorch/nets/rmac_resnext.py
CUDA_VISIBLE_DEVICES=0 python main_extract_train_feats.py --model-name=DeepFashion --df-comb=ALL --optimizer=ADAM --eval-dataset=DeepFashion --load-epoch=60 --batch-size=256 --resume --layer=X-1
CUDA_VISIBLE_DEVICES=0 python main_extract_train_feats.py --model-name=DeepFashion --df-comb=ALL --optimizer=ADAM --eval-dataset=Street2Shop --load-epoch=60 --batch-size=256 --resume --layer=X-1
# Test Set (fc feats) ############### DeepFashion evaluation and Street2Shop Features.
CUDA_VISIBLE_DEVICES=0 python main_extract_test_feats.py --model-name=DeepFashion --df-comb=ALL --optimizer=ADAM --eval-dataset=DeepFashion --load-epoch=60 --batch-size=256 --resume --layer=X
CUDA_VISIBLE_DEVICES=0 python main_extract_test_feats.py --model-name=DeepFashion --df-comb=ALL --optimizer=ADAM --eval-dataset=Street2Shop --load-epoch=60 --batch-size=256 --resume --layer=X
# Test Set (GEM normalized feats) -- Commment lines after "x.squeeze_()" in dirtorch/nets/rmac_resnext.py
CUDA_VISIBLE_DEVICES=0 python main_extract_test_feats.py --model-name=DeepFashion --df-comb=ALL --optimizer=ADAM --eval-dataset=DeepFashion --load-epoch=60 --batch-size=256 --resume --layer=X-1
CUDA_VISIBLE_DEVICES=0 python main_extract_test_feats.py --model-name=DeepFashion --df-comb=ALL --optimizer=ADAM --eval-dataset=Street2Shop --load-epoch=60 --batch-size=256 --resume --layer=X-1
If you find the code and datasets useful in your research, please cite:
@inproceedings{udma,
author = {Vivek Sharma, Naila Murray, Diane Larlus, M. Saquib Sarfraz, Rainer Stiefelhagen, and Gabriela Csurka},
title = {Unsupervised Meta-Domain Adaptation for Fashion Retrieval},
booktitle = {WACV},
year = {2021}
}
@inproceedings{finch,
author = {M. Saquib Sarfraz, Vivek Sharma and Rainer Stiefelhagen},
title = {Efficient Parameter-free Clustering Using First Neighbor Relations},
booktitle = {CVPR},
year = {2019}
}
@inproceedings{kucer,
title={A detect-then-retrieve model for multi-domain fashion item retrieval},
author={Kucer, Michal and Murray, Naila},
booktitle={CVPR Workshops},
year={2019}
}