Commit 8a0c8960 authored by Ilya Ovodov's avatar Ilya Ovodov

create_object and create_*() accept parameters in various forms, including…

create_object and create_*() accept parameters in various forms, including tuple; extra parameters for ctx.create_*() functions
parent 22f19afb
...@@ -137,7 +137,7 @@ class AttrDict(OrderedDict): ...@@ -137,7 +137,7 @@ class AttrDict(OrderedDict):
if verbose >= 2: if verbose >= 2:
print('params: '+ str(params) + '\nhash: ' + params.hash()) print('params: '+ str(params) + '\nhash: ' + params.hash())
if verbose >= 1: if verbose >= 1:
print('loaded from ' + params_fn) print('loaded from ' + str(params_fn))
return params return params
......
...@@ -4,14 +4,20 @@ from torch.nn.modules.loss import _Loss ...@@ -4,14 +4,20 @@ from torch.nn.modules.loss import _Loss
class SimpleLoss(_Loss): class SimpleLoss(_Loss):
def __init__(self, loss_func: _Loss, dict_key: str = None): '''
Wrapper over any usual loss function, that
1) stores calculated loss as `val` attribute
2) provides extra interfaces similar to CompositeLoss (`get_val`, `len`, `get_subval`)
3) if `key` is defined, loss is calculated against y_true[key] instead of y_true
'''
def __init__(self, loss_func: _Loss, dict_key = None):
super(SimpleLoss, self).__init__() super(SimpleLoss, self).__init__()
self.loss_func = loss_func self.loss_func = loss_func
self.dict_key = dict_key self.dict_key = dict_key
self.val = None self.val = None
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if self.dict_key and isinstance(y_true, dict): if self.dict_key:
y_true = y_true[self.dict_key] y_true = y_true[self.dict_key]
self.val = self.loss_func(y_pred, y_true) self.val = self.loss_func(y_pred, y_true)
return self.val return self.val
...@@ -24,9 +30,21 @@ class SimpleLoss(_Loss): ...@@ -24,9 +30,21 @@ class SimpleLoss(_Loss):
return self.val return self.val
return call return call
def get_subval(self, index):
def call(*kargs, **kwargs):
return None
return call
class CompositeLoss(_Loss): class CompositeLoss(_Loss):
'''
Wrapper to calculate weighted sum of losses
Also stores calculates value and calculated values of each loss being composed of
'''
def __init__(self, loss_funcs: List): def __init__(self, loss_funcs: List):
'''
:param loss_funcs: list of (loss_function, weight,)
'''
super(CompositeLoss, self).__init__() super(CompositeLoss, self).__init__()
self.loss_funcs = loss_funcs self.loss_funcs = loss_funcs
self.val = None self.val = None
...@@ -38,14 +56,24 @@ class CompositeLoss(_Loss): ...@@ -38,14 +56,24 @@ class CompositeLoss(_Loss):
return self.val return self.val
def __len__(self): def __len__(self):
'''
:return: number of component losses
'''
return len(self.loss_funcs) return len(self.loss_funcs)
def get_val(self): def get_val(self):
'''
:return: callable to get last calculated value
'''
def call(*kargs, **kwargs): def call(*kargs, **kwargs):
return self.val return self.val
return call return call
def get_subval(self, index): def get_subval(self, index):
'''
:param index: index of component loss
:return: callable to get last calculates value of component loss by index
'''
def call(*kargs, **kwargs): def call(*kargs, **kwargs):
return self.sub_vals[index] return self.sub_vals[index]
return call return call
......
...@@ -8,52 +8,68 @@ def create_object(params: dict, eval_func: Callable = eval, *args, **kwargs) -> ...@@ -8,52 +8,68 @@ def create_object(params: dict, eval_func: Callable = eval, *args, **kwargs) ->
''' '''
Create object of type params['type'] using *args, **kwargs and parameters params['params']. Create object of type params['type'] using *args, **kwargs and parameters params['params'].
params['params'] is optional. params['params'] is optional.
Also other options available (see below)
Example: Examples:
create_object({'type': 'torch.nn.Conv2d', 'params': {'in_channels': 64, 'out_channels': 32, 'kernel_size': 3} }) create_object({'type': 'torch.nn.Conv2d', 'params': {'in_channels': 64, 'out_channels': 32, 'kernel_size': 3} })
create_object({'type': 'torch.nn.BCELoss'}) create_object({'type': 'torch.nn.BCELoss'})
create_object(['torch.nn.Conv2d', {'in_channels': 64, 'out_channels': 32, 'kernel_size': 3} },])
:param params: dict describing the object. Must contain ['type'] and optional ['params'] create_object('torch.nn.BCELoss')
create_object([['torch.nn.BCELoss',], ['torch.nn.Conv2d', {'in_channels': 64, 'out_channels': 32, 'kernel_size': 3} }]])
:param params: options available
1) dict describing the object. Must contain ['type']: str and optional ['params']: dict with constructor params
2) tuple or list with 1 (type,) or 2 members (type, params) or more (other members are ignored)
3) string contaning type
4) list or tuple of options listed above. List of objects is created
:param eval_func: function to convert ['type'] string to object class. Usual usecase is calling eval(x) :param eval_func: function to convert ['type'] string to object class. Usual usecase is calling eval(x)
in a context of the calling module in a context of the calling module
:param args: args to be passed to the constructor :param args: args to be passed to the constructor
:param kwargs: :param kwargs: args to be passed to the constructor
:return: created object :return: created object or list of objects if params is list of params.
''' '''
all_kwargs = kwargs.copy() if isinstance(params, dict):
p = params.get('params', dict()) assert isinstance(params['type'], str)
all_kwargs.update(p) all_kwargs = kwargs.copy()
print('creating: ', params['type'], repr(dict(p))) p = params.get('params', dict())
obj = eval_func(params['type'])(*args, **all_kwargs) assert isinstance(p, dict)
return obj all_kwargs.update(p)
print('creating: ', params['type'], repr(dict(p)))
obj = eval_func(params['type'])(*args, **all_kwargs)
return obj
elif isinstance(params, (list, tuple)):
if len(params) >= 1 and isinstance(params[0], str):
if len(params) == 1:
return create_object({'type': params[0]}, eval_func, *args, **kwargs)
elif len(params) >= 2 and isinstance(params[1], dict):
return create_object({'type': params[0], 'params': params[1]}, eval_func, *args, **kwargs)
return [create_object(pi, eval_func, *args, **kwargs) for pi in params]
elif isinstance(params, str):
return create_object({'type': params}, eval_func, *args, **kwargs)
else:
raise Exception("Invalid call to create_object: params is {}".format(params))
def create_optional_object(params: dict, key: str, eval_func = eval, *args, **kwargs) -> object: def create_optional_object(params: dict, key: str, eval_func = eval, *args, **kwargs) -> object:
''' '''
Create object of type params[<key>]['type'] using *args, **kwargs and parameters params[<key>]['params']. Create object of type and with parameters defined in optional params[<key>]
If no params[<key>] or params[<key>]['type'] is defined, returns None. If no params[<key>] or params[<key>]['type'] is defined, returns None.
params[<key>]['params'] is optional
Example: See create_object() for details
create_object({'type': 'torch.nn.Conv2d', 'params': {'in_channels': 64, 'out_channels': 32, 'kernel_size': 3} })
create_object({'type': 'torch.nn.BCELoss'})
:param params: dict containig params[<key>] describing the object. params[<key>] must contain ['type'] and optional ['params'] :param params: dict containig optional params[<key>] describing the object.
for object to be created if params[<key>] is defined, it must be valid parameter for create_object()
:param key: string, key in params dict :param key: string, key in params dict
:param eval_func: function to convert ['type'] string to object class. Usual usecase is calling eval(x) :param eval_func: function to convert ['type'] string to object class. Usual usecase is calling eval(x)
in a context of the calling module in a context of the calling module
:param args: args to be passed to the constructor :param args: args to be passed to the constructor
:param kwargs: :param kwargs: args to be passed to the constructor
:return: created object :return: created object
''' '''
p = params.get(key) p = params.get(key)
if not p: if not p:
print('NO '+ key + ' is set') print('NO '+ key + ' is set')
return None return None
if not p.get('type'):
print('NO '+ key + '["type"] is set')
return None
return create_object(p, eval_func, *args, **kwargs) return create_object(p, eval_func, *args, **kwargs)
...@@ -75,32 +91,32 @@ class Context(AttrDict): ...@@ -75,32 +91,32 @@ class Context(AttrDict):
optimizer = None, optimizer = None,
) )
def create_model(self, train=True): def create_model(self, train=True, *args, **kwargs):
return create_model(self, train) return create_model(self, train, *args, **kwargs)
def create_optim(self): def create_optim(self, *args, **kwargs):
return create_optim(self) return create_optim(self, *args, **kwargs)
def create_lr_scheduler(self): def create_lr_scheduler(self, *args, **kwargs):
return create_lr_scheduler(self) return create_lr_scheduler(self, *args, **kwargs)
def create_loss(self): def create_loss(self, *args, **kwargs):
return create_loss(self) return create_loss(self, *args, **kwargs)
def create_model(context: Context, train=True) -> torch.nn.Module: def create_model(context: Context, train=True, *args, **kwargs) -> torch.nn.Module:
''' '''
Creates model using standard context structure (context.params.model) Creates model using standard context structure (context.params.model)
Stores model in context.net Stores model in context.net
:return updated context :return updated context
''' '''
context.net = create_object(context.params.model, context.eval_func) context.net = create_object(context.params.model, context.eval_func, *args, **kwargs)
context.net.train(train) context.net.train(train)
return context.net return context.net
def create_optim(context: Context) -> torch.optim.Optimizer: def create_optim(context: Context, *args, **kwargs) -> torch.optim.Optimizer:
''' '''
Creates optimizer using standard context structure (context.params.optimizer) Creates optimizer using standard context structure (context.params.optimizer)
Stores optimizer in context.optimizer Stores optimizer in context.optimizer
...@@ -108,11 +124,11 @@ def create_optim(context: Context) -> torch.optim.Optimizer: ...@@ -108,11 +124,11 @@ def create_optim(context: Context) -> torch.optim.Optimizer:
:return updated context :return updated context
''' '''
context.optimizer = create_object(context.params.optimizer, context.eval_func, context.net.parameters()) context.optimizer = create_object(context.params.optimizer, context.eval_func, context.net.parameters(), *args, **kwargs)
return context.optimizer return context.optimizer
def create_lr_scheduler(context: Context) -> object: def create_lr_scheduler(context: Context, *args, **kwargs) -> object:
''' '''
Creates lr_scheduler using standard context structure (context.params.lr_scheduler) Creates lr_scheduler using standard context structure (context.params.lr_scheduler)
Stores lr_scheduler in context.lr_scheduler Stores lr_scheduler in context.lr_scheduler
...@@ -120,39 +136,47 @@ def create_lr_scheduler(context: Context) -> object: ...@@ -120,39 +136,47 @@ def create_lr_scheduler(context: Context) -> object:
:return updated context :return updated context
''' '''
context.lr_scheduler = create_object(context.params.lr_scheduler, context.eval_func, optimizer=context.optimizer) context.lr_scheduler = create_object(context.params.lr_scheduler, context.eval_func, *args, optimizer=context.optimizer, **kwargs)
return context.lr_scheduler return context.lr_scheduler
def create_loss(context: Context) -> torch.nn.modules.loss._Loss: def create_loss(context: Context, *args, **kwargs) -> torch.nn.modules.loss._Loss:
''' '''
Creates loss using standard context structure (context.params.loss) Creates loss using standard context structure (context.params.loss)
Stores loss in context.loss Stores loss in context.loss
:return updated context :return updated context
''' '''
context.loss = CreateCompositeLoss(context.params.loss, context.eval_func) context.loss = CreateCompositeLoss(context.params.loss, context.eval_func, *args, **kwargs)
return context.loss return context.loss
def CreateCompositeLoss(loss_params: dict, eval_func=eval) -> torch.nn.modules.loss._Loss: def CreateCompositeLoss(loss_params: dict, eval_func=eval, *args, **kwargs) -> torch.nn.modules.loss._Loss:
''' '''
creates loss using loss_params creates loss using loss_params
:param loss_params: dict can be dict (to create single loss or list of loss_params (to create composite loss) :param loss_params: params in any form accepted by to `creat_object()`
loss dict can contain optional 'weight' (for composite loss) and 'mean' values. if `loss_params` describes list of losses. `CompositeLoss` is created, otherwise `SimpleLoss`.
if loss is described in loss_params as dict, it can contain keys:
'key' - to calculate loss against `y_true[key]` instead of `y_true`
'weight' - weight of loss in composite loss
'mean' - to create `MeanLoss` (calculate loss over channels and then mean over channels).
:param eval_func: :param eval_func:
:return: :return: SimpleLoss or CompositeLoss object
''' '''
if isinstance(loss_params, dict): loss = create_object(loss_params, eval_func, *args, **kwargs)
loss = create_object(loss_params, eval_func) if not isinstance(loss, (list, tuple)):
if loss_params.get('mean', False): key = None
loss = MeanLoss(loss) if isinstance(loss_params, dict):
loss = SimpleLoss(loss, loss_params.get('key')) if loss_params.get('mean', False):
loss = MeanLoss(loss)
key = loss_params.get('key')
loss = SimpleLoss(loss, key)
return loss return loss
else: else:
loss_funcs = [] loss_funcs = []
for loss_param in loss_params: for loss_param in loss_params:
loss_i = CreateCompositeLoss(loss_param, eval_func=eval_func) loss_i = CreateCompositeLoss(loss_param, eval_func=eval_func, *args, **kwargs)
loss_funcs.append((loss_i, loss_param.get('weight', 1.),)) weight = loss_param.get('weight', 1.) if isinstance(loss_param, dict) else 1.
loss_funcs.append((loss_i, weight,))
return CompositeLoss(loss_funcs) return CompositeLoss(loss_funcs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment