Source code for linna.nnutils

import json
import logging
import os
import shutil
import numpy as np
import scipy.misc 
try:
    from StringIO import StringIO  # Python 2.7
except ImportError:
    from io import BytesIO         # Python 3.x
import torch
import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot as plt

[docs]class Params(): """Class that loads hyperparameters from a json file. Example: ``` params = Params(json_path) print(params.learning_rate) params.learning_rate = 0.5 # change the value of learning_rate in params ``` """ def __init__(self, json_path): with open(json_path) as f: params = json.load(f) self.__dict__.update(params)
[docs] def save(self, json_path): with open(json_path, 'w') as f: json.dump(self.__dict__, f, indent=4)
[docs] def update(self, json_path): """Loads parameters from json file""" with open(json_path) as f: params = json.load(f) self.__dict__.update(params)
@property
[docs] def dict(self): """Gives dict-like access to Params instance by `params.dict['learning_rate']""" return self.__dict__
[docs]class RunningAverage(): """A simple class that maintains the running average of a quantity Example: ``` loss_avg = RunningAverage() loss_avg.update(2) loss_avg.update(4) loss_avg() = 3 ``` """ def __init__(self): self.steps = 0 self.total = 0
[docs] def update(self, val): self.total += val self.steps += 1
[docs] def __call__(self): return self.total/float(self.steps)
[docs]def set_logger(log_path): """Set the logger to log info in terminal and file `log_path`. In general, it is useful to have a logger so that every output to the terminal is saved in a permanent file. Here we save it to `model_dir/train.log`. Example: ``` logging.info("Starting training...") ``` Args: log_path: (string) where to log """ logger = logging.getLogger() logger.setLevel(logging.INFO) if not logger.handlers: # Logging to a file file_handler = logging.FileHandler(log_path) file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) logger.addHandler(file_handler) # Logging to console stream_handler = logging.StreamHandler() stream_handler.setFormatter(logging.Formatter('%(message)s')) logger.addHandler(stream_handler)
[docs]def save_dict_to_json(d, json_path): """Saves dict of floats in json file Args: d: (dict) of float-castable values (np.float, int, float, etc.) json_path: (string) path to json file """ with open(json_path, 'w') as f: # We need to convert the values to float for json (it doesn't accept np.array, np.float, ) d = {k: float(v) for k, v in d.items()} json.dump(d, f, indent=4)
[docs]def save_checkpoint(state, is_best, checkpoint): """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves checkpoint + 'best.pth.tar' Args: state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict is_best: (bool) True if it is the best model seen till now checkpoint: (string) folder where parameters are to be saved """ filepath = os.path.join(checkpoint, 'last.pth.tar') if not os.path.exists(checkpoint): print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint)) os.mkdir(checkpoint) else: pass #print("Checkpoint Directory exists! ") torch.save(state, filepath) if is_best: shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar'))
[docs]def load_checkpoint(checkpoint, model, optimizer=None, device=None, ismpi=False): """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of optimizer assuming it is present in checkpoint. Args: checkpoint: (string) filename which needs to be loaded model: (torch.nn.Module) model for which the parameters are loaded optimizer: (torch.optim) optional: resume optimizer from checkpoint """ if not os.path.exists(checkpoint): raise("File doesn't exist {}".format(checkpoint)) if device is not None: checkpoint = torch.load(checkpoint, map_location=device) else: checkpoint = torch.load(checkpoint) if ismpi: model.load_state_dict(checkpoint['mpi_state_dict']) else: model.load_state_dict(checkpoint['state_dict']) if optimizer: optimizer.load_state_dict(checkpoint['optim_dict']) return checkpoint
[docs]def gen_plot(plotarr, shape): """Create a pyplot plot and save to buffer.""" plt.figure(figsize=shape) labelArr=["predict", "data"] for i, arr in enumerate(plotarr): plt.plot(arr, label=labelArr[i]) plt.ylabel("number(h^3Mpc^-3/dex)") plt.legend() plt.title("smf") try: s = StringIO() except: s = BytesIO() plt.savefig(s, format='png') plt.close() s.seek(0) return s