Source code for bob.ip.binseg.engine.predictor

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import datetime
import logging
import os
import time

import h5py
import numpy
import torch
import torchvision.transforms.functional as VF

from tqdm import tqdm

from ...common.data.utils import overlayed_image

logger = logging.getLogger(__name__)


def _save_hdf5(stem, prob, output_folder):
    """
    Saves prediction maps as image in the same format as the test image


    Parameters
    ----------
    stem : str
        the name of the file without extension on the original dataset

    prob : PIL.Image.Image
        Monochrome Image with prediction maps

    output_folder : str
        path where to store predictions

    """

    fullpath = os.path.join(output_folder, f"{stem}.hdf5")
    tqdm.write(f"Saving {fullpath}...")
    os.makedirs(os.path.dirname(fullpath), exist_ok=True)
    with h5py.File(fullpath, "w") as f:
        data = prob.cpu().squeeze(0).numpy()
        f.create_dataset(
            "array", data=data, compression="gzip", compression_opts=9
        )


def _save_image(stem, extension, data, output_folder):
    """Saves a PIL image into a file

    Parameters
    ----------

    stem : str
        the name of the file without extension on the original dataset

    extension : str
        an extension for the file to be saved (e.g. ``.png``)

    data : PIL.Image.Image
        RGB image with the original image, preloaded

    output_folder : str
        path where to store results

    """

    fullpath = os.path.join(output_folder, stem + extension)
    tqdm.write(f"Saving {fullpath}...")
    os.makedirs(os.path.dirname(fullpath), exist_ok=True)
    data.save(fullpath)


def _save_overlayed_png(stem, image, prob, output_folder):
    """Overlays prediction predictions vessel tree with original test image


    Parameters
    ----------

    stem : str
        the name of the file without extension on the original dataset

    image : torch.Tensor
        Tensor with RGB input image

    prob : torch.Tensor
        Tensor with 1-D prediction map

    output_folder : str
        path where to store results

    """

    image = VF.to_pil_image(image)
    prob = VF.to_pil_image(prob.cpu())
    _save_image(stem, ".png", overlayed_image(image, prob), output_folder)


[docs]def run(model, data_loader, name, device, output_folder, overlayed_folder): """ Runs inference on input data, outputs HDF5 files with predictions Parameters --------- model : :py:class:`torch.nn.Module` neural network model (e.g. driu, hed, unet) data_loader : py:class:`torch.torch.utils.data.DataLoader` name : str the local name of this dataset (e.g. ``train``, or ``test``), to be used when saving measures files. device : :py:class:`torch.device` device to use output_folder : str folder where to store output prediction maps (HDF5 files) and model summary overlayed_folder : str folder where to store output images (PNG files) """ logger.info(f"Output folder: {output_folder}") os.makedirs(output_folder, exist_ok=True) logger.info(f"Device: {device}") model.eval() # set evaluation mode model.to(device) # set/cast parameters to device sigmoid = torch.nn.Sigmoid() # use sigmoid for predictions # Setup timers start_total_time = time.time() times = [] len_samples = [] output_folder = os.path.join(output_folder, name) overlayed_folder = ( os.path.join(overlayed_folder, name) if overlayed_folder is not None else overlayed_folder ) for samples in tqdm(data_loader, desc="batches", leave=False, disable=None): names = samples[0] images = samples[1].to( device=device, non_blocking=torch.cuda.is_available() ) with torch.no_grad(): start_time = time.perf_counter() outputs = model(images) # necessary check for HED/Little W-Net architecture that use # several outputs for loss calculation instead of just the last one if isinstance(outputs, (list, tuple)): outputs = outputs[-1] predictions = sigmoid(outputs) batch_time = time.perf_counter() - start_time times.append(batch_time) len_samples.append(len(images)) for stem, img, prob in zip(names, images, predictions): _save_hdf5(stem, prob, output_folder) if overlayed_folder is not None: _save_overlayed_png(stem, img, prob, overlayed_folder) # report operational summary total_time = datetime.timedelta(seconds=int(time.time() - start_total_time)) logger.info(f"Total time: {total_time}") average_batch_time = numpy.mean(times) logger.info(f"Average batch time: {average_batch_time:g}s") average_image_time = numpy.sum(numpy.array(times) * len_samples) / float( sum(len_samples) ) logger.info(f"Average image time: {average_image_time:g}s")