Skip to content

Commit

Permalink
checking n-pair
Browse files Browse the repository at this point in the history
  • Loading branch information
estija committed Oct 27, 2020
1 parent c48d392 commit f0ae718
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 17 deletions.
55 changes: 42 additions & 13 deletions loss/embedding_aug_mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,29 @@ def get_min_dis(F, dis0,label,a1l,a2l):
def get_sum_exp_dis(F, dis0,label,a1l,a2l):
k=0
for l in range(label.shape[0]):
id1=[i for i in range(a1l.shape[0]) if a1l[i]==label[l] ]
id2=[i for i in range(a2l.shape[0]) if a2l[i]==label[l] ]
id1=(a1l==label[l]) #[i for i in range(a1l.shape[0]) if a1l[i]==label[l] ]
id2=(a2l==label[l]) #[i for i in range(a2l.shape[0]) if a2l[i]==label[l] ]

if len(id1)>0 or len(id2)>0:
if len(id1)<1:
dist=F.sum(F.exp(-dis0[id2]))
elif len(id2)<1:
dist=F.sum(F.exp(-dis0[id1]))
if F.sum(id1)>0 or F.sum(id2)>0:
if F.sum(id1)<1:
dist=F.sum(F.exp(-F.contrib.boolean_mask(dis0,id2)))
num=F.sum(id2)
elif F.sum(id2)<1:
dist=F.sum(F.exp(-F.contrib.boolean_mask(dis0,id1)))
num=F.sum(id1)
else:
dist=F.sum(F.concat(F.sum(F.exp(-dis0[id1])), F.sum(F.exp(-dis0[id2])), dim=0))
dist=F.sum(F.concat(F.sum(F.exp(-F.contrib.boolean_mask(dis0,id1))), F.sum(F.exp(-F.contrib.boolean_mask(dis0,id2))), dim=0))
num=F.sum(id1)+F.sum(id2)

if k==0:
k=k+1
dis=dist
numf=num
else:
dis=F.concat(dis,dist,dim=0)
numf=F.concat(numf,num,dim=0)

return dis
return dis,numf

def get_pos_dis(F, dis_ap, labelsorg):
N = dis_ap.shape[0]
Expand Down Expand Up @@ -252,9 +257,27 @@ def pair_mining(F, dis_ap, dis_an, ids, a1l, a2l, ind, labels, num_ins, th, alph

#print('COUNT FOR emptiness per pair...', float(count)/float(len(ids)))
return dis_neg, dis_pos


def check_corners(F,X1,X2,X3,X4):
dis13=F.expand_dims(F.sum(X1*X3, axis=0),axis=1)
dis14=F.expand_dims(F.sum(X1*X4, axis=0),axis=1)
dis23=F.expand_dims(F.sum(X2*X3, axis=0),axis=1)
dis24=F.expand_dims(F.sum(X2*X4, axis=0),axis=1)

dis1=dis13+7*(F.sign(dis13-dis14)**2-1+F.sign(dis13-dis23)**2-1+F.sign(dis13-dis24)**2-1)
dis2=dis14+8*(F.sign(dis14-dis23)**2-1+F.sign(dis14-dis24)**2-1)
dis3=dis23+9*(F.sign(dis23-dis24)**2-1)
dis4=dis24

dis=F.concat(dis1,dis2,dis3,dis4,dim=1)
#print(dis)
dis=F.max(dis,axis=1)
#print(dis)
return dis


def get_opt_emb_dis(F, embeddings, labels, num_instance, l2_norm=True, multisim=False):
def get_opt_emb_dis(F, embeddings, labels, num_instance, l2_norm=True, multisim=False, npair=False):
batch_size = embeddings.shape[0]
dim=embeddings.shape[1]

Expand Down Expand Up @@ -282,8 +305,8 @@ def get_opt_emb_dis(F, embeddings, labels, num_instance, l2_norm=True, multisim=
X2l=X2l[ind]

if len(ind)<2 or len(indx)<2:
print('============================YESS=======================================')
print('Similarities...', sim)
#print('============================YESS=======================================')
#print('Similarities...', sim)
ind=[i for i in range(sim.shape[0])]

if num_instance==2:
Expand All @@ -297,7 +320,10 @@ def get_opt_emb_dis(F, embeddings, labels, num_instance, l2_norm=True, multisim=
print('dis_ap', dis_ap1)

else:
dis_ap1 = F.sum(X1*X2, axis=1) #num_ins=2 for n-pair
if npair:
dis_ap1 = F.sum(X1*X2, axis=1) #num_ins=2 for n-pair
else:
dis_ap1 = F.sqrt(F.sum((X1-X2)*(X1-X2), axis=1)+1e-20)

X1, X2, X3, X4, a1l, a2l, ids = concat(F,X1,X2,X1l,X2l)

Expand All @@ -309,6 +335,9 @@ def get_opt_emb_dis(F, embeddings, labels, num_instance, l2_norm=True, multisim=
dis = F.sqrt(F.sum((X1-X3)*(X1-X3), axis=1)+1e-20)

else:
if npair:
dis = check_corners(F,F.transpose(X1), F.transpose(X2), F.transpose(X3), F.transpose(X4))
else:
dis = opt_pts_lin(F.transpose(X1), F.transpose(X2), F.transpose(X3), F.transpose(X4))

#dis_an = get_min_dis(F, dis, ids, a1l, a2l) #for hphn-triplet
Expand Down
13 changes: 10 additions & 3 deletions loss/ml_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,18 @@ def hybrid_forward(self, F, embeddings, labels):
self.batch_size = embeddings.shape[0]

gen_start_time = time.time()
dist_ap, dist_an0, ids, a1l, a2l = get_opt_emb_dis(F, embeddings, labels, self.num_instance, self.l2_norm)
dist_an = get_sum_exp_dis(F, -dist_an0, ids, a1l, a2l)
dist_ap, dist_an0, ids, a1l, a2l = get_opt_emb_dis(F, embeddings, labels, self.num_instance, l2_norm=False,npair=True )
dist_an, numf = get_sum_exp_dis(F, -dist_an0, ids, a1l, a2l)
print(dist_an)
print(F.exp(-dist_ap))
gen_time = time.time() - gen_start_time

X1=embeddings[0:self.batch_size:2]
X2=embeddings[1:self.batch_size:2]
l2_reg=F.sum(X1*X1, axis = 1)+F.sum(X2*X2, axis = 1)
print(F.sqrt(F.sum((X1-X2)**2, axis = 1)))

loss = F.log(1.0 + F.exp(-dist_ap)*dist_an)
loss = F.log(1.0 + F.exp(-dist_ap)*dist_an) #+ 0.000075*l2_reg

total_time = time.time() - total_start_time

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def main():
# evaluate_and_log(summary_writer, evaluator, ranks, step, epoch, best_metrics)
best_metrics = evaluate_and_log(summary_writer, evaluator, args.recallk,
global_step, epoch + 1,
best_metrics=best_metrics, args.data_name)
best_metrics=best_metrics, data_name=args.data_name)
if best_metrics[0] != old_best_metric:
save_path = os.path.join(args.save_dir, 'model_epoch_%05d.params' % (epoch + 1))
model.save_parameters(save_path)
Expand Down

0 comments on commit f0ae718

Please sign in to comment.