Coverage for src/deepdraw/engine/adabound.py: 16%

132 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5"""Implementation of the AdaBound optimizer. 

6 

7<https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py>:: 

8 

9 @inproceedings{Luo2019AdaBound, 

10 author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu}, 

11 title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate}, 

12 booktitle = {Proceedings of the 7th International Conference on Learning Representations}, 

13 month = {May}, 

14 year = {2019}, 

15 address = {New Orleans, Louisiana} 

16 } 

17""" 

18 

19import math 

20 

21import torch 

22import torch.optim 

23 

24 

25class AdaBound(torch.optim.Optimizer): 

26 """Implements the AdaBound algorithm. 

27 

28 Parameters 

29 ---------- 

30 

31 params : list 

32 Iterable of parameters to optimize or dicts defining parameter groups 

33 

34 lr : :obj:`float`, optional 

35 Adam learning rate 

36 

37 betas : :obj:`tuple`, optional 

38 Coefficients (as a 2-tuple of floats) used for computing running 

39 averages of gradient and its square 

40 

41 final_lr : :obj:`float`, optional 

42 Final (SGD) learning rate 

43 

44 gamma : :obj:`float`, optional 

45 Convergence speed of the bound functions 

46 

47 eps : :obj:`float`, optional 

48 Term added to the denominator to improve numerical stability 

49 

50 weight_decay : :obj:`float`, optional 

51 Weight decay (L2 penalty) 

52 

53 amsbound : :obj:`bool`, optional 

54 Whether to use the AMSBound variant of this algorithm 

55 """ 

56 

57 def __init__( 

58 self, 

59 params, 

60 lr=1e-3, 

61 betas=(0.9, 0.999), 

62 final_lr=0.1, 

63 gamma=1e-3, 

64 eps=1e-8, 

65 weight_decay=0, 

66 amsbound=False, 

67 ): 

68 if not 0.0 <= lr: 

69 raise ValueError(f"Invalid learning rate: {lr}") 

70 if not 0.0 <= eps: 

71 raise ValueError(f"Invalid epsilon value: {eps}") 

72 if not 0.0 <= betas[0] < 1.0: 

73 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 

74 if not 0.0 <= betas[1] < 1.0: 

75 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 

76 if not 0.0 <= final_lr: 

77 raise ValueError(f"Invalid final learning rate: {final_lr}") 

78 if not 0.0 <= gamma < 1.0: 

79 raise ValueError(f"Invalid gamma parameter: {gamma}") 

80 defaults = dict( 

81 lr=lr, 

82 betas=betas, 

83 final_lr=final_lr, 

84 gamma=gamma, 

85 eps=eps, 

86 weight_decay=weight_decay, 

87 amsbound=amsbound, 

88 ) 

89 super().__init__(params, defaults) 

90 

91 self.base_lrs = list(map(lambda group: group["lr"], self.param_groups)) 

92 

93 def __setstate__(self, state): 

94 super().__setstate__(state) 

95 for group in self.param_groups: 

96 group.setdefault("amsbound", False) 

97 

98 def step(self, closure=None): 

99 """Performs a single optimization step. 

100 

101 Parameters 

102 ---------- 

103 

104 closure : :obj:`callable`, optional 

105 A closure that reevaluates the model and returns the loss. 

106 """ 

107 loss = None 

108 if closure is not None: 

109 loss = closure() 

110 

111 for group, base_lr in zip(self.param_groups, self.base_lrs): 

112 for p in group["params"]: 

113 if p.grad is None: 

114 continue 

115 grad = p.grad.data 

116 if grad.is_sparse: 

117 raise RuntimeError( 

118 "Adam does not support sparse gradients, please consider SparseAdam instead" 

119 ) 

120 amsbound = group["amsbound"] 

121 

122 state = self.state[p] 

123 

124 # State initialization 

125 if len(state) == 0: 

126 state["step"] = 0 

127 # Exponential moving average of gradient values 

128 state["exp_avg"] = torch.zeros_like(p.data) 

129 # Exponential moving average of squared gradient values 

130 state["exp_avg_sq"] = torch.zeros_like(p.data) 

131 if amsbound: 

132 # Maintains max of all exp. moving avg. of sq. grad. values 

133 state["max_exp_avg_sq"] = torch.zeros_like(p.data) 

134 

135 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 

136 if amsbound: 

137 max_exp_avg_sq = state["max_exp_avg_sq"] 

138 beta1, beta2 = group["betas"] 

139 

140 state["step"] += 1 

141 

142 if group["weight_decay"] != 0: 

143 grad = grad.add(group["weight_decay"], p.data) 

144 

145 # Decay the first and second moment running average coefficient 

146 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 

147 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 

148 if amsbound: 

149 # Maintains the maximum of all 2nd moment running avg. till now 

150 torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 

151 # Use the max. for normalizing running avg. of gradient 

152 denom = max_exp_avg_sq.sqrt().add_(group["eps"]) 

153 else: 

154 denom = exp_avg_sq.sqrt().add_(group["eps"]) 

155 

156 bias_correction1 = 1 - beta1 ** state["step"] 

157 bias_correction2 = 1 - beta2 ** state["step"] 

158 step_size = ( 

159 group["lr"] * math.sqrt(bias_correction2) / bias_correction1 

160 ) 

161 

162 # Applies bounds on actual learning rate 

163 # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay 

164 final_lr = group["final_lr"] * group["lr"] / base_lr 

165 lower_bound = final_lr * ( 

166 1 - 1 / (group["gamma"] * state["step"] + 1) 

167 ) 

168 upper_bound = final_lr * ( 

169 1 + 1 / (group["gamma"] * state["step"]) 

170 ) 

171 step_size = torch.full_like(denom, step_size) 

172 step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_( 

173 exp_avg 

174 ) 

175 

176 p.data.add_(-step_size) 

177 

178 return loss 

179 

180 

181class AdaBoundW(torch.optim.Optimizer): 

182 """Implements AdaBound algorithm with Decoupled Weight Decay (See 

183 https://arxiv.org/abs/1711.05101) 

184 

185 Parameters 

186 ---------- 

187 

188 params : list 

189 Iterable of parameters to optimize or dicts defining parameter groups 

190 

191 lr : :obj:`float`, optional 

192 Adam learning rate 

193 

194 betas : :obj:`tuple`, optional 

195 Coefficients (as a 2-tuple of floats) used for computing running 

196 averages of gradient and its square 

197 

198 final_lr : :obj:`float`, optional 

199 Final (SGD) learning rate 

200 

201 gamma : :obj:`float`, optional 

202 Convergence speed of the bound functions 

203 

204 eps : :obj:`float`, optional 

205 Term added to the denominator to improve numerical stability 

206 

207 weight_decay : :obj:`float`, optional 

208 Weight decay (L2 penalty) 

209 

210 amsbound : :obj:`bool`, optional 

211 Whether to use the AMSBound variant of this algorithm 

212 """ 

213 

214 def __init__( 

215 self, 

216 params, 

217 lr=1e-3, 

218 betas=(0.9, 0.999), 

219 final_lr=0.1, 

220 gamma=1e-3, 

221 eps=1e-8, 

222 weight_decay=0, 

223 amsbound=False, 

224 ): 

225 if not 0.0 <= lr: 

226 raise ValueError(f"Invalid learning rate: {lr}") 

227 if not 0.0 <= eps: 

228 raise ValueError(f"Invalid epsilon value: {eps}") 

229 if not 0.0 <= betas[0] < 1.0: 

230 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 

231 if not 0.0 <= betas[1] < 1.0: 

232 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 

233 if not 0.0 <= final_lr: 

234 raise ValueError(f"Invalid final learning rate: {final_lr}") 

235 if not 0.0 <= gamma < 1.0: 

236 raise ValueError(f"Invalid gamma parameter: {gamma}") 

237 defaults = dict( 

238 lr=lr, 

239 betas=betas, 

240 final_lr=final_lr, 

241 gamma=gamma, 

242 eps=eps, 

243 weight_decay=weight_decay, 

244 amsbound=amsbound, 

245 ) 

246 super().__init__(params, defaults) 

247 

248 self.base_lrs = list(map(lambda group: group["lr"], self.param_groups)) 

249 

250 def __setstate__(self, state): 

251 super().__setstate__(state) 

252 for group in self.param_groups: 

253 group.setdefault("amsbound", False) 

254 

255 def step(self, closure=None): 

256 """Performs a single optimization step. 

257 

258 Parameters 

259 ---------- 

260 

261 closure : :obj:`callable`, optional 

262 A closure that reevaluates the model and returns the loss. 

263 """ 

264 

265 loss = None 

266 if closure is not None: 

267 loss = closure() 

268 

269 for group, base_lr in zip(self.param_groups, self.base_lrs): 

270 for p in group["params"]: 

271 if p.grad is None: 

272 continue 

273 grad = p.grad.data 

274 if grad.is_sparse: 

275 raise RuntimeError( 

276 "Adam does not support sparse gradients, please consider SparseAdam instead" 

277 ) 

278 amsbound = group["amsbound"] 

279 

280 state = self.state[p] 

281 

282 # State initialization 

283 if len(state) == 0: 

284 state["step"] = 0 

285 # Exponential moving average of gradient values 

286 state["exp_avg"] = torch.zeros_like(p.data) 

287 # Exponential moving average of squared gradient values 

288 state["exp_avg_sq"] = torch.zeros_like(p.data) 

289 if amsbound: 

290 # Maintains max of all exp. moving avg. of sq. grad. values 

291 state["max_exp_avg_sq"] = torch.zeros_like(p.data) 

292 

293 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 

294 if amsbound: 

295 max_exp_avg_sq = state["max_exp_avg_sq"] 

296 beta1, beta2 = group["betas"] 

297 

298 state["step"] += 1 

299 

300 # Decay the first and second moment running average coefficient 

301 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 

302 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 

303 if amsbound: 

304 # Maintains the maximum of all 2nd moment running avg. till now 

305 torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 

306 # Use the max. for normalizing running avg. of gradient 

307 denom = max_exp_avg_sq.sqrt().add_(group["eps"]) 

308 else: 

309 denom = exp_avg_sq.sqrt().add_(group["eps"]) 

310 

311 bias_correction1 = 1 - beta1 ** state["step"] 

312 bias_correction2 = 1 - beta2 ** state["step"] 

313 step_size = ( 

314 group["lr"] * math.sqrt(bias_correction2) / bias_correction1 

315 ) 

316 

317 # Applies bounds on actual learning rate 

318 # lr_scheduler cannot affect final_lr, this is a workaround to 

319 # apply lr decay 

320 final_lr = group["final_lr"] * group["lr"] / base_lr 

321 lower_bound = final_lr * ( 

322 1 - 1 / (group["gamma"] * state["step"] + 1) 

323 ) 

324 upper_bound = final_lr * ( 

325 1 + 1 / (group["gamma"] * state["step"]) 

326 ) 

327 step_size = torch.full_like(denom, step_size) 

328 step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_( 

329 exp_avg 

330 ) 

331 

332 if group["weight_decay"] != 0: 

333 decayed_weights = torch.mul(p.data, group["weight_decay"]) 

334 p.data.add_(-step_size) 

335 p.data.sub_(decayed_weights) 

336 else: 

337 p.data.add_(-step_size) 

338 

339 return loss