UUNet训练自己写的网络

编程入门 行业动态 更新时间:2024-10-25 12:19:19

UUNet训练自己写的<a href=https://www.elefans.com/category/jswz/34/1771439.html style=网络"/>

UUNet训练自己写的网络

记录贴写的很乱仅供参考。
自己写的Unet网络不带深度监督,但是NNUNet默认的训练方法是深度监督训练的,对应的模型也是带有深度监督的。但是NNUNetV2也贴心的提供了非深度监督的训练方法在该目录下:

也或者说我们想要自己去定义一个nnUNWtTrainer 去扩展NNunet的话,就可以参考这里面的py文件去写自己的,但是都建议以nnUNetTrainer为基类去继承它。就如nnUNetTrainerNoDeepSupervision类的写法一样(这个类就是去实现无深度监督网络的训练的):
展示一下这个文件:以及要修改成自己网络的地方。
`import torch
from torch import autocast

from nnunetv2.training.losspound_losses import DC_and_BCE_loss, DC_and_CE_loss
from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.helpers import dummy_context
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from torch.nn.parallel import DistributedDataParallel as DDP
from nnunetv2.Network.UNet import UNet

class nnUNetTrainerNoDeepSupervision(nnUNetTrainer):
def _build_loss(self):
if self.label_manager.has_regions:
loss = DC_and_BCE_loss({},
{‘batch_dice’: self.configuration_manager.batch_dice,
‘do_bg’: True, ‘smooth’: 1e-5, ‘ddp’: self.is_ddp},
use_ignore_label=self.label_manager.ignore_label is not None,
dice_class=MemoryEfficientSoftDiceLoss)
else:
loss = DC_and_CE_loss({‘batch_dice’: self.configuration_manager.batch_dice,
‘smooth’: 1e-5, ‘do_bg’: False, ‘ddp’: self.is_ddp}, {}, weight_ce=1, weight_dice=1,
ignore_label=self.label_manager.ignore_label,
dice_class=MemoryEfficientSoftDiceLoss)
return loss

def _get_deep_supervision_scales(self):return Nonedef initialize(self):if not self.was_initialized:self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,self.dataset_json)# selfwork = self.build_network_architecture(self.plans_manager, self.dataset_json,#                                                self.configuration_manager,#                                                self.num_input_channels,#                                                enable_deep_supervision=False).to(self.device)selfwork = UNet(self.num_input_channels, 2, base_c=32).to(self.device)print("="*20)print("now use our unet")print("=" * 20)self.optimizer, self.lr_scheduler = self.configure_optimizers()# if ddp, wrap in DDP wrapperif self.is_ddp:selfwork = torch.nn.SyncBatchNorm.convert_sync_batchnorm(selfwork)selfwork = DDP(selfwork, device_ids=[self.local_rank])self.loss = self._build_loss()self.was_initialized = Trueelse:raise RuntimeError("You have called self.initialize even though the trainer was already initialized. ""That should not happen.")def set_deep_supervision_enabled(self, enabled: bool):passdef validation_step(self, batch: dict) -> dict:data = batch['data']target = batch['target']data = data.to(self.device, non_blocking=True)if isinstance(target, list):target = [i.to(self.device, non_blocking=True) for i in target]else:target = target.to(self.device, non_blocking=True)self.optimizer.zero_grad(set_to_none=True)# Autocast is a little bitch.# If the device_type is 'cpu' then it's slow as heck and needs to be disabled.# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)# So autocast will only be active if we have a cuda device.with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():output = selfwork(data)del datal = self.loss(output, target)# the following is needed for online evaluation. Fake dice (green line)axes = [0] + list(range(2, output.ndim))if self.label_manager.has_regions:predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()else:# no need for softmaxoutput_seg = output.argmax(1)[:, None]predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32)predicted_segmentation_onehot.scatter_(1, output_seg, 1)del output_segif self.label_manager.has_ignore_label:if not self.label_manager.has_regions:mask = (target != self.label_manager.ignore_label).float()# CAREFUL that you don't rely on target after this line!target[target == self.label_manager.ignore_label] = 0else:mask = 1 - target[:, -1:]# CAREFUL that you don't rely on target after this line!target = target[:, :-1]else:mask = Nonetp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)tp_hard = tp.detach().cpu().numpy()fp_hard = fp.detach().cpu().numpy()fn_hard = fn.detach().cpu().numpy()if not self.label_manager.has_regions:# if we train with regions all segmentation heads predict some kind of foreground. In conventional# (softmax training) there needs tobe one output for the background. We are not interested in the# background Dice# [1:] in order to remove backgroundtp_hard = tp_hard[1:]fp_hard = fp_hard[1:]fn_hard = fn_hard[1:]return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}`

在selfwork处将网络替换为自己的非深度监督网络即可,比如我改成自己编写的UNet网络如下:

selfwork = UNet(self.num_input_channels, 2, base_c=32).to(self.device)
###下列为提示语句,以便确认是在调用该训练器进行训练
print("="*20)
print("now use our unet")
print("=" * 20)

最后需要在训练时候的脚本上加上 -tr 自己写的类名,此处就是 -tr nnUNetTrainerNoDeepSupervision
也就是最后的训练脚本如下:

nnUNetv2_train 002 2d 0 -tr nnUNetTrainerNoDeepSupervision

PS:此处也可以通过直接在run_training.py 文件中修改
这个命令行参数的默认值来实现。
好记录完毕,继续炼丹

更多推荐

UUNet训练自己写的网络

本文发布于:2023-11-16 18:19:24,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1629753.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:网络   UUNet

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!