pytorch DistributedDataParallel 多卡训练结果变差的问题分析

编程入门 行业动态 更新时间:2024-10-06 18:29:12

<a href=https://www.elefans.com/category/jswz/34/1769961.html style=pytorch DistributedDataParallel 多卡训练结果变差的问题分析"/>

pytorch DistributedDataParallel 多卡训练结果变差的问题分析

DDP 数据shuffle 的设置

使用DDP要给dataloader传入sampler参数(torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)) 。 默认shuffle=True,但按照pytorch DistributedSampler的实现:

    def __iter__(self) -> Iterator[T_co]:if self.shuffle:# deterministically shuffle based on epoch and seedg = torch.Generator()g.manual_seed(self.seed + self.epoch)indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignoreelse:indices = list(range(len(self.dataset)))  # type: ignore

产生随机indix的种子是和当前的epoch有关,所以需要在训练的时候手动set epoch的值来实现真正的shuffle:

for epoch in range(start_epoch, n_epochs):if is_distributed:sampler.set_epoch(epoch)train(loader)
DDP 增大batchsize 效果变差的问题

large batchsize:
理论上的优点:

  • 数据中的噪声影响可能会变小,可能容易接近最优点;

缺点和问题:

  • 降低了梯度的variance;(理论上,对于凸优化问题,低的梯度variance可以得到更好的优化效果; 但是实际上Keskar et al验证了增大batchsize会导致差的泛化能力);
  • 对于非凸优化问题,损失函数包含多个局部最优点,小的batchsize有噪声的干扰可能容易跳出局部最优点,而大的batchsize有可能停在局部最优点跳不出来。

解决方法:

  • 增大learning_rate,但是可能出现问题,在训练开始就用很大的learning_rate 可能导致模型不收敛 (.04836)
  • 使用warming up (.02677)
warmup

在训练初期就用很大的learning_rate可能会导致训练不收敛的问题,warmup的思想是在训练初期用小的学习率,随着训练慢慢变大学习率,直到base learning_rate,再使用其他decay(CosineAnnealingLR)的方式训练.

# copy from .py
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateauclass GradualWarmupScheduler(_LRScheduler):""" Gradually warm-up(increasing) learning rate in optimizer.Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.Args:optimizer (Optimizer): Wrapped optimizer.multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.total_epoch: target learning rate is reached at total_epoch, graduallyafter_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)"""def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):self.multiplier = multiplierif self.multiplier < 1.:raise ValueError('multiplier should be greater thant or equal to 1.')self.total_epoch = total_epochself.after_scheduler = after_schedulerself.finished = Falsesuper(GradualWarmupScheduler, self).__init__(optimizer)def get_lr(self):if self.last_epoch > self.total_epoch:if self.after_scheduler:if not self.finished:self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]self.finished = Truereturn self.after_scheduler.get_last_lr()return [base_lr * self.multiplier for base_lr in self.base_lrs]if self.multiplier == 1.0:return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]else:return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]def step_ReduceLROnPlateau(self, metrics, epoch=None):if epoch is None:epoch = self.last_epoch + 1self.last_epoch = epoch if epoch != 0 else 1  # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginningif self.last_epoch <= self.total_epoch:warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):param_group['lr'] = lrelse:if epoch is None:self.after_scheduler.step(metrics, None)else:self.after_scheduler.step(metrics, epoch - self.total_epoch)def step(self, epoch=None, metrics=None):if type(self.after_scheduler) != ReduceLROnPlateau:if self.finished and self.after_scheduler:if epoch is None:self.after_scheduler.step(None)else:self.after_scheduler.step(epoch - self.total_epoch)self._last_lr = self.after_scheduler.get_last_lr()else:return super(GradualWarmupScheduler, self).step(epoch)else:self.step_ReduceLROnPlateau(metrics, epoch)

参考:

/

更多推荐

pytorch DistributedDataParallel 多卡训练结果变差的问题分析

本文发布于:2024-02-14 00:44:19,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1761146.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:pytorch   DistributedDataParallel

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!