Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing #13

Open
wants to merge 77 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
caa0925
personalized clustering
Dec 30, 2020
5bcfa6b
refactor testing
Dec 30, 2020
07c4a9b
refactor testing
Dec 30, 2020
3cb2301
refactor testing
Dec 30, 2020
9c8c11f
refactor testing
Dec 30, 2020
fa019f6
refactor testing
Dec 30, 2020
58d744b
refactor testing
Dec 30, 2020
3fdc1d7
refactor testing
Dec 30, 2020
b1f3784
refactor testing
Dec 30, 2020
14295de
refactor testing
Dec 30, 2020
fb502d6
refactor testing
Dec 31, 2020
9b6a696
refactor testing
Dec 31, 2020
d59c54b
refactor testing
Dec 31, 2020
a7f8fd5
base
Jan 1, 2021
3060251
base
Jan 1, 2021
ae5cbdd
base
Jan 1, 2021
80e31d0
base
Jan 1, 2021
888bbd7
base
Jan 1, 2021
0516bd5
base
Jan 1, 2021
afaaa94
base
Jan 1, 2021
da2752b
base
Jan 1, 2021
0539287
base
Jan 1, 2021
0fc4e3f
base
Jan 1, 2021
b206465
base
Jan 1, 2021
ae84072
base
Jan 1, 2021
a89d6df
base
Jan 1, 2021
7650a9a
base
Jan 1, 2021
84ab0a9
base
Jan 1, 2021
9df9bb0
base
Jan 1, 2021
91ff95b
base
Jan 2, 2021
86c943d
base
Jan 2, 2021
f7a7296
base
Jan 2, 2021
28223d6
base
Jan 4, 2021
6e2245c
sh
Jan 4, 2021
80f15c0
kd for clustering
Jan 4, 2021
f499816
kd for clustering
Jan 4, 2021
a7d90f7
name
Jan 4, 2021
dd6e171
add cpk
Jan 5, 2021
2c2c670
add cpk
Jan 5, 2021
991e430
add cpk
Jan 7, 2021
e9bfdce
cdw 41
Jan 8, 2021
a2059bb
cdw 41
Jan 8, 2021
10208fc
kd whole
Jan 12, 2021
ceb4a21
modify
Jan 12, 2021
9247b24
cdw kd
Jan 12, 2021
4359813
cdw kd
Jan 12, 2021
92e205c
finch with distance and kmeans
Jan 22, 2021
86dd906
finch with dis
Jan 22, 2021
3234f1b
fix
Jan 22, 2021
70c6bf4
fix kmeans
Jan 22, 2021
ba625d4
change name
Feb 2, 2021
06fcc54
sh file
Feb 2, 2021
e79bb5a
setting
codergan Apr 29, 2021
2cc09b4
modify
codergan Apr 29, 2021
0322a93
modify
codergan Apr 29, 2021
24ec9ed
modify
codergan Apr 29, 2021
c365700
error shown in log
codergan Apr 30, 2021
8b18003
testing on remove load and store mat
codergan May 5, 2021
8cf73a7
testing on remove load and store mat, test each round
codergan May 5, 2021
9436650
testing on remove load and store mat, test each round
codergan May 5, 2021
f69d872
testing on remove load and store mat, test each round
codergan May 5, 2021
27c7b71
testing on remove load and store mat, test each round
codergan May 5, 2021
0a442c2
testing on remove load and store mat, test each round
codergan May 5, 2021
f2f25fb
testing on remove load and store mat, test each round
codergan May 5, 2021
ac471d7
testing on remove load and store mat, test each round
codergan May 5, 2021
5358a85
testing on remove load and store mat, test each round
codergan May 5, 2021
19cca3d
testing on remove load and store mat, test each round
codergan May 5, 2021
96c1d7d
testing on remove load and store mat, test each round
codergan Apr 29, 2021
dde1805
final resolved testing
codergan May 5, 2021
e3c4041
merge
codergan May 5, 2021
dae533e
sh file
codergan May 6, 2021
2170d4e
modify
codergan May 6, 2021
ccea2df
modify
codergan May 6, 2021
5f657e9
modify
codergan May 7, 2021
0529bc2
ignore
codergan Oct 24, 2021
a4dc21e
test code draft
codergan Oct 24, 2021
4eff54f
testing file
codergan Oct 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.idea/
__pycache__/
.DS_Store
51 changes: 31 additions & 20 deletions client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import copy
from optimization import Optimization
class Client():
def __init__(self, cid, data, device, project_dir, model_name, local_epoch, lr, batch_size, drop_rate, stride):
def __init__(self, cid, data, device, project_dir, model_name, local_epoch, lr, batch_size, drop_rate, stride, clustering=False):
self.cid = cid
self.project_dir = project_dir
self.model_name = model_name
Expand All @@ -21,25 +21,30 @@ def __init__(self, cid, data, device, project_dir, model_name, local_epoch, lr,
self.dataset_sizes = self.data.train_dataset_sizes[cid]
self.train_loader = self.data.train_loaders[cid]

self.full_model = get_model(self.data.train_class_sizes[cid], drop_rate, stride)
self.classifier = self.full_model.classifier.classifier
self.full_model.classifier.classifier = nn.Sequential()
self.model = self.full_model
self.distance=0
self.model = get_model(self.data.train_class_sizes[cid], drop_rate, stride)
self.classifier = copy.deepcopy(self.model.classifier.classifier)
self.model.classifier.classifier = nn.Sequential()
self.distance = 0
self.optimization = Optimization(self.train_loader, self.device)
self.use_clustering = clustering
# print("class name size",class_names_size[cid])

def train(self, federated_model, use_cuda):
def train(self, federated_model=None, use_cuda=False):
self.y_err = []
self.y_loss = []

self.model.load_state_dict(federated_model.state_dict())
if self.use_clustering:
print("using clustering, model is set before")
assert federated_model is None
# self.model.classifier.classifier = nn.Sequential()
federated_model = copy.deepcopy(self.model)
else:
self.model.load_state_dict(federated_model.state_dict())
self.model.classifier.classifier = self.classifier
self.old_classifier = copy.deepcopy(self.classifier)
self.model = self.model.to(self.device)

self.model.train(True)
optimizer = get_optimizer(self.model, self.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
# scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)

criterion = nn.CrossEntropyLoss()

Expand All @@ -50,8 +55,7 @@ def train(self, federated_model, use_cuda):
print('Epoch {}/{}'.format(epoch, self.local_epoch - 1))
print('-' * 10)

scheduler.step()
self.model.train(True)
# scheduler.step()
running_loss = 0.0
running_corrects = 0.0

Expand All @@ -60,12 +64,12 @@ def train(self, federated_model, use_cuda):
b, c, h, w = inputs.shape
if b < self.batch_size:
continue
if use_cuda:
inputs = Variable(inputs.cuda().detach())
labels = Variable(labels.cuda().detach())
else:
inputs, labels = Variable(inputs), Variable(labels)

# if use_cuda:
# inputs = Variable(inputs.cuda().detach())
# labels = Variable(labels.cuda().detach())
# else:
# inputs, labels = Variable(inputs), Variable(labels)
inputs, labels = inputs.to(self.device), labels.to(self.device)
optimizer.zero_grad()

outputs = self.model(inputs)
Expand Down Expand Up @@ -106,6 +110,9 @@ def train(self, federated_model, use_cuda):
def generate_soft_label(self, x, regularization):
return self.optimization.kd_generate_soft_label(self.model, x, regularization)

def generate_custom_data_feature(self, inputs):
return self.optimization.generate_custom_data_feature(self.model, inputs)

def get_model(self):
return self.model

Expand All @@ -116,4 +123,8 @@ def get_train_loss(self):
return self.y_loss[-1]

def get_cos_distance_weight(self):
return self.distance
return self.distance

def set_model(self, model):
self.model = copy.deepcopy(model)

12 changes: 7 additions & 5 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __len__(self):
return len(self.imgs)

def __getitem__(self, index):
data,label = self.imgs[index]
data, label = self.imgs[index]
return self.transform(Image.open(data)), label


Expand Down Expand Up @@ -95,6 +95,10 @@ def preprocess_train(self):

print('Train dataset sizes:', self.train_dataset_sizes)
print('Train class sizes:', self.train_class_sizes)
if "cuhk02" in self.datasets:
#cuhk02 is not labeled, we only use it for feature extraction in clustering
self.datasets.remove("cuhk02")
self.client_list.remove("cuhk02")

def preprocess_test(self):
"""preprocess testing data, constructing test loaders
Expand All @@ -103,10 +107,8 @@ def preprocess_test(self):
self.gallery_meta = {}
self.query_meta = {}

for test_dir in self.datasets:
test_dir = 'data/'+test_dir+'/pytorch'

dataset = test_dir.split('/')[1]
for dataset in self.datasets:
test_dir = os.path.join(self.data_dir, dataset, "pytorch")
gallery_dataset = datasets.ImageFolder(os.path.join(test_dir, 'gallery'))
query_dataset = datasets.ImageFolder(os.path.join(test_dir, 'query'))

Expand Down
65 changes: 36 additions & 29 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import torch
import numpy as np
import os
import argparse
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--result_dir', default='.', type=str)
parser.add_argument('--dataset', default='no_dataset', type=str)
args = parser.parse_args()
# import argparse
# parser = argparse.ArgumentParser(description='Training')
# parser.add_argument('--result_dir', default='.', type=str)
# parser.add_argument('--dataset', default='no_dataset', type=str)
# args = parser.parse_args()

#######################################################################
# Evaluate
Expand Down Expand Up @@ -60,32 +60,39 @@ def compute_mAP(index, good_index, junk_index):

return ap, cmc

######################################################################
result = scipy.io.loadmat(args.result_dir + '/pytorch_result.mat')

query_feature = torch.FloatTensor(result['query_f'])
query_cam = result['query_cam'][0]
query_label = result['query_label'][0]
gallery_feature = torch.FloatTensor(result['gallery_f'])
gallery_cam = result['gallery_cam'][0]
gallery_label = result['gallery_label'][0]
def testing_model(result, dataset):
# result = scipy.io.loadmat(file_path)
# print("========= after loading ==========")
# for i in result:
# print(i, np.array(result[i]).shape)

query_feature = query_feature.cuda()
gallery_feature = gallery_feature.cuda()
query_feature = torch.FloatTensor(result['query_f'])
query_cam = np.array(result['query_cam'])
query_label = np.array(result['query_label'])
gallery_feature = torch.FloatTensor(result['gallery_f'])
gallery_cam = np.array(result['gallery_cam'])
gallery_label = np.array(result['gallery_label'])
# print(type(query_feature),query_feature[:3])
# print(type(query_cam),query_cam[:3])
# print(type(query_label),query_label[:3])

print(query_feature.shape)
CMC = torch.IntTensor(len(gallery_label)).zero_()
ap = 0.0
query_feature = query_feature.cuda()
gallery_feature = gallery_feature.cuda()

for i in range(len(query_label)):
ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], query_cam[i], gallery_feature, gallery_label, gallery_cam)
if CMC_tmp[0]==-1:
continue
CMC = CMC + CMC_tmp
ap += ap_tmp
print(query_feature.shape)
CMC = torch.IntTensor(len(gallery_label)).zero_()
ap = 0.0

CMC = CMC.float()
CMC = CMC/len(query_label) #average CMC
print(args.dataset+' Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0], CMC[4], CMC[9], ap/len(query_label)))
print('-'*15)
print()
for i in range(len(query_label)):
ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], query_cam[i], gallery_feature, gallery_label, gallery_cam)
if CMC_tmp[0]==-1:
continue
CMC = CMC + CMC_tmp
ap += ap_tmp

CMC = CMC.float()
CMC = CMC/len(query_label) #average CMC
print(dataset+' Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0], CMC[4], CMC[9], ap/len(query_label)))
print('-'*15)
print()
173 changes: 173 additions & 0 deletions finch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import time
import argparse
import numpy as np
from sklearn import metrics
import scipy.sparse as sp
import warnings

try:
from pyflann import *

pyflann_available = True
except Exception as e:
warnings.warn('pyflann not installed: {}'.format(e))
pyflann_available = False
pass

RUN_FLANN = 70000


def clust_rank(mat, initial_rank=None, distance='cosine'):
s = mat.shape[0]
if initial_rank is not None:
orig_dist = []
elif s <= RUN_FLANN:
orig_dist = metrics.pairwise.pairwise_distances(mat, mat, metric=distance)
np.fill_diagonal(orig_dist, 1000.0)
initial_rank = np.argmin(orig_dist, axis=1)
else:
if not pyflann_available:
raise MemoryError("You should use pyflann for inputs larger than {} samples.".format(RUN_FLANN))
print('Using flann to compute 1st-neighbours at this step ...')
flann = FLANN()
result, dists = flann.nn(mat, mat, num_neighbors=2, algorithm="kdtree", trees=8, checks=128)
initial_rank = result[:, 1]
orig_dist = []
print('Step flann done ...')

# The Clustering Equation
A = sp.csr_matrix((np.ones_like(initial_rank, dtype=np.float32), (np.arange(0, s), initial_rank)), shape=(s, s))
A = A + sp.eye(s, dtype=np.float32, format='csr')
A = A @ A.T

A = A.tolil()
A.setdiag(0)
return A, orig_dist


def get_clust(a, orig_dist, min_sim=None):
if min_sim is not None:
a[np.where((orig_dist * a.toarray()) > min_sim)] = 0

num_clust, u = sp.csgraph.connected_components(csgraph=a, directed=True, connection='weak', return_labels=True)
return u, num_clust


def cool_mean(M, u):
_, nf = np.unique(u, return_counts=True)
idx = np.argsort(u)
M = M[idx, :]
M = np.vstack((np.zeros((1, M.shape[1])), M))

np.cumsum(M, axis=0, out=M)
cnf = np.cumsum(nf)
nf1 = np.insert(cnf, 0, 0)
nf1 = nf1[:-1]

M = M[cnf, :] - M[nf1, :]
M = M / nf[:, None]
return M


def get_merge(c, u, data):
if len(c) != 0:
_, ig = np.unique(c, return_inverse=True)
c = u[ig]
else:
c = u

mat = cool_mean(data, c)
return c, mat


def update_adj(adj, d):
# Update adj, keep one merge at a time
idx = adj.nonzero()
v = np.argsort(d[idx])
v = v[:2]
x = [idx[0][v[0]], idx[0][v[1]]]
y = [idx[1][v[0]], idx[1][v[1]]]
a = sp.lil_matrix(adj.get_shape())
a[x, y] = 1
return a


def req_numclust(c, data, req_clust, distance):
iter_ = len(np.unique(c)) - req_clust
c_, mat = get_merge([], c, data)
for i in range(iter_):
adj, orig_dist = clust_rank(mat, initial_rank=None, distance=distance)
adj = update_adj(adj, orig_dist)
u, _ = get_clust(adj, [], min_sim=None)
c_, mat = get_merge(c_, u, data)
return c_


def FINCH(data, min_sim=None, initial_rank=None, req_clust=None, distance='cosine', verbose=True):
""" FINCH clustering algorithm.
:param data: Input matrix with features in rows.
:param initial_rank: Nx1 first integer neighbor indices (optional).
:param req_clust: Set output number of clusters (optional). Not recommended.
:param distance: One of ['cityblock', 'cosine', 'euclidean', 'l1', 'l2', 'manhattan'] Recommended 'cosine'.
:param verbose: Print verbose output.
:return:
c: NxP matrix where P is the partition. Cluster label for every partition.
num_clust: Number of clusters.
req_c: Labels of required clusters (Nx1). Only set if `req_clust` is not None.
The code implements the FINCH algorithm described in our CVPR 2019 paper
Sarfraz et al. "Efficient Parameter-free Clustering Using First Neighbor Relations", CVPR2019
https://arxiv.org/abs/1902.11266
For academic purpose only. The code or its re-implementation should not be used for commercial use.
Please contact the author below for licensing information.
Copyright
M. Saquib Sarfraz ([email protected])
Karlsruhe Institute of Technology (KIT)
"""
# Cast input data to float32
data = data.astype(np.float32)

# min_sim = None
adj, orig_dist = clust_rank(data, initial_rank, distance)
initial_rank = None
group, num_clust = get_clust(adj, [], min_sim)
c, mat = get_merge([], group, data)

if verbose:
print('Partition 0: {} clusters'.format(num_clust))
if len(orig_dist) != 0:
min_sim = np.max(orig_dist * adj.toarray())

exit_clust = 2
c_ = c
k = 1
num_clust = [num_clust]

while exit_clust > 1:
adj, orig_dist = clust_rank(mat, initial_rank, distance)
u, num_clust_curr = get_clust(adj, orig_dist, min_sim)
c_, mat = get_merge(c_, u, data)

num_clust.append(num_clust_curr)
c = np.column_stack((c, c_))
exit_clust = num_clust[-2] - num_clust_curr

if num_clust_curr == 1 or exit_clust < 1:
num_clust = num_clust[:-1]
c = c[:, :-1]
break

if verbose:
print('Partition {}: {} clusters'.format(k, num_clust[k]))
k += 1

if req_clust is not None:
if req_clust not in num_clust:
ind = [i for i, v in enumerate(num_clust) if v >= req_clust]
req_c = req_numclust(c[:, ind[-1]], data, req_clust, distance)
else:
req_c = c[:, num_clust.index(req_clust)]
else:
req_c = None

return c, num_clust, req_c

Loading