状态 恢复中断训练"/>
保存dataloader状态 恢复中断训练
对于pytorch恢复一个epoch中的中断的训练时,通常dataloader都会从头加载,对于大型数据集不友好,loss又重新下降了
这时候可以自定义sampler
import random
from torch.utils.data.dataloader import Samplerrandom.seed(224) # use a fixed numberclass MySampler(Sampler):def __init__(self, data, i=0):random.shuffle(data)#自定义shuffleself.seq = list(range(len(data)))[i * batch_size:]def __iter__(self):return iter(self.seq)def __len__(self):return len(self.seq)
调用dataloader时传入自定义sampler,指定恢复的step
train_dataset = MyDataset(train_data)
train_sampler = MySampler(train_dataset, last_i)
train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler,shuffle=False) # don't forget to set DataLoader's shuffle to False
就可以啦!ref
也可以用笨方法,空跑到指定的step:
for batch in train_loader:if restart_step<global_step:restart_step+=1pbar.update(1)continue
更多推荐
保存dataloader状态 恢复中断训练
发布评论