Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4"""DRIU Network for Vessel Segmentation using SSL and Batch Normalization 

5 

6Deep Retinal Image Understanding (DRIU), a unified framework of retinal image 

7analysis that provides both retinal vessel and optic disc segmentation using 

8deep Convolutional Neural Networks (CNNs). This version of our model includes 

9a loss that is suitable for Semi-Supervised Learning (SSL). This version also 

10includes batch normalization as a regularization mechanism. 

11 

12Reference: [MANINIS-2016]_ 

13""" 

14 

15from torch.optim.lr_scheduler import MultiStepLR 

16 

17from bob.ip.binseg.engine.adabound import AdaBound 

18from bob.ip.binseg.models.driu_bn import driu_bn 

19from bob.ip.binseg.models.losses import MixJacLoss 

20 

21# config 

22lr = 0.001 

23betas = (0.9, 0.999) 

24eps = 1e-08 

25weight_decay = 0 

26final_lr = 0.1 

27gamma = 1e-3 

28eps = 1e-8 

29amsbound = False 

30 

31scheduler_milestones = [900] 

32scheduler_gamma = 0.1 

33 

34model = driu_bn() 

35 

36# optimizer 

37optimizer = AdaBound( 

38 model.parameters(), 

39 lr=lr, 

40 betas=betas, 

41 final_lr=final_lr, 

42 gamma=gamma, 

43 eps=eps, 

44 weight_decay=weight_decay, 

45 amsbound=amsbound, 

46) 

47 

48# criterion 

49criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7) 

50ssl = True 

51 

52# scheduler 

53scheduler = MultiStepLR( 

54 optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma 

55)