Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4""" 

5Implementation of the `AdaBound optimizer 

6<https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py>`:: 

7 

8 @inproceedings{Luo2019AdaBound, 

9 author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu}, 

10 title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate}, 

11 booktitle = {Proceedings of the 7th International Conference on Learning Representations}, 

12 month = {May}, 

13 year = {2019}, 

14 address = {New Orleans, Louisiana} 

15 } 

16 

17""" 

18 

19import math 

20 

21import torch 

22import torch.optim 

23 

24 

25class AdaBound(torch.optim.Optimizer): 

26 """Implements the AdaBound algorithm. 

27 

28 Parameters 

29 ---------- 

30 

31 params : list 

32 Iterable of parameters to optimize or dicts defining parameter groups 

33 

34 lr : :obj:`float`, optional 

35 Adam learning rate 

36 

37 betas : :obj:`tuple`, optional 

38 Coefficients (as a 2-tuple of floats) used for computing running 

39 averages of gradient and its square 

40 

41 final_lr : :obj:`float`, optional 

42 Final (SGD) learning rate 

43 

44 gamma : :obj:`float`, optional 

45 Convergence speed of the bound functions 

46 

47 eps : :obj:`float`, optional 

48 Term added to the denominator to improve numerical stability 

49 

50 weight_decay : :obj:`float`, optional 

51 Weight decay (L2 penalty) 

52 

53 amsbound : :obj:`bool`, optional 

54 Whether to use the AMSBound variant of this algorithm 

55 

56 """ 

57 

58 def __init__( 

59 self, 

60 params, 

61 lr=1e-3, 

62 betas=(0.9, 0.999), 

63 final_lr=0.1, 

64 gamma=1e-3, 

65 eps=1e-8, 

66 weight_decay=0, 

67 amsbound=False, 

68 ): 

69 if not 0.0 <= lr: 

70 raise ValueError("Invalid learning rate: {}".format(lr)) 

71 if not 0.0 <= eps: 

72 raise ValueError("Invalid epsilon value: {}".format(eps)) 

73 if not 0.0 <= betas[0] < 1.0: 

74 raise ValueError( 

75 "Invalid beta parameter at index 0: {}".format(betas[0]) 

76 ) 

77 if not 0.0 <= betas[1] < 1.0: 

78 raise ValueError( 

79 "Invalid beta parameter at index 1: {}".format(betas[1]) 

80 ) 

81 if not 0.0 <= final_lr: 

82 raise ValueError("Invalid final learning rate: {}".format(final_lr)) 

83 if not 0.0 <= gamma < 1.0: 

84 raise ValueError("Invalid gamma parameter: {}".format(gamma)) 

85 defaults = dict( 

86 lr=lr, 

87 betas=betas, 

88 final_lr=final_lr, 

89 gamma=gamma, 

90 eps=eps, 

91 weight_decay=weight_decay, 

92 amsbound=amsbound, 

93 ) 

94 super(AdaBound, self).__init__(params, defaults) 

95 

96 self.base_lrs = list(map(lambda group: group["lr"], self.param_groups)) 

97 

98 def __setstate__(self, state): 

99 super(AdaBound, self).__setstate__(state) 

100 for group in self.param_groups: 

101 group.setdefault("amsbound", False) 

102 

103 def step(self, closure=None): 

104 """Performs a single optimization step. 

105 

106 Parameters 

107 ---------- 

108 

109 closure : :obj:`callable`, optional 

110 A closure that reevaluates the model and returns the loss. 

111 

112 """ 

113 loss = None 

114 if closure is not None: 

115 loss = closure() 

116 

117 for group, base_lr in zip(self.param_groups, self.base_lrs): 

118 for p in group["params"]: 

119 if p.grad is None: 

120 continue 

121 grad = p.grad.data 

122 if grad.is_sparse: 

123 raise RuntimeError( 

124 "Adam does not support sparse gradients, please consider SparseAdam instead" 

125 ) 

126 amsbound = group["amsbound"] 

127 

128 state = self.state[p] 

129 

130 # State initialization 

131 if len(state) == 0: 

132 state["step"] = 0 

133 # Exponential moving average of gradient values 

134 state["exp_avg"] = torch.zeros_like(p.data) 

135 # Exponential moving average of squared gradient values 

136 state["exp_avg_sq"] = torch.zeros_like(p.data) 

137 if amsbound: 

138 # Maintains max of all exp. moving avg. of sq. grad. values 

139 state["max_exp_avg_sq"] = torch.zeros_like(p.data) 

140 

141 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 

142 if amsbound: 

143 max_exp_avg_sq = state["max_exp_avg_sq"] 

144 beta1, beta2 = group["betas"] 

145 

146 state["step"] += 1 

147 

148 if group["weight_decay"] != 0: 

149 grad = grad.add(group["weight_decay"], p.data) 

150 

151 # Decay the first and second moment running average coefficient 

152 exp_avg.mul_(beta1).add_(1 - beta1, grad) 

153 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 

154 if amsbound: 

155 # Maintains the maximum of all 2nd moment running avg. till now 

156 torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 

157 # Use the max. for normalizing running avg. of gradient 

158 denom = max_exp_avg_sq.sqrt().add_(group["eps"]) 

159 else: 

160 denom = exp_avg_sq.sqrt().add_(group["eps"]) 

161 

162 bias_correction1 = 1 - beta1 ** state["step"] 

163 bias_correction2 = 1 - beta2 ** state["step"] 

164 step_size = ( 

165 group["lr"] * math.sqrt(bias_correction2) / bias_correction1 

166 ) 

167 

168 # Applies bounds on actual learning rate 

169 # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay 

170 final_lr = group["final_lr"] * group["lr"] / base_lr 

171 lower_bound = final_lr * ( 

172 1 - 1 / (group["gamma"] * state["step"] + 1) 

173 ) 

174 upper_bound = final_lr * ( 

175 1 + 1 / (group["gamma"] * state["step"]) 

176 ) 

177 step_size = torch.full_like(denom, step_size) 

178 step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_( 

179 exp_avg 

180 ) 

181 

182 p.data.add_(-step_size) 

183 

184 return loss 

185 

186 

187class AdaBoundW(torch.optim.Optimizer): 

188 """Implements AdaBound algorithm with Decoupled Weight Decay 

189 (See https://arxiv.org/abs/1711.05101) 

190 

191 Parameters 

192 ---------- 

193 

194 params : list 

195 Iterable of parameters to optimize or dicts defining parameter groups 

196 

197 lr : :obj:`float`, optional 

198 Adam learning rate 

199 

200 betas : :obj:`tuple`, optional 

201 Coefficients (as a 2-tuple of floats) used for computing running 

202 averages of gradient and its square 

203 

204 final_lr : :obj:`float`, optional 

205 Final (SGD) learning rate 

206 

207 gamma : :obj:`float`, optional 

208 Convergence speed of the bound functions 

209 

210 eps : :obj:`float`, optional 

211 Term added to the denominator to improve numerical stability 

212 

213 weight_decay : :obj:`float`, optional 

214 Weight decay (L2 penalty) 

215 

216 amsbound : :obj:`bool`, optional 

217 Whether to use the AMSBound variant of this algorithm 

218 

219 """ 

220 

221 def __init__( 

222 self, 

223 params, 

224 lr=1e-3, 

225 betas=(0.9, 0.999), 

226 final_lr=0.1, 

227 gamma=1e-3, 

228 eps=1e-8, 

229 weight_decay=0, 

230 amsbound=False, 

231 ): 

232 

233 if not 0.0 <= lr: 

234 raise ValueError("Invalid learning rate: {}".format(lr)) 

235 if not 0.0 <= eps: 

236 raise ValueError("Invalid epsilon value: {}".format(eps)) 

237 if not 0.0 <= betas[0] < 1.0: 

238 raise ValueError( 

239 "Invalid beta parameter at index 0: {}".format(betas[0]) 

240 ) 

241 if not 0.0 <= betas[1] < 1.0: 

242 raise ValueError( 

243 "Invalid beta parameter at index 1: {}".format(betas[1]) 

244 ) 

245 if not 0.0 <= final_lr: 

246 raise ValueError("Invalid final learning rate: {}".format(final_lr)) 

247 if not 0.0 <= gamma < 1.0: 

248 raise ValueError("Invalid gamma parameter: {}".format(gamma)) 

249 defaults = dict( 

250 lr=lr, 

251 betas=betas, 

252 final_lr=final_lr, 

253 gamma=gamma, 

254 eps=eps, 

255 weight_decay=weight_decay, 

256 amsbound=amsbound, 

257 ) 

258 super(AdaBoundW, self).__init__(params, defaults) 

259 

260 self.base_lrs = list(map(lambda group: group["lr"], self.param_groups)) 

261 

262 def __setstate__(self, state): 

263 super(AdaBoundW, self).__setstate__(state) 

264 for group in self.param_groups: 

265 group.setdefault("amsbound", False) 

266 

267 def step(self, closure=None): 

268 """Performs a single optimization step. 

269 

270 Parameters 

271 ---------- 

272 

273 closure : :obj:`callable`, optional 

274 A closure that reevaluates the model and returns the loss. 

275 

276 """ 

277 

278 loss = None 

279 if closure is not None: 

280 loss = closure() 

281 

282 for group, base_lr in zip(self.param_groups, self.base_lrs): 

283 for p in group["params"]: 

284 if p.grad is None: 

285 continue 

286 grad = p.grad.data 

287 if grad.is_sparse: 

288 raise RuntimeError( 

289 "Adam does not support sparse gradients, please consider SparseAdam instead" 

290 ) 

291 amsbound = group["amsbound"] 

292 

293 state = self.state[p] 

294 

295 # State initialization 

296 if len(state) == 0: 

297 state["step"] = 0 

298 # Exponential moving average of gradient values 

299 state["exp_avg"] = torch.zeros_like(p.data) 

300 # Exponential moving average of squared gradient values 

301 state["exp_avg_sq"] = torch.zeros_like(p.data) 

302 if amsbound: 

303 # Maintains max of all exp. moving avg. of sq. grad. values 

304 state["max_exp_avg_sq"] = torch.zeros_like(p.data) 

305 

306 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 

307 if amsbound: 

308 max_exp_avg_sq = state["max_exp_avg_sq"] 

309 beta1, beta2 = group["betas"] 

310 

311 state["step"] += 1 

312 

313 # Decay the first and second moment running average coefficient 

314 exp_avg.mul_(beta1).add_(1 - beta1, grad) 

315 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 

316 if amsbound: 

317 # Maintains the maximum of all 2nd moment running avg. till now 

318 torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 

319 # Use the max. for normalizing running avg. of gradient 

320 denom = max_exp_avg_sq.sqrt().add_(group["eps"]) 

321 else: 

322 denom = exp_avg_sq.sqrt().add_(group["eps"]) 

323 

324 bias_correction1 = 1 - beta1 ** state["step"] 

325 bias_correction2 = 1 - beta2 ** state["step"] 

326 step_size = ( 

327 group["lr"] * math.sqrt(bias_correction2) / bias_correction1 

328 ) 

329 

330 # Applies bounds on actual learning rate 

331 # lr_scheduler cannot affect final_lr, this is a workaround to 

332 # apply lr decay 

333 final_lr = group["final_lr"] * group["lr"] / base_lr 

334 lower_bound = final_lr * ( 

335 1 - 1 / (group["gamma"] * state["step"] + 1) 

336 ) 

337 upper_bound = final_lr * ( 

338 1 + 1 / (group["gamma"] * state["step"]) 

339 ) 

340 step_size = torch.full_like(denom, step_size) 

341 step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_( 

342 exp_avg 

343 ) 

344 

345 if group["weight_decay"] != 0: 

346 decayed_weights = torch.mul(p.data, group["weight_decay"]) 

347 p.data.add_(-step_size) 

348 p.data.sub_(decayed_weights) 

349 else: 

350 p.data.add_(-step_size) 

351 

352 return loss