Source code for mmv_im2im.proj_trainer

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
from pathlib import Path
from importlib import import_module
import pytorch_lightning as pl
import torch
from mmv_im2im.data_modules import get_data_module
from mmv_im2im.utils.misc import parse_ops_list
import pyrallis

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

###############################################################################

log = logging.getLogger(__name__)

###############################################################################


[docs]class ProjectTrainer(object): """ entry for training models Parameters ---------- cfg: configuration """ def __init__(self, cfg): # seed everything before start pl.seed_everything(123, workers=True) # extract the three major chuck of the config self.model_cfg = cfg.model self.train_cfg = cfg.trainer self.data_cfg = cfg.data # define variables self.model = None self.data = None
[docs] def run_training(self): # set up data self.data = get_data_module(self.data_cfg) # set up model model_category = self.model_cfg.framework model_module = import_module(f"mmv_im2im.models.pl_{model_category}") my_model_func = getattr(model_module, "Model") self.model = my_model_func(self.model_cfg, verbose=self.train_cfg.verbose) if self.model_cfg.model_extra is not None: if "resume" in self.model_cfg.model_extra: self.model = self.model.load_from_checkpoint( self.model_cfg.model_extra["resume"] ) elif "pre-train" in self.model_cfg.model_extra: pre_train = torch.load(self.model_cfg.model_extra["pre-train"]) # TODO: hacky solution to remove a wrongly registered key pre_train["state_dict"].pop("criterion.xym", None) self.model.load_state_dict(pre_train["state_dict"]) # set up training if self.train_cfg.callbacks is None: callback_list = [] else: callback_list = parse_ops_list(self.train_cfg.callbacks) trainer = pl.Trainer(callbacks=callback_list, **self.train_cfg.params) # save the configuration in the log directory save_path = Path(trainer.log_dir) save_path.mkdir(parents=True, exist_ok=True) pyrallis.dump(self.model_cfg, open(save_path / Path("model_config.yaml"), "w")) pyrallis.dump(self.train_cfg, open(save_path / Path("train_config.yaml"), "w")) pyrallis.dump(self.data_cfg, open(save_path / Path("data_config.yaml"), "w")) # start training print("start training ... ") trainer.fit(model=self.model, datamodule=self.data)