Commit a7d4a324 authored by IlyaOvodov's avatar IlyaOvodov

9ea497 updated distance weighted sampling

parent 6419986d
......@@ -114,7 +114,7 @@ class LogTrainingResults:
str = "Epoch:{}.{}\t".format(engine.state.epoch, engine.state.iteration)
else:
str = "Epoch:{}\t".format(engine.state.epoch)
str += '\t'.join(['{}:{:.3f}'.format(k,v) for k,v in engine.state.metrics.items()])
str += '\t'.join(['{}:{:.5f}'.format(k,v) for k,v in engine.state.metrics.items()])
print(str)
with open(self.params.get_base_filename() + '.log', 'a') as f:
f.write(str + '\n')
......
......@@ -23,7 +23,7 @@ class MarginBaseLoss:
margin based loss with distance weighted sampling https://arxiv.org/pdf/1706.07567.pdf
'''
ignore_index = -100
def __init__(self, model, classes, device, params, timer = DummyTimer()):
def __init__(self, model, classes, device, params):
assert params.data.samples_per_class >= 2
self.model = model
self.device = device
......@@ -31,7 +31,7 @@ class MarginBaseLoss:
self.classes = sorted(classes)
self.classes_dict = {v: i for i, v in enumerate(self.classes)}
self.lambda_rev = 1/params.distance_weighted_sampling.lambda_
self.timer = timer
self.timer = DummyTimer()
print('classes: ', len(self.classes))
def set_timer(self, timer):
......@@ -44,13 +44,11 @@ class MarginBaseLoss:
with self.timer.watch('time.l2_loss'):
pred_class = net_output[0]
class_nos = self.classes_to_ids(y_class, ignore_index=self.ignore_index)
return torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index)(pred_class, class_nos)
def D(self, pred_embeddings, i ,j):
if i == j:
return 0
return torch.dist(pred_embeddings[i], pred_embeddings[j]).item()
self.l2_loss_val = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index)(pred_class, class_nos)
return self.l2_loss_val
def last_l2_loss(self, net_output, y_class):
return self.l2_loss_val
def mb_loss(self, net_output, y_class):
with self.timer.watch('time.mb_loss'):
......@@ -58,29 +56,37 @@ class MarginBaseLoss:
loss = 0
n = len(pred_embeddings) # samples in batch
dim = pred_embeddings[0].shape[0] # dimensionality
with self.timer.watch('time.d_ij'):
assert len(pred_embeddings.shape) == 2, pred_embeddings.shape
norm = (pred_embeddings ** 2).sum(1)
self.d_ij = norm.view(-1, 1) + norm.view(1, -1) - 2.0 * torch.mm(pred_embeddings, torch.transpose(pred_embeddings, 0, 1))
self.d_ij = torch.sqrt(torch.clamp(self.d_ij, min=0.0) + 1.0e-8)
for i_start in range(0, n, self.params.data.samples_per_class): # start of class block
i_end = i_start + self.params.data.samples_per_class # start of class block
for i in range(i_start, i_end -1):
with self.timer.watch('time.d_ij'):
d_ij = [0 if i==j else self.D(pred_embeddings, i, j) for j in range(n)]
weights = [1/max(self.lambda_rev, pow(d,dim-2)*pow(1-d*d/4, (dim-3)/2)) # https://arxiv.org/pdf/1706.07567.pdf
for id, d in enumerate(d_ij) if id != i] # dont join with itself
weights_same = np.asarray(weights[i_start: i_end-1]) # i-th element already excluded
j = np.random.choice(range(i_start, i_end-1), p = weights_same/np.sum(weights_same) )
if j >= i:
j += 1
# for j in range(i+1, i_end): # positive pair
loss += (self.params.mb_loss.alpha + (d_ij[j] - self.model.mb_loss_beta)).clamp(min=0)
for i in range(i_start, i_end):
d = self.d_ij[i,:].detach()
prob = torch.exp(-(d - 1.4142135623730951)**2 * dim) #https://arxiv.org/pdf/1706.07567.pdf
weights = (1/prob.clamp(min = self.lambda_rev)).cpu().numpy()
weights[i] = 0 # dont join with itself
# select positive pair
weights_same = weights[i_start: i_end] # i-th element already excluded
j = np.random.choice(range(i_start, i_end), p = weights_same/np.sum(weights_same) )
assert j != i
loss += (self.params.mb_loss.alpha + (self.d_ij[i,j] - self.model.mb_loss_beta)).clamp(min=0) #https://arxiv.org/pdf/1706.07567.pdf
# select neg. pait
weights[i_start: i_end - 1] = [] # i-th element already excluded
weights = np.asarray(weights)
weights = weights/np.sum(weights)
k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights)
weights = np.delete(weights, np.s_[i_start: i_end], axis=0)
k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights/np.sum(weights))
if k >= i_start:
k += self.params.data.samples_per_class
loss += (self.params.mb_loss.alpha - (d_ij[k] - self.model.mb_loss_beta)).clamp(min=0)
return loss[0] / len(pred_embeddings)
loss += (self.params.mb_loss.alpha - (self.d_ij[i,k] - self.model.mb_loss_beta)).clamp(min=0) #https://arxiv.org/pdf/1706.07567.pdf
self.mb_loss_val = loss[0] / len(pred_embeddings)
return self.mb_loss_val
def last_mb_loss(self, net_output, y_class):
return self.mb_loss_val
def loss(self, net_output, y_class):
return self.l2_loss(net_output, y_class) + self.mb_loss(net_output, y_class)
self.loss_val = self.l2_loss(net_output, y_class) + self.mb_loss(net_output, y_class)
return self.loss_val
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment