Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
5"""MobileNetV2 U-Net model for image segmentation using SSL
7The MobileNetV2 architecture is based on an inverted residual structure where
8the input and output of the residual block are thin bottleneck layers opposite
9to traditional residual models which use expanded representations in the input
10an MobileNetV2 uses lightweight depthwise convolutions to filter features in
11the intermediate expansion layer. This model implements a MobileNetV2 U-Net
12model, henceforth named M2U-Net, combining the strenghts of U-Net for medical
13segmentation applications and the speed of MobileNetV2 networks. This version
14of our model includes a loss that is suitable for Semi-Supervised Learning
15(SSL).
17References: [SANDLER-2018]_, [RONNEBERGER-2015]_
18"""
20from torch.optim.lr_scheduler import MultiStepLR
22from bob.ip.binseg.engine.adabound import AdaBound
23from bob.ip.binseg.models.losses import MixJacLoss
24from bob.ip.binseg.models.m2unet import m2unet
26# config
27lr = 0.001
28betas = (0.9, 0.999)
29eps = 1e-08
30weight_decay = 0
31final_lr = 0.1
32gamma = 1e-3
33eps = 1e-8
34amsbound = False
36scheduler_milestones = [900]
37scheduler_gamma = 0.1
39model = m2unet()
41# optimizer
42optimizer = AdaBound(
43 model.parameters(),
44 lr=lr,
45 betas=betas,
46 final_lr=final_lr,
47 gamma=gamma,
48 eps=eps,
49 weight_decay=weight_decay,
50 amsbound=amsbound,
51)
53# criterion
54criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7)
55ssl = True
57# scheduler
58scheduler = MultiStepLR(
59 optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
60)