Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
O
OvoTools
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
林帅浩
OvoTools
Commits
6419986d
Commit
6419986d
authored
Mar 08, 2019
by
IlyaOvodov
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Timer, Trensorboard logger
parent
f21f1e8e
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
152 additions
and
30 deletions
+152
-30
ignite_tools.py
ovotools/ignite_tools.py
+100
-1
pytorch_tools.py
ovotools/pytorch_tools.py
+52
-29
No files found.
ovotools/ignite_tools.py
View file @
6419986d
import
copy
import
torch
from
ignite.engine
import
Events
import
collections
import
time
import
tensorboardX
class
IgniteTimes
:
class
TimerWatch
:
def
__init__
(
self
,
timer
,
name
):
self
.
name
=
name
self
.
timer
=
timer
def
__enter__
(
self
):
self
.
timer
.
start
(
self
.
name
)
return
self
def
__exit__
(
self
,
*
args
):
self
.
timer
.
end
(
self
.
name
)
return
False
def
__init__
(
self
,
engine
,
count_iters
=
False
,
measured_events
=
{}):
self
.
clocks
=
dict
()
self
.
sums
=
collections
.
defaultdict
(
float
)
self
.
counts
=
collections
.
defaultdict
(
int
)
for
name
,
(
event_engine
,
start_event
,
end_event
)
in
measured_events
.
items
():
event_engine
.
add_event_handler
(
start_event
,
self
.
on_start
,
name
)
event_engine
.
add_event_handler
(
end_event
,
self
.
on_end
,
name
)
event
=
Events
.
ITERATION_COMPLETED
if
count_iters
else
Events
.
EPOCH_COMPLETED
engine
.
add_event_handler
(
event
,
self
.
on_complete
)
def
reset_all
(
self
):
self
.
clocks
.
clear
()
self
.
sums
.
clear
()
self
.
counts
.
clear
()
def
start
(
self
,
name
):
assert
not
name
in
self
.
clocks
self
.
clocks
[
name
]
=
time
.
time
()
def
end
(
self
,
name
):
assert
name
in
self
.
clocks
t
=
time
.
time
()
-
self
.
clocks
[
name
]
self
.
counts
[
name
]
+=
1
self
.
sums
[
name
]
+=
t
self
.
clocks
.
pop
(
name
)
def
watch
(
self
,
name
):
return
self
.
TimerWatch
(
self
,
name
)
def
on_start
(
self
,
engine
,
name
):
self
.
start
(
name
)
def
on_end
(
self
,
engine
,
name
):
self
.
end
(
name
)
def
on_complete
(
self
,
engine
):
for
n
,
v
in
self
.
sums
.
items
():
engine
.
state
.
metrics
[
n
]
=
v
self
.
reset_all
()
class
BestModelBuffer
:
...
...
@@ -49,7 +108,7 @@ class LogTrainingResults:
for
key
,
loader
in
self
.
loaders_dict
.
items
():
self
.
evaluator
.
run
(
loader
)
for
k
,
v
in
self
.
evaluator
.
state
.
metrics
.
items
():
engine
.
state
.
metrics
[
key
+
'
.
'
+
k
]
=
v
engine
.
state
.
metrics
[
key
+
'
:
'
+
k
]
=
v
self
.
best_model_buffer
.
save_if_best
(
engine
)
if
event
==
Events
.
ITERATION_COMPLETED
:
str
=
"Epoch:{}.{}
\t
"
.
format
(
engine
.
state
.
epoch
,
engine
.
state
.
iteration
)
...
...
@@ -59,3 +118,43 @@ class LogTrainingResults:
print
(
str
)
with
open
(
self
.
params
.
get_base_filename
()
+
'.log'
,
'a'
)
as
f
:
f
.
write
(
str
+
'
\n
'
)
class
TensorBoardLogger
:
SERIES_PLOT_SEPARATOR
=
':'
GROUP_PLOT_SEPARATOR
=
'.'
def
__init__
(
self
,
trainer_engine
,
params
,
count_iters
=
False
,
period
=
1
):
log_dir
=
params
.
get_base_filename
()
self
.
writer
=
tensorboardX
.
SummaryWriter
(
log_dir
=
log_dir
,
flush_secs
=
10
)
event
=
Events
.
ITERATION_COMPLETED
if
count_iters
else
Events
.
EPOCH_COMPLETED
trainer_engine
.
add_event_handler
(
event
,
self
.
on_event
)
self
.
period
=
period
self
.
call_count
=
0
trainer_engine
.
add_event_handler
(
Events
.
COMPLETED
,
self
.
on_completed
)
def
on_completed
(
self
,
engine
):
self
.
writer
.
close
()
def
on_event
(
self
,
engine
):
'''
engine.state.metrics with name
*|* are interpreted as series(train,val).plot_name(metric)
*|*.* are interpreted as series(train,val).group(metric class).plot_name
'''
self
.
call_count
+=
1
if
self
.
call_count
%
self
.
period
!=
0
:
return
metrics
=
collections
.
defaultdict
(
dict
)
for
name
,
value
in
engine
.
state
.
metrics
.
items
():
name_parts
=
name
.
split
(
self
.
SERIES_PLOT_SEPARATOR
,
1
)
if
len
(
name_parts
)
==
1
:
name_parts
.
append
(
name_parts
[
0
])
metrics
[
name_parts
[
1
]
.
replace
(
self
.
GROUP_PLOT_SEPARATOR
,
'/'
)][
name_parts
[
0
]]
=
value
for
n
,
d
in
metrics
.
items
():
if
len
(
d
)
==
1
:
for
k
,
v
in
d
.
items
():
self
.
writer
.
add_scalar
(
n
,
v
,
self
.
call_count
)
else
:
self
.
writer
.
add_scalars
(
n
,
d
,
self
.
call_count
)
ovotools/pytorch_tools.py
View file @
6419986d
import
torch
import
numpy
as
np
class
DummyTimer
:
'''
replacement for IgniteTimer if it is not provided
'''
class
TimerWatch
:
def
__init__
(
self
,
timer
,
name
):
pass
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
args
):
return
False
def
__init__
(
self
):
pass
def
start
(
self
,
name
):
pass
def
end
(
self
,
name
):
pass
def
watch
(
self
,
name
):
return
self
.
TimerWatch
(
self
,
name
)
class
MarginBaseLoss
:
'''
L2-constrained Softmax Loss for Discriminative Face Verification https://arxiv.org/pdf/1703.09507
margin based loss with distance weighted sampling https://arxiv.org/pdf/1706.07567.pdf
'''
ignore_index
=
-
100
def
__init__
(
self
,
model
,
classes
,
device
,
params
):
def
__init__
(
self
,
model
,
classes
,
device
,
params
,
timer
=
DummyTimer
()
):
assert
params
.
data
.
samples_per_class
>=
2
self
.
model
=
model
self
.
device
=
device
...
...
@@ -15,15 +31,20 @@ class MarginBaseLoss:
self
.
classes
=
sorted
(
classes
)
self
.
classes_dict
=
{
v
:
i
for
i
,
v
in
enumerate
(
self
.
classes
)}
self
.
lambda_rev
=
1
/
params
.
distance_weighted_sampling
.
lambda_
self
.
timer
=
timer
print
(
'classes: '
,
len
(
self
.
classes
))
def
set_timer
(
self
,
timer
):
self
.
timer
=
timer
def
classes_to_ids
(
self
,
y_class
,
ignore_index
=
-
100
):
return
torch
.
tensor
([
self
.
classes_dict
.
get
(
int
(
c
.
item
()),
ignore_index
)
for
c
in
y_class
])
.
to
(
self
.
device
)
def
l2_loss
(
self
,
net_output
,
y_class
):
pred_class
=
net_output
[
0
]
class_nos
=
self
.
classes_to_ids
(
y_class
,
ignore_index
=
self
.
ignore_index
)
return
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
self
.
ignore_index
)(
pred_class
,
class_nos
)
with
self
.
timer
.
watch
(
'time.l2_loss'
):
pred_class
=
net_output
[
0
]
class_nos
=
self
.
classes_to_ids
(
y_class
,
ignore_index
=
self
.
ignore_index
)
return
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
self
.
ignore_index
)(
pred_class
,
class_nos
)
def
D
(
self
,
pred_embeddings
,
i
,
j
):
if
i
==
j
:
...
...
@@ -32,31 +53,33 @@ class MarginBaseLoss:
def
mb_loss
(
self
,
net_output
,
y_class
):
pred_embeddings
=
net_output
[
1
]
loss
=
0
n
=
len
(
pred_embeddings
)
# samples in batch
dim
=
pred_embeddings
[
0
]
.
shape
[
0
]
# dimensionality
for
i_start
in
range
(
0
,
n
,
self
.
params
.
data
.
samples_per_class
):
# start of class block
i_end
=
i_start
+
self
.
params
.
data
.
samples_per_class
# start of class block
for
i
in
range
(
i_start
,
i_end
-
1
):
d_ij
=
[
0
if
i
==
j
else
self
.
D
(
pred_embeddings
,
i
,
j
)
for
j
in
range
(
n
)]
weights
=
[
1
/
max
(
self
.
lambda_rev
,
pow
(
d
,
dim
-
2
)
*
pow
(
1
-
d
*
d
/
4
,
(
dim
-
3
)
/
2
))
# https://arxiv.org/pdf/1706.07567.pdf
for
id
,
d
in
enumerate
(
d_ij
)
if
id
!=
i
]
# dont join with itself
weights_same
=
np
.
asarray
(
weights
[
i_start
:
i_end
-
1
])
# i-th element already excluded
j
=
np
.
random
.
choice
(
range
(
i_start
,
i_end
-
1
),
p
=
weights_same
/
np
.
sum
(
weights_same
)
)
if
j
>=
i
:
j
+=
1
# for j in range(i+1, i_end): # positive pair
loss
+=
(
self
.
params
.
mb_loss
.
alpha
+
(
d_ij
[
j
]
-
self
.
model
.
mb_loss_beta
))
.
clamp
(
min
=
0
)
# select neg. pait
weights
[
i_start
:
i_end
-
1
]
=
[]
# i-th element already excluded
weights
=
np
.
asarray
(
weights
)
weights
=
weights
/
np
.
sum
(
weights
)
k
=
np
.
random
.
choice
(
range
(
0
,
n
-
self
.
params
.
data
.
samples_per_class
),
p
=
weights
)
if
k
>=
i_start
:
k
+=
self
.
params
.
data
.
samples_per_class
loss
+=
(
self
.
params
.
mb_loss
.
alpha
-
(
d_ij
[
k
]
-
self
.
model
.
mb_loss_beta
))
.
clamp
(
min
=
0
)
return
loss
[
0
]
/
len
(
pred_embeddings
)
with
self
.
timer
.
watch
(
'time.mb_loss'
):
pred_embeddings
=
net_output
[
1
]
loss
=
0
n
=
len
(
pred_embeddings
)
# samples in batch
dim
=
pred_embeddings
[
0
]
.
shape
[
0
]
# dimensionality
for
i_start
in
range
(
0
,
n
,
self
.
params
.
data
.
samples_per_class
):
# start of class block
i_end
=
i_start
+
self
.
params
.
data
.
samples_per_class
# start of class block
for
i
in
range
(
i_start
,
i_end
-
1
):
with
self
.
timer
.
watch
(
'time.d_ij'
):
d_ij
=
[
0
if
i
==
j
else
self
.
D
(
pred_embeddings
,
i
,
j
)
for
j
in
range
(
n
)]
weights
=
[
1
/
max
(
self
.
lambda_rev
,
pow
(
d
,
dim
-
2
)
*
pow
(
1
-
d
*
d
/
4
,
(
dim
-
3
)
/
2
))
# https://arxiv.org/pdf/1706.07567.pdf
for
id
,
d
in
enumerate
(
d_ij
)
if
id
!=
i
]
# dont join with itself
weights_same
=
np
.
asarray
(
weights
[
i_start
:
i_end
-
1
])
# i-th element already excluded
j
=
np
.
random
.
choice
(
range
(
i_start
,
i_end
-
1
),
p
=
weights_same
/
np
.
sum
(
weights_same
)
)
if
j
>=
i
:
j
+=
1
# for j in range(i+1, i_end): # positive pair
loss
+=
(
self
.
params
.
mb_loss
.
alpha
+
(
d_ij
[
j
]
-
self
.
model
.
mb_loss_beta
))
.
clamp
(
min
=
0
)
# select neg. pait
weights
[
i_start
:
i_end
-
1
]
=
[]
# i-th element already excluded
weights
=
np
.
asarray
(
weights
)
weights
=
weights
/
np
.
sum
(
weights
)
k
=
np
.
random
.
choice
(
range
(
0
,
n
-
self
.
params
.
data
.
samples_per_class
),
p
=
weights
)
if
k
>=
i_start
:
k
+=
self
.
params
.
data
.
samples_per_class
loss
+=
(
self
.
params
.
mb_loss
.
alpha
-
(
d_ij
[
k
]
-
self
.
model
.
mb_loss_beta
))
.
clamp
(
min
=
0
)
return
loss
[
0
]
/
len
(
pred_embeddings
)
def
loss
(
self
,
net_output
,
y_class
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment