Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
H
hi-template
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
吴磊(20博)
hi-template
Commits
fb1ea036
Commit
fb1ea036
authored
Apr 01, 2022
by
吴磊(20博)
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
init
parent
300bc2b9
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
647 additions
and
0 deletions
+647
-0
.gitignore
.gitignore
+3
-0
trainer.py
trainer/trainer.py
+46
-0
__init__.py
utils/__init__.py
+0
-0
utils.py
utils/utils.py
+598
-0
No files found.
.gitignore
0 → 100644
View file @
fb1ea036
.idea
.DS_Store
\ No newline at end of file
trainer/trainer.py
View file @
fb1ea036
import
torch
from
utils.utils
import
*
def
train_one_epoch
(
model
,
optimizer
,
data_loader
,
device
,
epoch
,
loss
,
loss_weights
,
warmup
=
False
):
lr_scheduler
=
None
log_bar
=
None
if
is_main_process
():
log_bar
=
ProgBar
(
len
(
data_loader
))
for
i
,
(
inputs
,
targets
)
in
enumerate
(
data_loader
):
if
isinstance
(
inputs
,
torch
.
Tensor
):
inputs
=
[
inputs
]
if
isinstance
(
targets
,
torch
.
Tensor
):
targets
=
[
targets
]
for
idx_inputs
in
range
(
len
(
inputs
)):
inputs
[
idx_inputs
]
=
inputs
[
idx_inputs
]
.
to
(
device
)
for
idx_target
in
range
(
len
(
targets
)):
targets
[
idx_target
]
=
targets
[
idx_target
]
.
to
(
device
)
outputs
=
model
(
*
inputs
)
if
isinstance
(
outputs
,
torch
.
Tensor
):
outputs
=
[
outputs
]
losses
=
[]
it
=
zip
(
outputs
,
targets
,
loss
,
loss_weights
)
for
output
,
target
,
loss_dict
,
loss_weight
in
it
:
loss_value_dict
=
dict
()
for
loss_name
,
loss_fn
in
loss_dict
.
items
():
loss_value
=
loss_fn
(
output
,
target
)
loss_value_dict
.
update
({
loss_name
:
loss_value
*
loss_weight
})
losses
.
append
(
loss_value_dict
)
losses_value
=
sum
([
sum
(
loss_v
for
loss_v
in
loss_value_dict
.
values
())
for
loss_value_dict
in
losses
])
optimizer
.
zero_grad
()
losses_value
.
backward
()
optimizer
.
step
()
loss_dict_reduced
=
[]
for
loss_value_dict
in
losses
:
loss_dict_reduced
.
append
(
reduce_dict
(
loss_value_dict
))
if
is_main_process
()
and
log_bar
is
not
None
:
logs
=
[]
for
l_i
,
loss_dict
in
enumerate
(
loss_dict_reduced
):
for
k
,
v
in
loss_dict
.
items
():
logs
.
append
((
k
+
"_
%
d"
%
l_i
,
v
.
item
()))
log_bar
.
update
(
i
+
1
,
logs
)
return
None
\ No newline at end of file
utils/__init__.py
0 → 100644
View file @
fb1ea036
utils/utils.py
0 → 100644
View file @
fb1ea036
import
sys
from
collections
import
defaultdict
,
deque
import
datetime
import
pickle
import
time
import
torch
import
torch.distributed
as
dist
import
errno
import
os
import
random
import
numpy
as
np
class
ProgBar
(
object
):
"""Displays a progress bar.
Arguments:
target: Total number of steps expected, None if unknown.
width: Progress bar width on screen.
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
stateful_metrics: Iterable of string names of metrics that should *not* be
averaged over time. Metrics in this list will be displayed as-is. All
others will be averaged by the progbar before display.
interval: Minimum visual progress update interval (in seconds).
unit_name: Display name for step counts (usually "step" or "sample").
"""
def
__init__
(
self
,
target
,
width
=
30
,
verbose
=
1
,
interval
=
0.05
,
stateful_metrics
=
None
,
unit_name
=
'step'
):
self
.
target
=
target
self
.
width
=
width
self
.
verbose
=
verbose
self
.
interval
=
interval
self
.
unit_name
=
unit_name
if
stateful_metrics
:
self
.
stateful_metrics
=
set
(
stateful_metrics
)
else
:
self
.
stateful_metrics
=
set
()
self
.
_dynamic_display
=
((
hasattr
(
sys
.
stdout
,
'isatty'
)
and
sys
.
stdout
.
isatty
())
or
'ipykernel'
in
sys
.
modules
or
'posix'
in
sys
.
modules
or
'PYCHARM_HOSTED'
in
os
.
environ
)
self
.
_total_width
=
0
self
.
_seen_so_far
=
0
# We use a dict + list to avoid garbage collection
# issues found in OrderedDict
self
.
_values
=
{}
self
.
_values_order
=
[]
self
.
_start
=
time
.
time
()
self
.
_last_update
=
0
self
.
_time_after_first_step
=
None
def
update
(
self
,
current
,
values
=
None
,
finalize
=
None
):
"""Updates the progress bar.
Arguments:
current: Index of current step.
values: List of tuples: `(name, value_for_last_step)`. If `name` is in
`stateful_metrics`, `value_for_last_step` will be displayed as-is.
Else, an average of the metric over time will be displayed.
finalize: Whether this is the last update for the progress bar. If
`None`, defaults to `current >= self.target`.
"""
if
finalize
is
None
:
if
self
.
target
is
None
:
finalize
=
False
else
:
finalize
=
current
>=
self
.
target
values
=
values
or
[]
for
k
,
v
in
values
:
if
k
not
in
self
.
_values_order
:
self
.
_values_order
.
append
(
k
)
if
k
not
in
self
.
stateful_metrics
:
# In the case that progress bar doesn't have a target value in the first
# epoch, both on_batch_end and on_epoch_end will be called, which will
# cause 'current' and 'self._seen_so_far' to have the same value. Force
# the minimal value to 1 here, otherwise stateful_metric will be 0s.
value_base
=
max
(
current
-
self
.
_seen_so_far
,
1
)
if
k
not
in
self
.
_values
:
self
.
_values
[
k
]
=
[
v
*
value_base
,
value_base
]
else
:
self
.
_values
[
k
][
0
]
+=
v
*
value_base
self
.
_values
[
k
][
1
]
+=
value_base
else
:
# Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the
# numeric formatting.
self
.
_values
[
k
]
=
[
v
,
1
]
self
.
_seen_so_far
=
current
now
=
time
.
time
()
info
=
' -
%.0
fs'
%
(
now
-
self
.
_start
)
if
self
.
verbose
==
1
:
if
now
-
self
.
_last_update
<
self
.
interval
and
not
finalize
:
return
prev_total_width
=
self
.
_total_width
if
self
.
_dynamic_display
:
sys
.
stdout
.
write
(
'
\b
'
*
prev_total_width
)
sys
.
stdout
.
write
(
'
\r
'
)
else
:
sys
.
stdout
.
write
(
'
\n
'
)
if
self
.
target
is
not
None
:
numdigits
=
int
(
np
.
log10
(
self
.
target
))
+
1
bar
=
(
'
%
'
+
str
(
numdigits
)
+
'd/
%
d ['
)
%
(
current
,
self
.
target
)
prog
=
float
(
current
)
/
self
.
target
prog_width
=
int
(
self
.
width
*
prog
)
if
prog_width
>
0
:
bar
+=
(
'='
*
(
prog_width
-
1
))
if
current
<
self
.
target
:
bar
+=
'>'
else
:
bar
+=
'='
bar
+=
(
'.'
*
(
self
.
width
-
prog_width
))
bar
+=
']'
else
:
bar
=
'
%7
d/Unknown'
%
current
self
.
_total_width
=
len
(
bar
)
sys
.
stdout
.
write
(
bar
)
time_per_unit
=
self
.
_estimate_step_duration
(
current
,
now
)
if
self
.
target
is
None
or
finalize
:
if
time_per_unit
>=
1
or
time_per_unit
==
0
:
info
+=
'
%.0
fs/
%
s'
%
(
time_per_unit
,
self
.
unit_name
)
elif
time_per_unit
>=
1e-3
:
info
+=
'
%.0
fms/
%
s'
%
(
time_per_unit
*
1e3
,
self
.
unit_name
)
else
:
info
+=
'
%.0
fus/
%
s'
%
(
time_per_unit
*
1e6
,
self
.
unit_name
)
else
:
eta
=
time_per_unit
*
(
self
.
target
-
current
)
if
eta
>
3600
:
eta_format
=
'
%
d:
%02
d:
%02
d'
%
(
eta
//
3600
,
(
eta
%
3600
)
//
60
,
eta
%
60
)
elif
eta
>
60
:
eta_format
=
'
%
d:
%02
d'
%
(
eta
//
60
,
eta
%
60
)
else
:
eta_format
=
'
%
ds'
%
eta
info
=
' - ETA:
%
s'
%
eta_format
for
k
in
self
.
_values_order
:
info
+=
' -
%
s:'
%
k
if
isinstance
(
self
.
_values
[
k
],
list
):
avg
=
np
.
mean
(
self
.
_values
[
k
][
0
]
/
max
(
1
,
self
.
_values
[
k
][
1
]))
if
abs
(
avg
)
>
1e-3
:
info
+=
'
%.4
f'
%
avg
else
:
info
+=
'
%.4
e'
%
avg
else
:
info
+=
'
%
s'
%
self
.
_values
[
k
]
self
.
_total_width
+=
len
(
info
)
if
prev_total_width
>
self
.
_total_width
:
info
+=
(
' '
*
(
prev_total_width
-
self
.
_total_width
))
if
finalize
:
info
+=
'
\n
'
sys
.
stdout
.
write
(
info
)
sys
.
stdout
.
flush
()
elif
self
.
verbose
==
2
:
if
finalize
:
numdigits
=
int
(
np
.
log10
(
self
.
target
))
+
1
count
=
(
'
%
'
+
str
(
numdigits
)
+
'd/
%
d'
)
%
(
current
,
self
.
target
)
info
=
count
+
info
for
k
in
self
.
_values_order
:
info
+=
' -
%
s:'
%
k
avg
=
np
.
mean
(
self
.
_values
[
k
][
0
]
/
max
(
1
,
self
.
_values
[
k
][
1
]))
if
avg
>
1e-3
:
info
+=
'
%.4
f'
%
avg
else
:
info
+=
'
%.4
e'
%
avg
info
+=
'
\n
'
sys
.
stdout
.
write
(
info
)
sys
.
stdout
.
flush
()
self
.
_last_update
=
now
def
add
(
self
,
n
,
values
=
None
):
self
.
update
(
self
.
_seen_so_far
+
n
,
values
)
def
_estimate_step_duration
(
self
,
current
,
now
):
"""Estimate the duration of a single step.
Given the step number `current` and the corresponding time `now`
this function returns an estimate for how long a single step
takes. If this is called before one step has been completed
(i.e. `current == 0`) then zero is given as an estimate. The duration
estimate ignores the duration of the (assumed to be non-representative)
first step for estimates when more steps are available (i.e. `current>1`).
Arguments:
current: Index of current step.
now: The current time.
Returns: Estimate of the duration of a single step.
"""
if
current
:
# there are a few special scenarios here:
# 1) somebody is calling the progress bar without ever supplying step 1
# 2) somebody is calling the progress bar and supplies step one mulitple
# times, e.g. as part of a finalizing call
# in these cases, we just fall back to the simple calculation
if
self
.
_time_after_first_step
is
not
None
and
current
>
1
:
time_per_unit
=
(
now
-
self
.
_time_after_first_step
)
/
(
current
-
1
)
else
:
time_per_unit
=
(
now
-
self
.
_start
)
/
current
if
current
==
1
:
self
.
_time_after_first_step
=
now
return
time_per_unit
else
:
return
0
def
default
(
method
):
"""Decorates a method to detect overrides in subclasses."""
method
.
_is_default
=
True
# pylint: disable=protected-access
return
method
def
is_default
(
method
):
"""Check if a method is decorated with the `default` wrapper."""
return
getattr
(
method
,
'_is_default'
,
False
)
class
SmoothedValue
(
object
):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def
__init__
(
self
,
window_size
=
20
,
fmt
=
None
):
if
fmt
is
None
:
fmt
=
"{median:.4f} ({global_avg:.4f})"
self
.
deque
=
deque
(
maxlen
=
window_size
)
self
.
total
=
0.0
self
.
count
=
0
self
.
fmt
=
fmt
def
update
(
self
,
value
,
n
=
1
):
self
.
deque
.
append
(
value
)
self
.
count
+=
n
self
.
total
+=
value
*
n
def
synchronize_between_processes
(
self
):
"""
Warning: does not synchronize the deque!
"""
if
not
is_dist_avail_and_initialized
():
return
t
=
torch
.
tensor
([
self
.
count
,
self
.
total
],
dtype
=
torch
.
float64
,
device
=
'cuda'
)
dist
.
barrier
()
dist
.
all_reduce
(
t
)
t
=
t
.
tolist
()
self
.
count
=
int
(
t
[
0
])
self
.
total
=
t
[
1
]
@property
def
median
(
self
):
d
=
torch
.
tensor
(
list
(
self
.
deque
))
return
d
.
median
()
.
item
()
@property
def
avg
(
self
):
d
=
torch
.
tensor
(
list
(
self
.
deque
),
dtype
=
torch
.
float32
)
return
d
.
mean
()
.
item
()
@property
def
global_avg
(
self
):
return
self
.
total
/
self
.
count
@property
def
max
(
self
):
return
max
(
self
.
deque
)
@property
def
value
(
self
):
return
self
.
deque
[
-
1
]
def
__str__
(
self
):
return
self
.
fmt
.
format
(
median
=
self
.
median
,
avg
=
self
.
avg
,
global_avg
=
self
.
global_avg
,
max
=
self
.
max
,
value
=
self
.
value
)
def
all_gather
(
data
):
"""
Run all_gather on arbitrary picklable datasets (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of datasets gathered from each rank
"""
world_size
=
get_world_size
()
if
world_size
==
1
:
return
[
data
]
# serialized to a Tensor
buffer
=
pickle
.
dumps
(
data
)
storage
=
torch
.
ByteStorage
.
from_buffer
(
buffer
)
tensor
=
torch
.
ByteTensor
(
storage
)
.
to
(
"cuda"
)
# obtain Tensor size of each rank
local_size
=
torch
.
tensor
([
tensor
.
numel
()],
device
=
"cuda"
)
size_list
=
[
torch
.
tensor
([
0
],
device
=
"cuda"
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
size_list
,
local_size
)
size_list
=
[
int
(
size
.
item
())
for
size
in
size_list
]
max_size
=
max
(
size_list
)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list
=
[]
for
_
in
size_list
:
tensor_list
.
append
(
torch
.
empty
((
max_size
,),
dtype
=
torch
.
uint8
,
device
=
"cuda"
))
if
local_size
!=
max_size
:
padding
=
torch
.
empty
(
size
=
(
max_size
-
local_size
,),
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
tensor
=
torch
.
cat
((
tensor
,
padding
),
dim
=
0
)
dist
.
all_gather
(
tensor_list
,
tensor
)
data_list
=
[]
for
size
,
tensor
in
zip
(
size_list
,
tensor_list
):
buffer
=
tensor
.
cpu
()
.
numpy
()
.
tobytes
()[:
size
]
data_list
.
append
(
pickle
.
loads
(
buffer
))
return
data_list
def
reduce_dict
(
input_dict
,
average
=
True
):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size
=
get_world_size
()
if
world_size
<
2
:
return
input_dict
with
torch
.
no_grad
():
names
=
[]
values
=
[]
# sort the keys so that they are consistent across processes
for
k
in
sorted
(
input_dict
.
keys
()):
names
.
append
(
k
)
values
.
append
(
input_dict
[
k
])
values
=
torch
.
stack
(
values
,
dim
=
0
)
dist
.
all_reduce
(
values
)
if
average
:
values
/=
world_size
reduced_dict
=
{
k
:
v
for
k
,
v
in
zip
(
names
,
values
)}
return
reduced_dict
class
MetricLogger
(
object
):
def
__init__
(
self
,
delimiter
=
"
\t
"
):
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
delimiter
=
delimiter
def
update
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
v
=
v
.
item
()
assert
isinstance
(
v
,
(
float
,
int
))
self
.
meters
[
k
]
.
update
(
v
)
def
__getattr__
(
self
,
attr
):
if
attr
in
self
.
meters
:
return
self
.
meters
[
attr
]
if
attr
in
self
.
__dict__
:
return
self
.
__dict__
[
attr
]
raise
AttributeError
(
"'{}' object has no attribute '{}'"
.
format
(
type
(
self
)
.
__name__
,
attr
))
def
__str__
(
self
):
loss_str
=
[]
for
name
,
meter
in
self
.
meters
.
items
():
loss_str
.
append
(
"{}: {}"
.
format
(
name
,
str
(
meter
))
)
return
self
.
delimiter
.
join
(
loss_str
)
def
synchronize_between_processes
(
self
):
for
meter
in
self
.
meters
.
values
():
meter
.
synchronize_between_processes
()
def
add_meter
(
self
,
name
,
meter
):
self
.
meters
[
name
]
=
meter
def
log_every
(
self
,
iterable
,
print_freq
,
header
=
None
):
i
=
0
if
not
header
:
header
=
''
start_time
=
time
.
time
()
end
=
time
.
time
()
iter_time
=
SmoothedValue
(
fmt
=
'{avg:.4f}'
)
data_time
=
SmoothedValue
(
fmt
=
'{avg:.4f}'
)
space_fmt
=
':'
+
str
(
len
(
str
(
len
(
iterable
))))
+
'd'
if
torch
.
cuda
.
is_available
():
log_msg
=
self
.
delimiter
.
join
([
header
,
'[{0'
+
space_fmt
+
'}/{1}]'
,
'eta: {eta}'
,
'{meters}'
,
'time: {time}'
,
'datasets: {datasets}'
,
'max mem: {memory:.0f}'
])
else
:
log_msg
=
self
.
delimiter
.
join
([
header
,
'[{0'
+
space_fmt
+
'}/{1}]'
,
'eta: {eta}'
,
'{meters}'
,
'time: {time}'
,
'datasets: {datasets}'
])
MB
=
1024.0
*
1024.0
for
obj
in
iterable
:
data_time
.
update
(
time
.
time
()
-
end
)
yield
obj
iter_time
.
update
(
time
.
time
()
-
end
)
if
i
%
print_freq
==
0
or
i
==
len
(
iterable
)
-
1
:
eta_seconds
=
iter_time
.
global_avg
*
(
len
(
iterable
)
-
i
)
eta_string
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_seconds
)))
if
torch
.
cuda
.
is_available
():
print
(
log_msg
.
format
(
i
,
len
(
iterable
),
eta
=
eta_string
,
meters
=
str
(
self
),
time
=
str
(
iter_time
),
data
=
str
(
data_time
),
memory
=
torch
.
cuda
.
max_memory_allocated
()
/
MB
))
else
:
print
(
log_msg
.
format
(
i
,
len
(
iterable
),
eta
=
eta_string
,
meters
=
str
(
self
),
time
=
str
(
iter_time
),
data
=
str
(
data_time
)))
i
+=
1
end
=
time
.
time
()
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
'{} Total time: {} ({:.4f} s / it)'
.
format
(
header
,
total_time_str
,
total_time
/
len
(
iterable
)))
def
collate_fn
(
batch
):
# return tuple(batch[0]), batch[1], tuple(batch[2])
return
tuple
(
zip
(
*
batch
))
def
warmup_lr_scheduler
(
optimizer
,
warmup_iters
,
warmup_factor
):
def
f
(
x
):
if
x
>=
warmup_iters
:
return
1
alpha
=
float
(
x
)
/
warmup_iters
return
warmup_factor
*
(
1
-
alpha
)
+
alpha
return
torch
.
optim
.
lr_scheduler
.
LambdaLR
(
optimizer
,
f
)
def
mkdir
(
path
):
try
:
os
.
makedirs
(
path
)
except
OSError
as
e
:
if
e
.
errno
!=
errno
.
EEXIST
:
raise
def
setup_for_distributed
(
is_master
):
"""
This function disables printing when not in master process
"""
import
builtins
as
__builtin__
builtin_print
=
__builtin__
.
print
def
print
(
*
args
,
**
kwargs
):
force
=
kwargs
.
pop
(
'force'
,
False
)
if
is_master
or
force
:
builtin_print
(
*
args
,
**
kwargs
)
__builtin__
.
print
=
print
def
is_dist_avail_and_initialized
():
if
not
dist
.
is_available
():
return
False
if
not
dist
.
is_initialized
():
return
False
return
True
def
get_world_size
():
if
not
is_dist_avail_and_initialized
():
return
1
return
dist
.
get_world_size
()
def
get_rank
():
if
not
is_dist_avail_and_initialized
():
return
0
return
dist
.
get_rank
()
def
is_main_process
():
return
get_rank
()
==
0
def
save_on_master
(
*
args
,
**
kwargs
):
if
is_main_process
():
torch
.
save
(
*
args
,
**
kwargs
)
def
init_distributed_mode
(
args
):
if
'RANK'
in
os
.
environ
and
'WORLD_SIZE'
in
os
.
environ
:
args
.
rank
=
int
(
os
.
environ
[
"RANK"
])
args
.
world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
args
.
gpu
=
int
(
os
.
environ
[
'LOCAL_RANK'
])
elif
'SLURM_PROCID'
in
os
.
environ
:
args
.
rank
=
int
(
os
.
environ
[
'SLURM_PROCID'
])
args
.
gpu
=
args
.
rank
%
torch
.
cuda
.
device_count
()
else
:
print
(
'Not using distributed mode'
)
args
.
distributed
=
False
return
args
.
distributed
=
True
torch
.
cuda
.
set_device
(
args
.
gpu
)
args
.
dist_backend
=
'nccl'
print
(
'| distributed init (rank {}): {}'
.
format
(
args
.
rank
,
args
.
dist_url
),
flush
=
True
)
torch
.
distributed
.
init_process_group
(
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
torch
.
distributed
.
barrier
()
setup_for_distributed
(
args
.
rank
==
0
)
def
setup
(
rank
,
world_size
,
addr
=
"127.0.0.1"
,
port
=
"1895"
):
os
.
environ
[
'MASTER_ADDR'
]
=
addr
os
.
environ
[
'MASTER_PORT'
]
=
port
dist
.
init_process_group
(
"nccl"
,
rank
=
rank
,
world_size
=
world_size
)
def
cleanup
():
dist
.
destroy_process_group
()
def
select
(
prob_a
):
num_
=
[
'choice_a'
,
'choice_b'
]
# 概率列表
r_
=
[
prob_a
,
1
-
prob_a
]
sum_
=
0
ran
=
random
.
random
()
for
num
,
r
in
zip
(
num_
,
r_
):
sum_
+=
r
if
ran
<
sum_
:
break
return
num
import
logging
def
get_logger
(
filename
,
verbosity
=
1
,
name
=
None
):
level_dict
=
{
0
:
logging
.
DEBUG
,
1
:
logging
.
INFO
,
2
:
logging
.
WARNING
}
formatter
=
logging
.
Formatter
(
"[
%(asctime)
s][
%(filename)
s][line:
%(lineno)
d][
%(levelname)
s]
%(message)
s"
)
logger
=
logging
.
getLogger
(
name
)
logger
.
setLevel
(
level_dict
[
verbosity
])
fh
=
logging
.
FileHandler
(
filename
,
"w"
)
fh
.
setFormatter
(
formatter
)
logger
.
addHandler
(
fh
)
sh
=
logging
.
StreamHandler
()
sh
.
setFormatter
(
formatter
)
logger
.
addHandler
(
sh
)
return
logger
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment