import json
import logging
import os
import shutil
import numpy as np
import scipy.misc
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
[docs]class Params():
"""Class that loads hyperparameters from a json file.
Example:
```
params = Params(json_path)
print(params.learning_rate)
params.learning_rate = 0.5 # change the value of learning_rate in params
```
"""
def __init__(self, json_path):
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
[docs] def save(self, json_path):
with open(json_path, 'w') as f:
json.dump(self.__dict__, f, indent=4)
[docs] def update(self, json_path):
"""Loads parameters from json file"""
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
@property
[docs] def dict(self):
"""Gives dict-like access to Params instance by `params.dict['learning_rate']"""
return self.__dict__
[docs]class RunningAverage():
"""A simple class that maintains the running average of a quantity
Example:
```
loss_avg = RunningAverage()
loss_avg.update(2)
loss_avg.update(4)
loss_avg() = 3
```
"""
def __init__(self):
self.steps = 0
self.total = 0
[docs] def update(self, val):
self.total += val
self.steps += 1
[docs] def __call__(self):
return self.total/float(self.steps)
[docs]def set_logger(log_path):
"""Set the logger to log info in terminal and file `log_path`.
In general, it is useful to have a logger so that every output to the terminal is saved
in a permanent file. Here we save it to `model_dir/train.log`.
Example:
```
logging.info("Starting training...")
```
Args:
log_path: (string) where to log
"""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if not logger.handlers:
# Logging to a file
file_handler = logging.FileHandler(log_path)
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logger.addHandler(file_handler)
# Logging to console
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter('%(message)s'))
logger.addHandler(stream_handler)
[docs]def save_dict_to_json(d, json_path):
"""Saves dict of floats in json file
Args:
d: (dict) of float-castable values (np.float, int, float, etc.)
json_path: (string) path to json file
"""
with open(json_path, 'w') as f:
# We need to convert the values to float for json (it doesn't accept np.array, np.float, )
d = {k: float(v) for k, v in d.items()}
json.dump(d, f, indent=4)
[docs]def save_checkpoint(state, is_best, checkpoint):
"""Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves
checkpoint + 'best.pth.tar'
Args:
state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict
is_best: (bool) True if it is the best model seen till now
checkpoint: (string) folder where parameters are to be saved
"""
filepath = os.path.join(checkpoint, 'last.pth.tar')
if not os.path.exists(checkpoint):
print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint))
os.mkdir(checkpoint)
else:
pass
#print("Checkpoint Directory exists! ")
torch.save(state, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar'))
[docs]def load_checkpoint(checkpoint, model, optimizer=None, device=None, ismpi=False):
"""Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of
optimizer assuming it is present in checkpoint.
Args:
checkpoint: (string) filename which needs to be loaded
model: (torch.nn.Module) model for which the parameters are loaded
optimizer: (torch.optim) optional: resume optimizer from checkpoint
"""
if not os.path.exists(checkpoint):
raise("File doesn't exist {}".format(checkpoint))
if device is not None:
checkpoint = torch.load(checkpoint, map_location=device)
else:
checkpoint = torch.load(checkpoint)
if ismpi:
model.load_state_dict(checkpoint['mpi_state_dict'])
else:
model.load_state_dict(checkpoint['state_dict'])
if optimizer:
optimizer.load_state_dict(checkpoint['optim_dict'])
return checkpoint
[docs]def gen_plot(plotarr, shape):
"""Create a pyplot plot and save to buffer."""
plt.figure(figsize=shape)
labelArr=["predict", "data"]
for i, arr in enumerate(plotarr):
plt.plot(arr, label=labelArr[i])
plt.ylabel("number(h^3Mpc^-3/dex)")
plt.legend()
plt.title("smf")
try:
s = StringIO()
except:
s = BytesIO()
plt.savefig(s, format='png')
plt.close()
s.seek(0)
return s