Unverified Commit 85b722e4 authored by NingMa's avatar NingMa Committed by GitHub

Update protonet.py

parent 870e1755
......@@ -146,7 +146,7 @@ class ProtoNet(nn.Module):
x = x.unsqueeze(1).expand(n, m, t, c).reshape(n * m, t, c)
y = y.unsqueeze(0).expand(n, m, t, c).reshape(n * m, t, c)
sdtw = SoftDTW(gamma=gl.gamma, normalize=False, attention=self.attention_x, attention_y=self.attention_y)
sdtw = SoftDTW(gamma=gl.gamma, normalize=True, attention=self.attention_x, attention_y=self.attention_y)
loss = sdtw(x, y)
return loss.view(n, m)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment