Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4 

5"""MobileNetV2 U-Net model for image segmentation using SSL 

6 

7The MobileNetV2 architecture is based on an inverted residual structure where 

8the input and output of the residual block are thin bottleneck layers opposite 

9to traditional residual models which use expanded representations in the input 

10an MobileNetV2 uses lightweight depthwise convolutions to filter features in 

11the intermediate expansion layer. This model implements a MobileNetV2 U-Net 

12model, henceforth named M2U-Net, combining the strenghts of U-Net for medical 

13segmentation applications and the speed of MobileNetV2 networks. This version 

14of our model includes a loss that is suitable for Semi-Supervised Learning 

15(SSL). 

16 

17References: [SANDLER-2018]_, [RONNEBERGER-2015]_ 

18""" 

19 

20from torch.optim.lr_scheduler import MultiStepLR 

21 

22from bob.ip.binseg.engine.adabound import AdaBound 

23from bob.ip.binseg.models.losses import MixJacLoss 

24from bob.ip.binseg.models.m2unet import m2unet 

25 

26# config 

27lr = 0.001 

28betas = (0.9, 0.999) 

29eps = 1e-08 

30weight_decay = 0 

31final_lr = 0.1 

32gamma = 1e-3 

33eps = 1e-8 

34amsbound = False 

35 

36scheduler_milestones = [900] 

37scheduler_gamma = 0.1 

38 

39model = m2unet() 

40 

41# optimizer 

42optimizer = AdaBound( 

43 model.parameters(), 

44 lr=lr, 

45 betas=betas, 

46 final_lr=final_lr, 

47 gamma=gamma, 

48 eps=eps, 

49 weight_decay=weight_decay, 

50 amsbound=amsbound, 

51) 

52 

53# criterion 

54criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7) 

55ssl = True 

56 

57# scheduler 

58scheduler = MultiStepLR( 

59 optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma 

60)