Commit a0ccc3cf authored by Ilya Ovodov's avatar Ilya Ovodov

param.txt moved to model_name directory

parent 8a0c8960
......@@ -94,26 +94,25 @@ class AttrDict(OrderedDict):
assert self.has('data_root')
return os.path.join(self.data_root, self.get_model_name())
def save(self, base_fn = None, verbose = 1, can_overwrite = False, create_dirs = False):
def save(self, file_name=None, verbose = 1, can_overwrite = False, create_dirs = False):
'''
save to file adding '.param.txt' to name
'''
if base_fn is None:
base_fn = self.get_base_filename()
params_fn = base_fn + '.param.txt'
if file_name is None:
file_name = os.path.join(self.get_base_filename(), 'param.txt')
if not can_overwrite:
assert not os.path.exists(params_fn), "Can't save parameters to {}: File exists".format(params_fn)
assert not os.path.exists(file_name), "Can't save parameters to {}: File exists".format(params_fn)
if create_dirs:
dir_name = os.path.dirname(params_fn)
dir_name = os.path.dirname(file_name)
os.makedirs(dir_name, exist_ok=True)
with open(params_fn, 'w+') as f:
with open(file_name, 'w+') as f:
s = str(self)
s = s + '\nhash: ' + self.hash()
f.write(s)
if verbose >= 2:
print('params: '+ s)
if verbose >= 1:
print('saved to ' + params_fn)
print('saved to ' + str(file_name))
def load_from_str(s, data_root):
assert len(s) >= 2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment