Commit 53003683 authored by Ilya Ovodov's avatar Ilya Ovodov

params repr layout with margins

parent b59c3474
...@@ -38,12 +38,33 @@ class AttrDict(OrderedDict): ...@@ -38,12 +38,33 @@ class AttrDict(OrderedDict):
assert '.' not in k, "AttrDict: attribute '" + k + "' is invalid ('.' char is not allowed)" assert '.' not in k, "AttrDict: attribute '" + k + "' is invalid ('.' char is not allowed)"
if isinstance(v, dict): if isinstance(v, dict):
self[k] = AttrDict(v) self[k] = AttrDict(v)
elif isinstance(v, list):
self[k] = [AttrDict(item) if isinstance(item, dict) else item for item in v]
def __repr__(self): def __repr__(self):
return ('{\n' + def write_item(item, margin='\n'):
'\n'.join([repr(x[0]) + ' : ' + repr(x[1]) + ',' if isinstance(item, dict):
for x in vars(self).items() if not x[0].startswith('__') and x[0] != 'data_root']) + s = '{'
' \n}') margin2 = margin + ' '
for k, v in item.items():
if not k.startswith('__') and k != 'data_root':
s += margin2 + "'{0}': ".format(k) + write_item(v, margin=margin2) + ","
if item.items():
s += margin
s += '}'
elif isinstance(item, (list, tuple)):
s = '[' if isinstance(item, list) else '('
for v in item:
if isinstance(v, dict):
s += margin + ' '
else:
s += ' '
s += write_item(v, margin=margin + ' ') + ","
s += ' ' + (']' if isinstance(item, list) else ')')
else:
s = repr(item)
return s
return write_item(self)
def has(self, name): def has(self, name):
''' '''
...@@ -95,11 +116,10 @@ class AttrDict(OrderedDict): ...@@ -95,11 +116,10 @@ class AttrDict(OrderedDict):
print('saved to ' + params_fn) print('saved to ' + params_fn)
def load_from_str(s, data_root): def load_from_str(s, data_root):
assert len(s) >= 2
assert s[0][0] == '{' and s[-1][-2:] == '}\n'
s = ''.join(s) s = ''.join(s)
s = s.replace('\n', '') s = s.replace('\n', '')
assert len(s) >= 2
assert s[0][0] == '{'
assert s[-1][-1] == '}'
params = ast.literal_eval(s) params = ast.literal_eval(s)
if data_root: if data_root:
params.data_root = data_root params.data_root = data_root
...@@ -121,13 +141,97 @@ class AttrDict(OrderedDict): ...@@ -121,13 +141,97 @@ class AttrDict(OrderedDict):
return params return params
if __name__=='__main__': if __name__=='__main__':
m = AttrDict(b = AttrDict(b1="b1v", b2 = "qwe"), a=1, c = "qweqweqwe") m = AttrDict(
data_root='abc',
model_name='NN_results/segmentation/unet',
data=AttrDict(
val_folds=(4,),
batch_size=100,
resize=(128, 800), # H, W
crop=(128, 256), # H, W
train_augmentations=AttrDict(
# HorizontalFlip = AttrDict(),
# RandomBrightnessContrast = AttrDict(),
),
crop_for_val=False,
add_coordconv=False,
),
model=AttrDict(
type='segmentation_models_pytorch.Unet',
params=AttrDict(
encoder_name='resnet34', # 'se_resnext50_32x4d' 'resnet34'
encoder_weights='imagenet',
),
# load_from = 'NN_results/segmentation/unet_66fa48/models/best.t7',
),
dann=AttrDict(
use_dann=False,
lambda_max=1.,
epochs=40,
weight=0.1,
),
loss=[
AttrDict(
type='torch.nn.BCEWithLogitsLoss',
params=AttrDict(),
),
AttrDict(
type='pytorch_toolbelt.losses.dice.DiceLoss',
params = AttrDict(
mode='multilabel',
log_loss=True,
smooth=1,
),
weight = 0.5,
),
],
optim='torch.optim.SGD',
optim_params=AttrDict(
lr=0.2,
momentum=0.9,
weight_decay=5e-4, # 0.001,
# nesterov = False,
),
lr_finder=AttrDict(
iters_num=200,
log_lr_start=-4,
log_lr_end=-0,
),
ls_cheduler='torch.optim.lr_scheduler.ReduceLROnPlateau',
clr=AttrDict(
warmup_epochs=1,
min_lr=0.0002,
max_lr=1e-1,
period_epochs=40,
scale_max_lr=0.95,
scale_min_lr=0.95,
),
ReduceLROnPlateau_params=AttrDict(
mode='min',
factor=0.5,
patience=10,
min_lr=2.e-4,
),
StepLR_params=AttrDict(
step_size=20,
gamma=0.5,
),
MultiStepLR_params=AttrDict(
milestones=[25, 50, 75, 100, 125, 150, 175, 200, 215, 230, 245, 260, 275, 290, 300],
gamma=0.5,
),
)
print(repr(m))
fn = 'test_' + m.hash() fn = 'test_' + m.hash()
m.save(fn) m.save(fn, can_overwrite=True)
mm = AttrDict.Load(fn) m.save(fn+'0', can_overwrite=True)
mm = AttrDict.load(fn+'.param.txt')
import os import os
os.remove(fn + '.param.txt') mm.save(fn, can_overwrite=True)
print(m) print(m)
print(mm) print(mm)
assert str(m)==str(mm) assert str(m)==str(mm)
\ No newline at end of file os.remove(fn + '.param.txt')
os.remove(fn + '0.param.txt')
{
'model_name': 'NN_results/segmentation/unet',
'data': {
'val_folds': ( 4, ),
'batch_size': 100,
'resize': ( 128, 800, ),
'crop': ( 128, 256, ),
'train_augmentations': {},
'crop_for_val': False,
'add_coordconv': False,
},
'model': {
'type': 'segmentation_models_pytorch.Unet',
'params': {
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet',
},
},
'dann': {
'use_dann': False,
'lambda_max': 1.0,
'epochs': 40,
'weight': 0.1,
},
'loss': [
{
'type': 'torch.nn.BCEWithLogitsLoss',
'params': {},
},
{
'type': 'pytorch_toolbelt.losses.dice.DiceLoss',
'params': {
'mode': 'multilabel',
'log_loss': True,
'smooth': 1,
},
'weight': 0.5,
}, ],
'optim': 'torch.optim.SGD',
'optim_params': {
'lr': 0.2,
'momentum': 0.9,
'weight_decay': 0.0005,
},
'lr_finder': {
'iters_num': 200,
'log_lr_start': -4,
'log_lr_end': 0,
},
'ls_cheduler': 'torch.optim.lr_scheduler.ReduceLROnPlateau',
'clr': {
'warmup_epochs': 1,
'min_lr': 0.0002,
'max_lr': 0.1,
'period_epochs': 40,
'scale_max_lr': 0.95,
'scale_min_lr': 0.95,
},
'ReduceLROnPlateau_params': {
'mode': 'min',
'factor': 0.5,
'patience': 10,
'min_lr': 0.0002,
},
'StepLR_params': {
'step_size': 20,
'gamma': 0.5,
},
'MultiStepLR_params': {
'milestones': [ 25, 50, 75, 100, 125, 150, 175, 200, 215, 230, 245, 260, 275, 290, 300, ],
'gamma': 0.5,
},
}
hash: d444b2
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment