Commit 8a0c8960 authored by Ilya Ovodov's avatar Ilya Ovodov

create_object and create_*() accept parameters in various forms, including…

create_object and create_*() accept parameters in various forms, including tuple; extra parameters for ctx.create_*() functions
parent 22f19afb
......@@ -137,7 +137,7 @@ class AttrDict(OrderedDict):
if verbose >= 2:
print('params: '+ str(params) + '\nhash: ' + params.hash())
if verbose >= 1:
print('loaded from ' + params_fn)
print('loaded from ' + str(params_fn))
return params
......
......@@ -4,14 +4,20 @@ from torch.nn.modules.loss import _Loss
class SimpleLoss(_Loss):
def __init__(self, loss_func: _Loss, dict_key: str = None):
'''
Wrapper over any usual loss function, that
1) stores calculated loss as `val` attribute
2) provides extra interfaces similar to CompositeLoss (`get_val`, `len`, `get_subval`)
3) if `key` is defined, loss is calculated against y_true[key] instead of y_true
'''
def __init__(self, loss_func: _Loss, dict_key = None):
super(SimpleLoss, self).__init__()
self.loss_func = loss_func
self.dict_key = dict_key
self.val = None
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if self.dict_key and isinstance(y_true, dict):
if self.dict_key:
y_true = y_true[self.dict_key]
self.val = self.loss_func(y_pred, y_true)
return self.val
......@@ -24,9 +30,21 @@ class SimpleLoss(_Loss):
return self.val
return call
def get_subval(self, index):
def call(*kargs, **kwargs):
return None
return call
class CompositeLoss(_Loss):
'''
Wrapper to calculate weighted sum of losses
Also stores calculates value and calculated values of each loss being composed of
'''
def __init__(self, loss_funcs: List):
'''
:param loss_funcs: list of (loss_function, weight,)
'''
super(CompositeLoss, self).__init__()
self.loss_funcs = loss_funcs
self.val = None
......@@ -38,14 +56,24 @@ class CompositeLoss(_Loss):
return self.val
def __len__(self):
'''
:return: number of component losses
'''
return len(self.loss_funcs)
def get_val(self):
'''
:return: callable to get last calculated value
'''
def call(*kargs, **kwargs):
return self.val
return call
def get_subval(self, index):
'''
:param index: index of component loss
:return: callable to get last calculates value of component loss by index
'''
def call(*kargs, **kwargs):
return self.sub_vals[index]
return call
......
......@@ -8,52 +8,68 @@ def create_object(params: dict, eval_func: Callable = eval, *args, **kwargs) ->
'''
Create object of type params['type'] using *args, **kwargs and parameters params['params'].
params['params'] is optional.
Also other options available (see below)
Example:
Examples:
create_object({'type': 'torch.nn.Conv2d', 'params': {'in_channels': 64, 'out_channels': 32, 'kernel_size': 3} })
create_object({'type': 'torch.nn.BCELoss'})
:param params: dict describing the object. Must contain ['type'] and optional ['params']
create_object(['torch.nn.Conv2d', {'in_channels': 64, 'out_channels': 32, 'kernel_size': 3} },])
create_object('torch.nn.BCELoss')
create_object([['torch.nn.BCELoss',], ['torch.nn.Conv2d', {'in_channels': 64, 'out_channels': 32, 'kernel_size': 3} }]])
:param params: options available
1) dict describing the object. Must contain ['type']: str and optional ['params']: dict with constructor params
2) tuple or list with 1 (type,) or 2 members (type, params) or more (other members are ignored)
3) string contaning type
4) list or tuple of options listed above. List of objects is created
:param eval_func: function to convert ['type'] string to object class. Usual usecase is calling eval(x)
in a context of the calling module
:param args: args to be passed to the constructor
:param kwargs:
:return: created object
:param kwargs: args to be passed to the constructor
:return: created object or list of objects if params is list of params.
'''
all_kwargs = kwargs.copy()
p = params.get('params', dict())
all_kwargs.update(p)
print('creating: ', params['type'], repr(dict(p)))
obj = eval_func(params['type'])(*args, **all_kwargs)
return obj
if isinstance(params, dict):
assert isinstance(params['type'], str)
all_kwargs = kwargs.copy()
p = params.get('params', dict())
assert isinstance(p, dict)
all_kwargs.update(p)
print('creating: ', params['type'], repr(dict(p)))
obj = eval_func(params['type'])(*args, **all_kwargs)
return obj
elif isinstance(params, (list, tuple)):
if len(params) >= 1 and isinstance(params[0], str):
if len(params) == 1:
return create_object({'type': params[0]}, eval_func, *args, **kwargs)
elif len(params) >= 2 and isinstance(params[1], dict):
return create_object({'type': params[0], 'params': params[1]}, eval_func, *args, **kwargs)
return [create_object(pi, eval_func, *args, **kwargs) for pi in params]
elif isinstance(params, str):
return create_object({'type': params}, eval_func, *args, **kwargs)
else:
raise Exception("Invalid call to create_object: params is {}".format(params))
def create_optional_object(params: dict, key: str, eval_func = eval, *args, **kwargs) -> object:
'''
Create object of type params[<key>]['type'] using *args, **kwargs and parameters params[<key>]['params'].
Create object of type and with parameters defined in optional params[<key>]
If no params[<key>] or params[<key>]['type'] is defined, returns None.
params[<key>]['params'] is optional
Example:
create_object({'type': 'torch.nn.Conv2d', 'params': {'in_channels': 64, 'out_channels': 32, 'kernel_size': 3} })
create_object({'type': 'torch.nn.BCELoss'})
See create_object() for details
:param params: dict containig params[<key>] describing the object. params[<key>] must contain ['type'] and optional ['params']
for object to be created
:param params: dict containig optional params[<key>] describing the object.
if params[<key>] is defined, it must be valid parameter for create_object()
:param key: string, key in params dict
:param eval_func: function to convert ['type'] string to object class. Usual usecase is calling eval(x)
in a context of the calling module
:param args: args to be passed to the constructor
:param kwargs:
:param kwargs: args to be passed to the constructor
:return: created object
'''
p = params.get(key)
if not p:
print('NO '+ key + ' is set')
return None
if not p.get('type'):
print('NO '+ key + '["type"] is set')
return None
return create_object(p, eval_func, *args, **kwargs)
......@@ -75,32 +91,32 @@ class Context(AttrDict):
optimizer = None,
)
def create_model(self, train=True):
return create_model(self, train)
def create_model(self, train=True, *args, **kwargs):
return create_model(self, train, *args, **kwargs)
def create_optim(self):
return create_optim(self)
def create_optim(self, *args, **kwargs):
return create_optim(self, *args, **kwargs)
def create_lr_scheduler(self):
return create_lr_scheduler(self)
def create_lr_scheduler(self, *args, **kwargs):
return create_lr_scheduler(self, *args, **kwargs)
def create_loss(self):
return create_loss(self)
def create_loss(self, *args, **kwargs):
return create_loss(self, *args, **kwargs)
def create_model(context: Context, train=True) -> torch.nn.Module:
def create_model(context: Context, train=True, *args, **kwargs) -> torch.nn.Module:
'''
Creates model using standard context structure (context.params.model)
Stores model in context.net
:return updated context
'''
context.net = create_object(context.params.model, context.eval_func)
context.net = create_object(context.params.model, context.eval_func, *args, **kwargs)
context.net.train(train)
return context.net
def create_optim(context: Context) -> torch.optim.Optimizer:
def create_optim(context: Context, *args, **kwargs) -> torch.optim.Optimizer:
'''
Creates optimizer using standard context structure (context.params.optimizer)
Stores optimizer in context.optimizer
......@@ -108,11 +124,11 @@ def create_optim(context: Context) -> torch.optim.Optimizer:
:return updated context
'''
context.optimizer = create_object(context.params.optimizer, context.eval_func, context.net.parameters())
context.optimizer = create_object(context.params.optimizer, context.eval_func, context.net.parameters(), *args, **kwargs)
return context.optimizer
def create_lr_scheduler(context: Context) -> object:
def create_lr_scheduler(context: Context, *args, **kwargs) -> object:
'''
Creates lr_scheduler using standard context structure (context.params.lr_scheduler)
Stores lr_scheduler in context.lr_scheduler
......@@ -120,39 +136,47 @@ def create_lr_scheduler(context: Context) -> object:
:return updated context
'''
context.lr_scheduler = create_object(context.params.lr_scheduler, context.eval_func, optimizer=context.optimizer)
context.lr_scheduler = create_object(context.params.lr_scheduler, context.eval_func, *args, optimizer=context.optimizer, **kwargs)
return context.lr_scheduler
def create_loss(context: Context) -> torch.nn.modules.loss._Loss:
def create_loss(context: Context, *args, **kwargs) -> torch.nn.modules.loss._Loss:
'''
Creates loss using standard context structure (context.params.loss)
Stores loss in context.loss
:return updated context
'''
context.loss = CreateCompositeLoss(context.params.loss, context.eval_func)
context.loss = CreateCompositeLoss(context.params.loss, context.eval_func, *args, **kwargs)
return context.loss
def CreateCompositeLoss(loss_params: dict, eval_func=eval) -> torch.nn.modules.loss._Loss:
def CreateCompositeLoss(loss_params: dict, eval_func=eval, *args, **kwargs) -> torch.nn.modules.loss._Loss:
'''
creates loss using loss_params
:param loss_params: dict can be dict (to create single loss or list of loss_params (to create composite loss)
loss dict can contain optional 'weight' (for composite loss) and 'mean' values.
:param loss_params: params in any form accepted by to `creat_object()`
if `loss_params` describes list of losses. `CompositeLoss` is created, otherwise `SimpleLoss`.
if loss is described in loss_params as dict, it can contain keys:
'key' - to calculate loss against `y_true[key]` instead of `y_true`
'weight' - weight of loss in composite loss
'mean' - to create `MeanLoss` (calculate loss over channels and then mean over channels).
:param eval_func:
:return:
:return: SimpleLoss or CompositeLoss object
'''
if isinstance(loss_params, dict):
loss = create_object(loss_params, eval_func)
if loss_params.get('mean', False):
loss = MeanLoss(loss)
loss = SimpleLoss(loss, loss_params.get('key'))
loss = create_object(loss_params, eval_func, *args, **kwargs)
if not isinstance(loss, (list, tuple)):
key = None
if isinstance(loss_params, dict):
if loss_params.get('mean', False):
loss = MeanLoss(loss)
key = loss_params.get('key')
loss = SimpleLoss(loss, key)
return loss
else:
loss_funcs = []
for loss_param in loss_params:
loss_i = CreateCompositeLoss(loss_param, eval_func=eval_func)
loss_funcs.append((loss_i, loss_param.get('weight', 1.),))
loss_i = CreateCompositeLoss(loss_param, eval_func=eval_func, *args, **kwargs)
weight = loss_param.get('weight', 1.) if isinstance(loss_param, dict) else 1.
loss_funcs.append((loss_i, weight,))
return CompositeLoss(loss_funcs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment