Commit 1c24e674 authored by IlyaOvodov's avatar IlyaOvodov

MBloss without replacement, fix LogTrainingResults, TensorboardLogger

parent a7d4a324
...@@ -103,13 +103,12 @@ class LogTrainingResults: ...@@ -103,13 +103,12 @@ class LogTrainingResults:
self.params = params self.params = params
def __call__(self, engine, event): def __call__(self, engine, event):
if event == Events.ITERATION_COMPLETED and engine.state.epoch != 1:
return
for key,loader in self.loaders_dict.items(): for key,loader in self.loaders_dict.items():
self.evaluator.run(loader) self.evaluator.run(loader)
for k,v in self.evaluator.state.metrics.items(): for k,v in self.evaluator.state.metrics.items():
engine.state.metrics[key+':'+k] = v engine.state.metrics[key+':'+k] = v
self.best_model_buffer.save_if_best(engine) if self.best_model_buffer:
self.best_model_buffer.save_if_best(engine)
if event == Events.ITERATION_COMPLETED: if event == Events.ITERATION_COMPLETED:
str = "Epoch:{}.{}\t".format(engine.state.epoch, engine.state.iteration) str = "Epoch:{}.{}\t".format(engine.state.epoch, engine.state.iteration)
else: else:
...@@ -158,3 +157,6 @@ class TensorBoardLogger: ...@@ -158,3 +157,6 @@ class TensorBoardLogger:
self.writer.add_scalar(n, v, self.call_count) self.writer.add_scalar(n, v, self.call_count)
else: else:
self.writer.add_scalars(n, d, self.call_count) self.writer.add_scalars(n, d, self.call_count)
for path, writer in self.writer.all_writers.items():
writer.flush()
...@@ -60,7 +60,7 @@ class MarginBaseLoss: ...@@ -60,7 +60,7 @@ class MarginBaseLoss:
with self.timer.watch('time.d_ij'): with self.timer.watch('time.d_ij'):
assert len(pred_embeddings.shape) == 2, pred_embeddings.shape assert len(pred_embeddings.shape) == 2, pred_embeddings.shape
norm = (pred_embeddings ** 2).sum(1) norm = (pred_embeddings ** 2).sum(1)
self.d_ij = norm.view(-1, 1) + norm.view(1, -1) - 2.0 * torch.mm(pred_embeddings, torch.transpose(pred_embeddings, 0, 1)) self.d_ij = norm.view(-1, 1) + norm.view(1, -1) - 2.0 * torch.mm(pred_embeddings, torch.transpose(pred_embeddings, 0, 1)) #https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/8
self.d_ij = torch.sqrt(torch.clamp(self.d_ij, min=0.0) + 1.0e-8) self.d_ij = torch.sqrt(torch.clamp(self.d_ij, min=0.0) + 1.0e-8)
for i_start in range(0, n, self.params.data.samples_per_class): # start of class block for i_start in range(0, n, self.params.data.samples_per_class): # start of class block
...@@ -72,12 +72,12 @@ class MarginBaseLoss: ...@@ -72,12 +72,12 @@ class MarginBaseLoss:
weights[i] = 0 # dont join with itself weights[i] = 0 # dont join with itself
# select positive pair # select positive pair
weights_same = weights[i_start: i_end] # i-th element already excluded weights_same = weights[i_start: i_end] # i-th element already excluded
j = np.random.choice(range(i_start, i_end), p = weights_same/np.sum(weights_same) ) j = np.random.choice(range(i_start, i_end), p = weights_same/np.sum(weights_same), replace=False)
assert j != i assert j != i
loss += (self.params.mb_loss.alpha + (self.d_ij[i,j] - self.model.mb_loss_beta)).clamp(min=0) #https://arxiv.org/pdf/1706.07567.pdf loss += (self.params.mb_loss.alpha + (self.d_ij[i,j] - self.model.mb_loss_beta)).clamp(min=0) #https://arxiv.org/pdf/1706.07567.pdf
# select neg. pait # select neg. pait
weights = np.delete(weights, np.s_[i_start: i_end], axis=0) weights = np.delete(weights, np.s_[i_start: i_end], axis=0)
k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights/np.sum(weights)) k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights/np.sum(weights), replace=False)
if k >= i_start: if k >= i_start:
k += self.params.data.samples_per_class k += self.params.data.samples_per_class
loss += (self.params.mb_loss.alpha - (self.d_ij[i,k] - self.model.mb_loss_beta)).clamp(min=0) #https://arxiv.org/pdf/1706.07567.pdf loss += (self.params.mb_loss.alpha - (self.d_ij[i,k] - self.model.mb_loss_beta)).clamp(min=0) #https://arxiv.org/pdf/1706.07567.pdf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment