#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Adapted from https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/engine/trainer.py
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
import torch
import os
from bob.ip.binseg.utils.model_serialization import load_state_dict
from bob.ip.binseg.utils.model_zoo import cache_url
[docs]class Checkpointer:
"""Adapted from `maskrcnn-benchmark`_ under MIT license
Returns
-------
[type]
[description]
"""
def __init__(
self,
model,
optimizer=None,
scheduler=None,
save_dir="",
save_to_disk=None,
logger=None,
):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.save_dir = save_dir
self.save_to_disk = save_to_disk
if logger is None:
logger = logging.getLogger(__name__)
self.logger = logger
[docs] def save(self, name, **kwargs):
if not self.save_dir:
return
if not self.save_to_disk:
return
data = {}
data["model"] = self.model.state_dict()
if self.optimizer is not None:
data["optimizer"] = self.optimizer.state_dict()
if self.scheduler is not None:
data["scheduler"] = self.scheduler.state_dict()
data.update(kwargs)
save_file = os.path.join(self.save_dir, "{}.pth".format(name))
self.logger.info("Saving checkpoint to {}".format(save_file))
torch.save(data, save_file)
self.tag_last_checkpoint(save_file)
[docs] def load(self, f=None):
if self.has_checkpoint():
# override argument with existing checkpoint
f = self.get_checkpoint_file()
if not f:
# no checkpoint could be found
self.logger.warn("No checkpoint found. Initializing model from scratch")
return {}
self.logger.info("Loading checkpoint from {}".format(f))
checkpoint = self._load_file(f)
self._load_model(checkpoint)
if "optimizer" in checkpoint and self.optimizer:
self.logger.info("Loading optimizer from {}".format(f))
self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
if "scheduler" in checkpoint and self.scheduler:
self.logger.info("Loading scheduler from {}".format(f))
self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
# return any further checkpoint data
return checkpoint
[docs] def has_checkpoint(self):
save_file = os.path.join(self.save_dir, "last_checkpoint")
return os.path.exists(save_file)
[docs] def get_checkpoint_file(self):
save_file = os.path.join(self.save_dir, "last_checkpoint")
try:
with open(save_file, "r") as f:
last_saved = f.read()
last_saved = last_saved.strip()
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
last_saved = ""
return last_saved
[docs] def tag_last_checkpoint(self, last_filename):
save_file = os.path.join(self.save_dir, "last_checkpoint")
with open(save_file, "w") as f:
f.write(last_filename)
def _load_file(self, f):
return torch.load(f, map_location=torch.device("cpu"))
def _load_model(self, checkpoint):
load_state_dict(self.model, checkpoint.pop("model"))
[docs]class DetectronCheckpointer(Checkpointer):
def __init__(
self,
model,
optimizer=None,
scheduler=None,
save_dir="",
save_to_disk=None,
logger=None,
):
super(DetectronCheckpointer, self).__init__(
model, optimizer, scheduler, save_dir, save_to_disk, logger
)
def _load_file(self, f):
# download url files
if f.startswith("http"):
# if the file is a url path, download it and cache it
cached_f = cache_url(f)
self.logger.info("url {} cached in {}".format(f, cached_f))
f = cached_f
# load checkpoint
loaded = super(DetectronCheckpointer, self)._load_file(f)
if "model" not in loaded:
loaded = dict(model=loaded)
return loaded