Commit fdc86310 authored by 张泓钇's avatar 张泓钇

update

parent 201797ee
...@@ -31,8 +31,6 @@ class ProtoNet(nn.Module): ...@@ -31,8 +31,6 @@ class ProtoNet(nn.Module):
st_graph = None st_graph = None
node = 0 node = 0
self.model = ST_GCN_18( self.model = ST_GCN_18(
in_channels=3, in_channels=3,
num_class=60, num_class=60,
...@@ -77,7 +75,6 @@ class ProtoNet(nn.Module): ...@@ -77,7 +75,6 @@ class ProtoNet(nn.Module):
if dtw > 0: if dtw > 0:
dist, reg_loss = self.dtw_loss(zq, z_proto) dist, reg_loss = self.dtw_loss(zq, z_proto)
else: else:
#zq, z_proto = F.avg_pool2d(zq, zq.size()[2:]).view(n_class * n_query, c), F.avg_pool2d(z_proto, z_proto.size()[2:]).view(n_class, c)
zq = zq.view(n_class * n_query, -1) zq = zq.view(n_class * n_query, -1)
z_proto = z_proto.view(n_class, -1) z_proto = z_proto.view(n_class, -1)
dist = euclidean_dist(zq, z_proto) dist = euclidean_dist(zq, z_proto)
...@@ -163,7 +160,6 @@ class ProtoNet(nn.Module): ...@@ -163,7 +160,6 @@ class ProtoNet(nn.Module):
loss = torch.tensor(0).float().to(gl.device) loss = torch.tensor(0).float().to(gl.device)
for i in range(x.size()[0]): for i in range(x.size()[0]):
transpose_X = x[i] transpose_X = x[i]
...@@ -174,24 +170,9 @@ class ProtoNet(nn.Module): ...@@ -174,24 +170,9 @@ class ProtoNet(nn.Module):
method_loss = -torch.mean(list_svd[:min(softmax_tgt.shape[0], softmax_tgt.shape[1])]) method_loss = -torch.mean(list_svd[:min(softmax_tgt.shape[0], softmax_tgt.shape[1])])
loss += method_loss loss += method_loss
return loss / x.size()[0] return loss / x.size()[0]
def idm_reg(self, x):
n, t, c = x.size()
reg_loss = torch.tensor(0).float().to(gl.device)
thred = 5
margin = 2
weight, inverse_weight = self.get_W(x, thred)
for i in range(n):
dist = euclidean_dist(x[i, :, :], x[i, :, :]) # t * t
inverse_dist = torch.max(torch.zeros(t, t).to(gl.device), margin - dist).to(gl.device)
reg_loss += (inverse_dist * inverse_weight + dist * weight).sum()
return reg_loss / n
def forward(self, x): def forward(self, x):
x = self.model(x) x = self.model(x)
......
# coding=utf-8
import torch
from torch.nn import functional as F
from torch.nn.modules import Module
class PrototypicalLoss(Module):
'''
Loss class deriving from Module for the prototypical loss function defined below
'''
def __init__(self, n_support):
super(PrototypicalLoss, self).__init__()
self.n_support = n_support
def forward(self, input, target):
return prototypical_loss(input, target, self.n_support)
def euclidean_dist(x, y):
'''
Compute euclidean distance between two tensors
'''
# x: N x D
# y: M x D
n = x.size(0)
m = y.size(0)
d = x.size(1)
if d != y.size(1):
raise Exception
x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)
return torch.pow(x - y, 2).sum(2)
def prototypical_loss(input, target, n_support):
'''
Inspired by https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py
Compute the barycentres by averaging the features of n_support
samples for each class in target, computes then the distances from each
samples' features to each one of the barycentres, computes the
log_probability for each n_query samples for each one of the current
classes, of appartaining to a class c, loss and accuracy are then computed
and returned
Args:
- input: the model output for a batch of samples
- target: ground truth for the above batch of samples
- n_support: number of samples to keep in account when computing
barycentres, for each one of the current classes
'''
target_cpu = target.to('cpu')
input_cpu = input.to('cpu')
def supp_idxs(c):
# FIXME when torch will support where as np
return torch.nonzero(target_cpu.eq(c), as_tuple=False)[:n_support].squeeze(1)
# FIXME when torch.unique will be available on cuda too
classes = torch.unique(target_cpu)
n_classes = len(classes)
# FIXME when torch will support where as np
# assuming n_query, n_target constants
n_query = target_cpu.eq(classes[0].item()).sum().item() - n_support
support_idxs = list(map(supp_idxs, classes))
prototypes = torch.stack([input_cpu[idx_list].mean(0) for idx_list in support_idxs])
# FIXME when torch will support where as np
query_idxs = torch.stack(list(map(lambda c: torch.nonzero(target_cpu.eq(c), as_tuple=False)[n_support:], classes))).view(-1)
query_samples = input.to('cpu')[query_idxs]
dists = euclidean_dist(query_samples, prototypes)
log_p_y = F.log_softmax(-dists, dim=1).view(n_classes, n_query, -1)
target_inds = torch.arange(0, n_classes)
target_inds = target_inds.view(n_classes, 1, 1)
target_inds = target_inds.expand(n_classes, n_query, 1).long()
loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
_, y_hat = log_p_y.max(2)
acc_val = y_hat.eq(target_inds.squeeze()).float().mean()
return loss_val, acc_val
...@@ -29,8 +29,6 @@ def init_seed(opt): ...@@ -29,8 +29,6 @@ def init_seed(opt):
def init_dataset(opt, data_list, mode): def init_dataset(opt, data_list, mode):
# print('not extract frame')
# opt.extract_frame = 0
debug = False debug = False
dataset = NTU_RGBD_Dataset(mode=mode, data_list=data_list, debug=debug, extract_frame=opt.extract_frame) dataset = NTU_RGBD_Dataset(mode=mode, data_list=data_list, debug=debug, extract_frame=opt.extract_frame)
n_classes = len(np.unique(dataset.label)) n_classes = len(np.unique(dataset.label))
...@@ -78,7 +76,6 @@ def init_optim(opt, model): ...@@ -78,7 +76,6 @@ def init_optim(opt, model):
''' '''
Initialize optimizer Initialize optimizer
''' '''
# optimizer = torch.optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=0.9, weight_decay=5e-4, nesterov=True) # optimizer = torch.optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=0.9, weight_decay=5e-4, nesterov=True)
optimizer = torch.optim.Adam(params=model.parameters(), lr=opt.learning_rate, weight_decay=5e-4) optimizer = torch.optim.Adam(params=model.parameters(), lr=opt.learning_rate, weight_decay=5e-4)
...@@ -101,35 +98,6 @@ def save_list_to_file(path, thelist): ...@@ -101,35 +98,6 @@ def save_list_to_file(path, thelist):
for item in thelist: for item in thelist:
f.write("%s\n" % item) f.write("%s\n" % item)
def cosine(x, str):
if str == 'not_encoder':
t_path = os.path.join(gl.experiment_root, 'origin_t')
n, c, t, v, m = x.size()
x = x.mean(4)
else :
t_path = os.path.join(gl.experiment_root, 't')
n, c, t, v = x.size()
for i in range(t - 1):
if not os.path.exists(t_path):
os.mkdir(t_path)
f_path = os.path.join(t_path, '{}_{}.txt'.format(i, i + 1))
t1, t2 = torch.transpose(x[0, :, i, :], 1, 0), torch.transpose(x[0, :, i + 1, :], 1, 0)
t1 = t1 / (t1.norm(dim=1, keepdim=True) + 1e-8)
t2 = t2 / (t2.norm(dim=1, keepdim=True) + 1e-8)
cos = torch.mm(t1, torch.transpose(t2, 1, 0))
# print(cos)
np.savetxt(f_path, cos.cpu().detach().numpy(), fmt='%.2f')
# print('--------------------')
t1, t2 = torch.transpose(x[0, :, 0, :], 1, 0), torch.transpose(x[0, :, t - 1, :], 1, 0)
t1 = t1 / (t1.norm(dim=1, keepdim=True) + 1e-8)
t2 = t2 / (t2.norm(dim=1, keepdim=True) + 1e-8)
cos = torch.mm(t1, torch.transpose(t2, 1, 0))
# print(cos)
f_path = os.path.join(t_path, '{}_{}.txt'.format(0, t - 1))
np.savetxt(f_path, cos.cpu().detach().numpy(), fmt='%.2f')
def train(opt, tr_dataloader, model, optim, lr_scheduler, val_dataloader=None, test_dataloader=None): def train(opt, tr_dataloader, model, optim, lr_scheduler, val_dataloader=None, test_dataloader=None):
''' '''
...@@ -255,8 +223,6 @@ def train(opt, tr_dataloader, model, optim, lr_scheduler, val_dataloader=None, t ...@@ -255,8 +223,6 @@ def train(opt, tr_dataloader, model, optim, lr_scheduler, val_dataloader=None, t
break break
torch.save(model.state_dict(), last_model_path) torch.save(model.state_dict(), last_model_path)
return best_state, best_acc return best_state, best_acc
...@@ -271,7 +237,7 @@ def test(opt, test_dataloader, model): ...@@ -271,7 +237,7 @@ def test(opt, test_dataloader, model):
n_class_val, n_query_val = opt.classes_per_it_val, opt.num_query_val n_class_val, n_query_val = opt.classes_per_it_val, opt.num_query_val
for epoch in range(10): for epoch in range(10):
# print('=== Epoch: {} ==='.format(epoch)) print('=== Epoch: {} ==='.format(epoch))
model.eval() model.eval()
gl.epoch = epoch gl.epoch = epoch
test_iter = iter(test_dataloader) test_iter = iter(test_dataloader)
...@@ -293,25 +259,6 @@ def test(opt, test_dataloader, model): ...@@ -293,25 +259,6 @@ def test(opt, test_dataloader, model):
return avg_acc return avg_acc
def eval(opt):
'''
Initialize everything and train
'''
options = get_parser().parse_args()
if torch.cuda.is_available() and not options.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
init_seed(options)
test_dataloader = init_dataset(options)[-1]
model = init_protonet(options)
model_path = os.path.join(opt.experiment_root, 'best_model.pth')
model.load_state_dict(torch.load(model_path))
test(opt=options,
test_dataloader=test_dataloader,
model=model)
def main(): def main():
''' '''
Initialize everything and train Initialize everything and train
...@@ -328,7 +275,6 @@ def main(): ...@@ -328,7 +275,6 @@ def main():
device = 'cuda:{}'.format(options.device) if torch.cuda.is_available() and options.cuda else 'cpu' device = 'cuda:{}'.format(options.device) if torch.cuda.is_available() and options.cuda else 'cpu'
gl.device = device gl.device = device
# print("device",device)
gl.gamma = options.gamma gl.gamma = options.gamma
options.experiment_root = "../log/"+options.experiment_root options.experiment_root = "../log/"+options.experiment_root
...@@ -368,10 +314,6 @@ def main(): ...@@ -368,10 +314,6 @@ def main():
optim=optim, optim=optim,
lr_scheduler=lr_scheduler) lr_scheduler=lr_scheduler)
best_state, best_acc = res best_state, best_acc = res
# print('Testing with last model..')
# test(opt=options,
# test_dataloader=test_dataloader,
# model=model)
model.load_state_dict(best_state) model.load_state_dict(best_state)
model_path = os.path.join(options.experiment_root, 'best_model.pth') model_path = os.path.join(options.experiment_root, 'best_model.pth')
...@@ -383,12 +325,8 @@ def main(): ...@@ -383,12 +325,8 @@ def main():
elif options.mode == 'test': elif options.mode == 'test':
print('Testing with best model..') print('Testing with best model..')
test(opt=options, test(opt=options,
test_dataloader= test_dataloader=test_dataloader,
test_dataloader,
model=model) model=model)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment