import numpy as np
import math
from typing import Union, List
from .base import BaseScheduler
[docs]class MultiStepScheduler(BaseScheduler):
def __init__(
self,
epoch_iters: int,
milestones: List[int],
gamma: float=None,
gammas: List[float]=None,
base_lr: float=0.1,
warmup_iters: int=0,
warmup_epochs: int=0,
min_lr: float=0,
noice_std: float=0.0,
last_epoch: int=-1,
**kwargs):
super().__init__(
epoch_iters,
warmup_iters,
warmup_epochs,
min_lr,
noice_std,
last_epoch)
if (gamma is not None and gammas is not None):
raise ValueError('\'gamma\' and \'gammas\' cannot be set at the same time!')
if gamma is not None:
gammas = [gamma,] * len(milestones)
self.base_lr = base_lr
self.milestones = [epoch_iters * i for i in milestones]
self.gammas = gammas
self.milestone_counter = 0
[docs] def get_lr(self, iter):
if self.warmup_iters > 0 and iter <= self.warmup_iters:
lr = self.min_lr + (iter / self.warmup_iters) * (self.base_lr - self.min_lr)
else:
stage = np.digitize(iter, self.milestones)
if stage == 0:
lr = self.base_lr
else:
lr = self.base_lr * np.prod(self.gammas[:stage])
return lr