|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
class WSDSchedule(object): |
|
|
|
def __init__(self, optimizer, warmup_steps, anneal_steps, T_max, start_lr, ref_lr, final_lr=0.0): |
|
self.optimizer = optimizer |
|
self.start_lr = start_lr |
|
self.ref_lr = ref_lr |
|
self.final_lr = final_lr |
|
self.anneal_steps = anneal_steps |
|
self.warmup_steps = warmup_steps |
|
self.T_max = T_max - warmup_steps - anneal_steps |
|
self._step = 0.0 |
|
|
|
def step(self): |
|
self._step += 1 |
|
if self._step < self.warmup_steps: |
|
progress = float(self._step) / float(max(1, self.warmup_steps)) |
|
new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) |
|
elif self._step < self.T_max + self.warmup_steps: |
|
new_lr = self.ref_lr |
|
else: |
|
_step = self._step - (self.T_max + self.warmup_steps) |
|
progress = float(_step) / float(max(1, self.anneal_steps)) |
|
new_lr = self.ref_lr + progress * (self.final_lr - self.ref_lr) |
|
|
|
for group in self.optimizer.param_groups: |
|
group["lr"] = new_lr |
|
if "lr_scale" in group: |
|
group["lr"] *= group["lr_scale"] |
|
|
|
return new_lr |
|
|
|
|
|
class WarmupCosineSchedule(object): |
|
|
|
def __init__(self, optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0): |
|
self.optimizer = optimizer |
|
self.start_lr = start_lr |
|
self.ref_lr = ref_lr |
|
self.final_lr = final_lr |
|
self.warmup_steps = warmup_steps |
|
self.T_max = T_max - warmup_steps |
|
self._step = 0.0 |
|
|
|
def step(self): |
|
self._step += 1 |
|
if self._step < self.warmup_steps: |
|
progress = float(self._step) / float(max(1, self.warmup_steps)) |
|
new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) |
|
else: |
|
|
|
progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) |
|
new_lr = max( |
|
self.final_lr, |
|
self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1.0 + math.cos(math.pi * progress)), |
|
) |
|
|
|
for group in self.optimizer.param_groups: |
|
group["lr"] = new_lr |
|
|
|
return new_lr |
|
|
|
|
|
class CosineWDSchedule(object): |
|
|
|
def __init__(self, optimizer, ref_wd, T_max, final_wd=0.0): |
|
self.optimizer = optimizer |
|
self.ref_wd = ref_wd |
|
self.final_wd = final_wd |
|
self.T_max = T_max |
|
self._step = 0.0 |
|
|
|
def step(self): |
|
self._step += 1 |
|
progress = self._step / self.T_max |
|
new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1.0 + math.cos(math.pi * progress)) |
|
|
|
if self.final_wd <= self.ref_wd: |
|
new_wd = max(self.final_wd, new_wd) |
|
else: |
|
new_wd = min(self.final_wd, new_wd) |
|
|
|
for group in self.optimizer.param_groups: |
|
if ("WD_exclude" not in group) or not group["WD_exclude"]: |
|
group["weight_decay"] = new_wd |
|
return new_wd |
|
|