是怎样计算loss的"/>
如何快速找到项目里是怎样计算loss的
上代码:
for epoch in range(Epochs):for i, (inputData, target) in enumerate(train_loader):# -------------------------------- compute loss ------------------------# breakinputData = inputData.cuda()target = target.cuda() # (batch,3,num_classes)target = target.max(dim=1)[0]with autocast(): # mixed precisionoutput = model(inputData).float() # sigmoid will be done in loss !loss = criterion(output, target)# ----------------------------------------------------------------------model.zero_grad()
直接搜索enumerate
,从处理input, targets
开始,一直到开始更新模型,如model.zero_grad()
结束;
这个例子比较规范,做了这些事:
- 处理输入;
- 把输入喂给模型,进行预测;
- 算loss;
学习实际项目,比如论文里附带的代码,是很好的学习方式,比单单看Pytorch教程要强!
完事!
更多推荐
如何快速找到项目里是怎样计算loss的
发布评论