Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4"""
5Implementation of the `AdaBound optimizer
6<https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py>`::
8 @inproceedings{Luo2019AdaBound,
9 author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu},
10 title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate},
11 booktitle = {Proceedings of the 7th International Conference on Learning Representations},
12 month = {May},
13 year = {2019},
14 address = {New Orleans, Louisiana}
15 }
17"""
19import math
21import torch
22import torch.optim
25class AdaBound(torch.optim.Optimizer):
26 """Implements the AdaBound algorithm.
28 Parameters
29 ----------
31 params : list
32 Iterable of parameters to optimize or dicts defining parameter groups
34 lr : :obj:`float`, optional
35 Adam learning rate
37 betas : :obj:`tuple`, optional
38 Coefficients (as a 2-tuple of floats) used for computing running
39 averages of gradient and its square
41 final_lr : :obj:`float`, optional
42 Final (SGD) learning rate
44 gamma : :obj:`float`, optional
45 Convergence speed of the bound functions
47 eps : :obj:`float`, optional
48 Term added to the denominator to improve numerical stability
50 weight_decay : :obj:`float`, optional
51 Weight decay (L2 penalty)
53 amsbound : :obj:`bool`, optional
54 Whether to use the AMSBound variant of this algorithm
56 """
58 def __init__(
59 self,
60 params,
61 lr=1e-3,
62 betas=(0.9, 0.999),
63 final_lr=0.1,
64 gamma=1e-3,
65 eps=1e-8,
66 weight_decay=0,
67 amsbound=False,
68 ):
69 if not 0.0 <= lr:
70 raise ValueError("Invalid learning rate: {}".format(lr))
71 if not 0.0 <= eps:
72 raise ValueError("Invalid epsilon value: {}".format(eps))
73 if not 0.0 <= betas[0] < 1.0:
74 raise ValueError(
75 "Invalid beta parameter at index 0: {}".format(betas[0])
76 )
77 if not 0.0 <= betas[1] < 1.0:
78 raise ValueError(
79 "Invalid beta parameter at index 1: {}".format(betas[1])
80 )
81 if not 0.0 <= final_lr:
82 raise ValueError("Invalid final learning rate: {}".format(final_lr))
83 if not 0.0 <= gamma < 1.0:
84 raise ValueError("Invalid gamma parameter: {}".format(gamma))
85 defaults = dict(
86 lr=lr,
87 betas=betas,
88 final_lr=final_lr,
89 gamma=gamma,
90 eps=eps,
91 weight_decay=weight_decay,
92 amsbound=amsbound,
93 )
94 super(AdaBound, self).__init__(params, defaults)
96 self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
98 def __setstate__(self, state):
99 super(AdaBound, self).__setstate__(state)
100 for group in self.param_groups:
101 group.setdefault("amsbound", False)
103 def step(self, closure=None):
104 """Performs a single optimization step.
106 Parameters
107 ----------
109 closure : :obj:`callable`, optional
110 A closure that reevaluates the model and returns the loss.
112 """
113 loss = None
114 if closure is not None:
115 loss = closure()
117 for group, base_lr in zip(self.param_groups, self.base_lrs):
118 for p in group["params"]:
119 if p.grad is None:
120 continue
121 grad = p.grad.data
122 if grad.is_sparse:
123 raise RuntimeError(
124 "Adam does not support sparse gradients, please consider SparseAdam instead"
125 )
126 amsbound = group["amsbound"]
128 state = self.state[p]
130 # State initialization
131 if len(state) == 0:
132 state["step"] = 0
133 # Exponential moving average of gradient values
134 state["exp_avg"] = torch.zeros_like(p.data)
135 # Exponential moving average of squared gradient values
136 state["exp_avg_sq"] = torch.zeros_like(p.data)
137 if amsbound:
138 # Maintains max of all exp. moving avg. of sq. grad. values
139 state["max_exp_avg_sq"] = torch.zeros_like(p.data)
141 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
142 if amsbound:
143 max_exp_avg_sq = state["max_exp_avg_sq"]
144 beta1, beta2 = group["betas"]
146 state["step"] += 1
148 if group["weight_decay"] != 0:
149 grad = grad.add(group["weight_decay"], p.data)
151 # Decay the first and second moment running average coefficient
152 exp_avg.mul_(beta1).add_(1 - beta1, grad)
153 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
154 if amsbound:
155 # Maintains the maximum of all 2nd moment running avg. till now
156 torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
157 # Use the max. for normalizing running avg. of gradient
158 denom = max_exp_avg_sq.sqrt().add_(group["eps"])
159 else:
160 denom = exp_avg_sq.sqrt().add_(group["eps"])
162 bias_correction1 = 1 - beta1 ** state["step"]
163 bias_correction2 = 1 - beta2 ** state["step"]
164 step_size = (
165 group["lr"] * math.sqrt(bias_correction2) / bias_correction1
166 )
168 # Applies bounds on actual learning rate
169 # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
170 final_lr = group["final_lr"] * group["lr"] / base_lr
171 lower_bound = final_lr * (
172 1 - 1 / (group["gamma"] * state["step"] + 1)
173 )
174 upper_bound = final_lr * (
175 1 + 1 / (group["gamma"] * state["step"])
176 )
177 step_size = torch.full_like(denom, step_size)
178 step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(
179 exp_avg
180 )
182 p.data.add_(-step_size)
184 return loss
187class AdaBoundW(torch.optim.Optimizer):
188 """Implements AdaBound algorithm with Decoupled Weight Decay
189 (See https://arxiv.org/abs/1711.05101)
191 Parameters
192 ----------
194 params : list
195 Iterable of parameters to optimize or dicts defining parameter groups
197 lr : :obj:`float`, optional
198 Adam learning rate
200 betas : :obj:`tuple`, optional
201 Coefficients (as a 2-tuple of floats) used for computing running
202 averages of gradient and its square
204 final_lr : :obj:`float`, optional
205 Final (SGD) learning rate
207 gamma : :obj:`float`, optional
208 Convergence speed of the bound functions
210 eps : :obj:`float`, optional
211 Term added to the denominator to improve numerical stability
213 weight_decay : :obj:`float`, optional
214 Weight decay (L2 penalty)
216 amsbound : :obj:`bool`, optional
217 Whether to use the AMSBound variant of this algorithm
219 """
221 def __init__(
222 self,
223 params,
224 lr=1e-3,
225 betas=(0.9, 0.999),
226 final_lr=0.1,
227 gamma=1e-3,
228 eps=1e-8,
229 weight_decay=0,
230 amsbound=False,
231 ):
233 if not 0.0 <= lr:
234 raise ValueError("Invalid learning rate: {}".format(lr))
235 if not 0.0 <= eps:
236 raise ValueError("Invalid epsilon value: {}".format(eps))
237 if not 0.0 <= betas[0] < 1.0:
238 raise ValueError(
239 "Invalid beta parameter at index 0: {}".format(betas[0])
240 )
241 if not 0.0 <= betas[1] < 1.0:
242 raise ValueError(
243 "Invalid beta parameter at index 1: {}".format(betas[1])
244 )
245 if not 0.0 <= final_lr:
246 raise ValueError("Invalid final learning rate: {}".format(final_lr))
247 if not 0.0 <= gamma < 1.0:
248 raise ValueError("Invalid gamma parameter: {}".format(gamma))
249 defaults = dict(
250 lr=lr,
251 betas=betas,
252 final_lr=final_lr,
253 gamma=gamma,
254 eps=eps,
255 weight_decay=weight_decay,
256 amsbound=amsbound,
257 )
258 super(AdaBoundW, self).__init__(params, defaults)
260 self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
262 def __setstate__(self, state):
263 super(AdaBoundW, self).__setstate__(state)
264 for group in self.param_groups:
265 group.setdefault("amsbound", False)
267 def step(self, closure=None):
268 """Performs a single optimization step.
270 Parameters
271 ----------
273 closure : :obj:`callable`, optional
274 A closure that reevaluates the model and returns the loss.
276 """
278 loss = None
279 if closure is not None:
280 loss = closure()
282 for group, base_lr in zip(self.param_groups, self.base_lrs):
283 for p in group["params"]:
284 if p.grad is None:
285 continue
286 grad = p.grad.data
287 if grad.is_sparse:
288 raise RuntimeError(
289 "Adam does not support sparse gradients, please consider SparseAdam instead"
290 )
291 amsbound = group["amsbound"]
293 state = self.state[p]
295 # State initialization
296 if len(state) == 0:
297 state["step"] = 0
298 # Exponential moving average of gradient values
299 state["exp_avg"] = torch.zeros_like(p.data)
300 # Exponential moving average of squared gradient values
301 state["exp_avg_sq"] = torch.zeros_like(p.data)
302 if amsbound:
303 # Maintains max of all exp. moving avg. of sq. grad. values
304 state["max_exp_avg_sq"] = torch.zeros_like(p.data)
306 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
307 if amsbound:
308 max_exp_avg_sq = state["max_exp_avg_sq"]
309 beta1, beta2 = group["betas"]
311 state["step"] += 1
313 # Decay the first and second moment running average coefficient
314 exp_avg.mul_(beta1).add_(1 - beta1, grad)
315 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
316 if amsbound:
317 # Maintains the maximum of all 2nd moment running avg. till now
318 torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
319 # Use the max. for normalizing running avg. of gradient
320 denom = max_exp_avg_sq.sqrt().add_(group["eps"])
321 else:
322 denom = exp_avg_sq.sqrt().add_(group["eps"])
324 bias_correction1 = 1 - beta1 ** state["step"]
325 bias_correction2 = 1 - beta2 ** state["step"]
326 step_size = (
327 group["lr"] * math.sqrt(bias_correction2) / bias_correction1
328 )
330 # Applies bounds on actual learning rate
331 # lr_scheduler cannot affect final_lr, this is a workaround to
332 # apply lr decay
333 final_lr = group["final_lr"] * group["lr"] / base_lr
334 lower_bound = final_lr * (
335 1 - 1 / (group["gamma"] * state["step"] + 1)
336 )
337 upper_bound = final_lr * (
338 1 + 1 / (group["gamma"] * state["step"])
339 )
340 step_size = torch.full_like(denom, step_size)
341 step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(
342 exp_avg
343 )
345 if group["weight_decay"] != 0:
346 decayed_weights = torch.mul(p.data, group["weight_decay"])
347 p.data.add_(-step_size)
348 p.data.sub_(decayed_weights)
349 else:
350 p.data.add_(-step_size)
352 return loss