import numpy as np
import math
from .base import BaseScheduler
[docs]class CosineScheduler(BaseScheduler):
def __init__(
self,
epoch_iters: int,
epochs: int,
warmup_iters: int=0,
warmup_epochs: int=0,
max_lr: float=0.1,
min_lr: float=0,
restarts: int=0,
hold_epochs: int=0,
hold_iters: int=0,
restart_decay: float=0.1,
noice_std: float=0,
last_epoch: int=-1, **kwargs):
super().__init__(
epoch_iters,
warmup_iters,
warmup_epochs,
min_lr,
noice_std,
last_epoch)
assert not (hold_epochs and hold_iters), \
'hold_iters and hold_epochs cannot be set at the same time.'
if hold_epochs > 0:
self.hold_iters = hold_epochs * epoch_iters
else:
self.hold_iters = hold_iters
assert restarts >= 0
self.max_iters = epoch_iters * epochs
self.max_lr = max_lr
self.min_lr = min_lr
self.restarts = restarts
self.restart_decay = restart_decay
self.period = math.ceil((self.max_iters - self.warmup_iters - self.hold_iters) / (self.restarts + 1))
[docs] def get_lr(self, iter):
if iter <= (self.warmup_iters + self.hold_iters):
if self.warmup_iters > 0 and iter <= self.warmup_iters:
lr = self.min_lr + (iter / self.warmup_iters) * \
(self.max_lr - self.min_lr)
elif 0 < (iter - self.warmup_iters) <= self.hold_iters:
lr = self.max_lr
else:
round = (iter - self.warmup_iters - self.hold_iters) // self.period
step = (iter - self.warmup_iters - self.hold_iters) % self.period
base_lr = self.max_lr * (self.restart_decay ** round)
assert base_lr > self.min_lr
lr = (base_lr - self.min_lr) * (1 + math.cos((step / self.period) * math.pi)) / 2 + self.min_lr
return lr