Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4"""DRIU Network for Vessel Segmentation using SSL and Batch Normalization
6Deep Retinal Image Understanding (DRIU), a unified framework of retinal image
7analysis that provides both retinal vessel and optic disc segmentation using
8deep Convolutional Neural Networks (CNNs). This version of our model includes
9a loss that is suitable for Semi-Supervised Learning (SSL). This version also
10includes batch normalization as a regularization mechanism.
12Reference: [MANINIS-2016]_
13"""
15from torch.optim.lr_scheduler import MultiStepLR
17from bob.ip.binseg.engine.adabound import AdaBound
18from bob.ip.binseg.models.driu_bn import driu_bn
19from bob.ip.binseg.models.losses import MixJacLoss
21# config
22lr = 0.001
23betas = (0.9, 0.999)
24eps = 1e-08
25weight_decay = 0
26final_lr = 0.1
27gamma = 1e-3
28eps = 1e-8
29amsbound = False
31scheduler_milestones = [900]
32scheduler_gamma = 0.1
34model = driu_bn()
36# optimizer
37optimizer = AdaBound(
38 model.parameters(),
39 lr=lr,
40 betas=betas,
41 final_lr=final_lr,
42 gamma=gamma,
43 eps=eps,
44 weight_decay=weight_decay,
45 amsbound=amsbound,
46)
48# criterion
49criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7)
50ssl = True
52# scheduler
53scheduler = MultiStepLR(
54 optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
55)