Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
DASTM
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
3
Issues
3
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
马宁(18博)
DASTM
Commits
870e1755
Unverified
Commit
870e1755
authored
Oct 04, 2022
by
NingMa
Committed by
GitHub
Oct 04, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #4 from HongyiZhang-1998/main
update
parents
201797ee
fdc86310
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
2 additions
and
169 deletions
+2
-169
protonet.py
protonet.py
+0
-19
prototypical_loss.py
prototypical_loss.py
+0
-86
train.py
train.py
+2
-64
No files found.
protonet.py
View file @
870e1755
...
@@ -31,8 +31,6 @@ class ProtoNet(nn.Module):
...
@@ -31,8 +31,6 @@ class ProtoNet(nn.Module):
st_graph
=
None
st_graph
=
None
node
=
0
node
=
0
self
.
model
=
ST_GCN_18
(
self
.
model
=
ST_GCN_18
(
in_channels
=
3
,
in_channels
=
3
,
num_class
=
60
,
num_class
=
60
,
...
@@ -77,7 +75,6 @@ class ProtoNet(nn.Module):
...
@@ -77,7 +75,6 @@ class ProtoNet(nn.Module):
if
dtw
>
0
:
if
dtw
>
0
:
dist
,
reg_loss
=
self
.
dtw_loss
(
zq
,
z_proto
)
dist
,
reg_loss
=
self
.
dtw_loss
(
zq
,
z_proto
)
else
:
else
:
#zq, z_proto = F.avg_pool2d(zq, zq.size()[2:]).view(n_class * n_query, c), F.avg_pool2d(z_proto, z_proto.size()[2:]).view(n_class, c)
zq
=
zq
.
view
(
n_class
*
n_query
,
-
1
)
zq
=
zq
.
view
(
n_class
*
n_query
,
-
1
)
z_proto
=
z_proto
.
view
(
n_class
,
-
1
)
z_proto
=
z_proto
.
view
(
n_class
,
-
1
)
dist
=
euclidean_dist
(
zq
,
z_proto
)
dist
=
euclidean_dist
(
zq
,
z_proto
)
...
@@ -163,7 +160,6 @@ class ProtoNet(nn.Module):
...
@@ -163,7 +160,6 @@ class ProtoNet(nn.Module):
loss
=
torch
.
tensor
(
0
)
.
float
()
.
to
(
gl
.
device
)
loss
=
torch
.
tensor
(
0
)
.
float
()
.
to
(
gl
.
device
)
for
i
in
range
(
x
.
size
()[
0
]):
for
i
in
range
(
x
.
size
()[
0
]):
transpose_X
=
x
[
i
]
transpose_X
=
x
[
i
]
...
@@ -174,24 +170,9 @@ class ProtoNet(nn.Module):
...
@@ -174,24 +170,9 @@ class ProtoNet(nn.Module):
method_loss
=
-
torch
.
mean
(
list_svd
[:
min
(
softmax_tgt
.
shape
[
0
],
softmax_tgt
.
shape
[
1
])])
method_loss
=
-
torch
.
mean
(
list_svd
[:
min
(
softmax_tgt
.
shape
[
0
],
softmax_tgt
.
shape
[
1
])])
loss
+=
method_loss
loss
+=
method_loss
return
loss
/
x
.
size
()[
0
]
return
loss
/
x
.
size
()[
0
]
def
idm_reg
(
self
,
x
):
n
,
t
,
c
=
x
.
size
()
reg_loss
=
torch
.
tensor
(
0
)
.
float
()
.
to
(
gl
.
device
)
thred
=
5
margin
=
2
weight
,
inverse_weight
=
self
.
get_W
(
x
,
thred
)
for
i
in
range
(
n
):
dist
=
euclidean_dist
(
x
[
i
,
:,
:],
x
[
i
,
:,
:])
# t * t
inverse_dist
=
torch
.
max
(
torch
.
zeros
(
t
,
t
)
.
to
(
gl
.
device
),
margin
-
dist
)
.
to
(
gl
.
device
)
reg_loss
+=
(
inverse_dist
*
inverse_weight
+
dist
*
weight
)
.
sum
()
return
reg_loss
/
n
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
model
(
x
)
x
=
self
.
model
(
x
)
...
...
prototypical_loss.py
deleted
100644 → 0
View file @
201797ee
# coding=utf-8
import
torch
from
torch.nn
import
functional
as
F
from
torch.nn.modules
import
Module
class
PrototypicalLoss
(
Module
):
'''
Loss class deriving from Module for the prototypical loss function defined below
'''
def
__init__
(
self
,
n_support
):
super
(
PrototypicalLoss
,
self
)
.
__init__
()
self
.
n_support
=
n_support
def
forward
(
self
,
input
,
target
):
return
prototypical_loss
(
input
,
target
,
self
.
n_support
)
def
euclidean_dist
(
x
,
y
):
'''
Compute euclidean distance between two tensors
'''
# x: N x D
# y: M x D
n
=
x
.
size
(
0
)
m
=
y
.
size
(
0
)
d
=
x
.
size
(
1
)
if
d
!=
y
.
size
(
1
):
raise
Exception
x
=
x
.
unsqueeze
(
1
)
.
expand
(
n
,
m
,
d
)
y
=
y
.
unsqueeze
(
0
)
.
expand
(
n
,
m
,
d
)
return
torch
.
pow
(
x
-
y
,
2
)
.
sum
(
2
)
def
prototypical_loss
(
input
,
target
,
n_support
):
'''
Inspired by https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py
Compute the barycentres by averaging the features of n_support
samples for each class in target, computes then the distances from each
samples' features to each one of the barycentres, computes the
log_probability for each n_query samples for each one of the current
classes, of appartaining to a class c, loss and accuracy are then computed
and returned
Args:
- input: the model output for a batch of samples
- target: ground truth for the above batch of samples
- n_support: number of samples to keep in account when computing
barycentres, for each one of the current classes
'''
target_cpu
=
target
.
to
(
'cpu'
)
input_cpu
=
input
.
to
(
'cpu'
)
def
supp_idxs
(
c
):
# FIXME when torch will support where as np
return
torch
.
nonzero
(
target_cpu
.
eq
(
c
),
as_tuple
=
False
)[:
n_support
]
.
squeeze
(
1
)
# FIXME when torch.unique will be available on cuda too
classes
=
torch
.
unique
(
target_cpu
)
n_classes
=
len
(
classes
)
# FIXME when torch will support where as np
# assuming n_query, n_target constants
n_query
=
target_cpu
.
eq
(
classes
[
0
]
.
item
())
.
sum
()
.
item
()
-
n_support
support_idxs
=
list
(
map
(
supp_idxs
,
classes
))
prototypes
=
torch
.
stack
([
input_cpu
[
idx_list
]
.
mean
(
0
)
for
idx_list
in
support_idxs
])
# FIXME when torch will support where as np
query_idxs
=
torch
.
stack
(
list
(
map
(
lambda
c
:
torch
.
nonzero
(
target_cpu
.
eq
(
c
),
as_tuple
=
False
)[
n_support
:],
classes
)))
.
view
(
-
1
)
query_samples
=
input
.
to
(
'cpu'
)[
query_idxs
]
dists
=
euclidean_dist
(
query_samples
,
prototypes
)
log_p_y
=
F
.
log_softmax
(
-
dists
,
dim
=
1
)
.
view
(
n_classes
,
n_query
,
-
1
)
target_inds
=
torch
.
arange
(
0
,
n_classes
)
target_inds
=
target_inds
.
view
(
n_classes
,
1
,
1
)
target_inds
=
target_inds
.
expand
(
n_classes
,
n_query
,
1
)
.
long
()
loss_val
=
-
log_p_y
.
gather
(
2
,
target_inds
)
.
squeeze
()
.
view
(
-
1
)
.
mean
()
_
,
y_hat
=
log_p_y
.
max
(
2
)
acc_val
=
y_hat
.
eq
(
target_inds
.
squeeze
())
.
float
()
.
mean
()
return
loss_val
,
acc_val
train.py
View file @
870e1755
...
@@ -29,8 +29,6 @@ def init_seed(opt):
...
@@ -29,8 +29,6 @@ def init_seed(opt):
def
init_dataset
(
opt
,
data_list
,
mode
):
def
init_dataset
(
opt
,
data_list
,
mode
):
# print('not extract frame')
# opt.extract_frame = 0
debug
=
False
debug
=
False
dataset
=
NTU_RGBD_Dataset
(
mode
=
mode
,
data_list
=
data_list
,
debug
=
debug
,
extract_frame
=
opt
.
extract_frame
)
dataset
=
NTU_RGBD_Dataset
(
mode
=
mode
,
data_list
=
data_list
,
debug
=
debug
,
extract_frame
=
opt
.
extract_frame
)
n_classes
=
len
(
np
.
unique
(
dataset
.
label
))
n_classes
=
len
(
np
.
unique
(
dataset
.
label
))
...
@@ -78,7 +76,6 @@ def init_optim(opt, model):
...
@@ -78,7 +76,6 @@ def init_optim(opt, model):
'''
'''
Initialize optimizer
Initialize optimizer
'''
'''
# optimizer = torch.optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=0.9, weight_decay=5e-4, nesterov=True)
# optimizer = torch.optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=0.9, weight_decay=5e-4, nesterov=True)
optimizer
=
torch
.
optim
.
Adam
(
params
=
model
.
parameters
(),
lr
=
opt
.
learning_rate
,
weight_decay
=
5e-4
)
optimizer
=
torch
.
optim
.
Adam
(
params
=
model
.
parameters
(),
lr
=
opt
.
learning_rate
,
weight_decay
=
5e-4
)
...
@@ -101,35 +98,6 @@ def save_list_to_file(path, thelist):
...
@@ -101,35 +98,6 @@ def save_list_to_file(path, thelist):
for
item
in
thelist
:
for
item
in
thelist
:
f
.
write
(
"
%
s
\n
"
%
item
)
f
.
write
(
"
%
s
\n
"
%
item
)
def
cosine
(
x
,
str
):
if
str
==
'not_encoder'
:
t_path
=
os
.
path
.
join
(
gl
.
experiment_root
,
'origin_t'
)
n
,
c
,
t
,
v
,
m
=
x
.
size
()
x
=
x
.
mean
(
4
)
else
:
t_path
=
os
.
path
.
join
(
gl
.
experiment_root
,
't'
)
n
,
c
,
t
,
v
=
x
.
size
()
for
i
in
range
(
t
-
1
):
if
not
os
.
path
.
exists
(
t_path
):
os
.
mkdir
(
t_path
)
f_path
=
os
.
path
.
join
(
t_path
,
'{}_{}.txt'
.
format
(
i
,
i
+
1
))
t1
,
t2
=
torch
.
transpose
(
x
[
0
,
:,
i
,
:],
1
,
0
),
torch
.
transpose
(
x
[
0
,
:,
i
+
1
,
:],
1
,
0
)
t1
=
t1
/
(
t1
.
norm
(
dim
=
1
,
keepdim
=
True
)
+
1e-8
)
t2
=
t2
/
(
t2
.
norm
(
dim
=
1
,
keepdim
=
True
)
+
1e-8
)
cos
=
torch
.
mm
(
t1
,
torch
.
transpose
(
t2
,
1
,
0
))
# print(cos)
np
.
savetxt
(
f_path
,
cos
.
cpu
()
.
detach
()
.
numpy
(),
fmt
=
'
%.2
f'
)
# print('--------------------')
t1
,
t2
=
torch
.
transpose
(
x
[
0
,
:,
0
,
:],
1
,
0
),
torch
.
transpose
(
x
[
0
,
:,
t
-
1
,
:],
1
,
0
)
t1
=
t1
/
(
t1
.
norm
(
dim
=
1
,
keepdim
=
True
)
+
1e-8
)
t2
=
t2
/
(
t2
.
norm
(
dim
=
1
,
keepdim
=
True
)
+
1e-8
)
cos
=
torch
.
mm
(
t1
,
torch
.
transpose
(
t2
,
1
,
0
))
# print(cos)
f_path
=
os
.
path
.
join
(
t_path
,
'{}_{}.txt'
.
format
(
0
,
t
-
1
))
np
.
savetxt
(
f_path
,
cos
.
cpu
()
.
detach
()
.
numpy
(),
fmt
=
'
%.2
f'
)
def
train
(
opt
,
tr_dataloader
,
model
,
optim
,
lr_scheduler
,
val_dataloader
=
None
,
test_dataloader
=
None
):
def
train
(
opt
,
tr_dataloader
,
model
,
optim
,
lr_scheduler
,
val_dataloader
=
None
,
test_dataloader
=
None
):
'''
'''
...
@@ -255,8 +223,6 @@ def train(opt, tr_dataloader, model, optim, lr_scheduler, val_dataloader=None, t
...
@@ -255,8 +223,6 @@ def train(opt, tr_dataloader, model, optim, lr_scheduler, val_dataloader=None, t
break
break
torch
.
save
(
model
.
state_dict
(),
last_model_path
)
torch
.
save
(
model
.
state_dict
(),
last_model_path
)
return
best_state
,
best_acc
return
best_state
,
best_acc
...
@@ -271,7 +237,7 @@ def test(opt, test_dataloader, model):
...
@@ -271,7 +237,7 @@ def test(opt, test_dataloader, model):
n_class_val
,
n_query_val
=
opt
.
classes_per_it_val
,
opt
.
num_query_val
n_class_val
,
n_query_val
=
opt
.
classes_per_it_val
,
opt
.
num_query_val
for
epoch
in
range
(
10
):
for
epoch
in
range
(
10
):
#
print('=== Epoch: {} ==='.format(epoch))
print
(
'=== Epoch: {} ==='
.
format
(
epoch
))
model
.
eval
()
model
.
eval
()
gl
.
epoch
=
epoch
gl
.
epoch
=
epoch
test_iter
=
iter
(
test_dataloader
)
test_iter
=
iter
(
test_dataloader
)
...
@@ -293,25 +259,6 @@ def test(opt, test_dataloader, model):
...
@@ -293,25 +259,6 @@ def test(opt, test_dataloader, model):
return
avg_acc
return
avg_acc
def
eval
(
opt
):
'''
Initialize everything and train
'''
options
=
get_parser
()
.
parse_args
()
if
torch
.
cuda
.
is_available
()
and
not
options
.
cuda
:
print
(
"WARNING: You have a CUDA device, so you should probably run with --cuda"
)
init_seed
(
options
)
test_dataloader
=
init_dataset
(
options
)[
-
1
]
model
=
init_protonet
(
options
)
model_path
=
os
.
path
.
join
(
opt
.
experiment_root
,
'best_model.pth'
)
model
.
load_state_dict
(
torch
.
load
(
model_path
))
test
(
opt
=
options
,
test_dataloader
=
test_dataloader
,
model
=
model
)
def
main
():
def
main
():
'''
'''
Initialize everything and train
Initialize everything and train
...
@@ -328,7 +275,6 @@ def main():
...
@@ -328,7 +275,6 @@ def main():
device
=
'cuda:{}'
.
format
(
options
.
device
)
if
torch
.
cuda
.
is_available
()
and
options
.
cuda
else
'cpu'
device
=
'cuda:{}'
.
format
(
options
.
device
)
if
torch
.
cuda
.
is_available
()
and
options
.
cuda
else
'cpu'
gl
.
device
=
device
gl
.
device
=
device
# print("device",device)
gl
.
gamma
=
options
.
gamma
gl
.
gamma
=
options
.
gamma
options
.
experiment_root
=
"../log/"
+
options
.
experiment_root
options
.
experiment_root
=
"../log/"
+
options
.
experiment_root
...
@@ -368,10 +314,6 @@ def main():
...
@@ -368,10 +314,6 @@ def main():
optim
=
optim
,
optim
=
optim
,
lr_scheduler
=
lr_scheduler
)
lr_scheduler
=
lr_scheduler
)
best_state
,
best_acc
=
res
best_state
,
best_acc
=
res
# print('Testing with last model..')
# test(opt=options,
# test_dataloader=test_dataloader,
# model=model)
model
.
load_state_dict
(
best_state
)
model
.
load_state_dict
(
best_state
)
model_path
=
os
.
path
.
join
(
options
.
experiment_root
,
'best_model.pth'
)
model_path
=
os
.
path
.
join
(
options
.
experiment_root
,
'best_model.pth'
)
...
@@ -383,12 +325,8 @@ def main():
...
@@ -383,12 +325,8 @@ def main():
elif
options
.
mode
==
'test'
:
elif
options
.
mode
==
'test'
:
print
(
'Testing with best model..'
)
print
(
'Testing with best model..'
)
test
(
opt
=
options
,
test
(
opt
=
options
,
test_dataloader
=
test_dataloader
=
test_dataloader
,
test_dataloader
,
model
=
model
)
model
=
model
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment