1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4"""DRIU Network for Vessel Segmentation with Batch Normalization
5
6Deep Retinal Image Understanding (DRIU), a unified framework of retinal image
7analysis that provides both retinal vessel and optic disc segmentation using
8deep Convolutional Neural Networks (CNNs). This implementation includes batch
9normalization as a regularization mechanism.
10
11Reference: [MANINIS-2016]_
12"""
13
14from torch.optim.lr_scheduler import MultiStepLR
15
16from bob.ip.binseg.engine.adabound import AdaBound
17from bob.ip.binseg.models.driu_bn import driu_bn
18from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
19
20# config
21lr = 0.001
22betas = (0.9, 0.999)
23eps = 1e-08
24weight_decay = 0
25final_lr = 0.1
26gamma = 1e-3
27eps = 1e-8
28amsbound = False
29
30scheduler_milestones = [900]
31scheduler_gamma = 0.1
32
33model = driu_bn()
34
35# optimizer
36optimizer = AdaBound(
37 model.parameters(),
38 lr=lr,
39 betas=betas,
40 final_lr=final_lr,
41 gamma=gamma,
42 eps=eps,
43 weight_decay=weight_decay,
44 amsbound=amsbound,
45)
46# criterion
47criterion = SoftJaccardBCELogitsLoss(alpha=0.7)
48
49# scheduler
50scheduler = MultiStepLR(
51 optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
52)