Pytorch DataLoader 读取tif(完整代码)

编程入门 行业动态 更新时间:2024-10-15 18:24:51

Pytorch DataLoader 读取tif(<a href=https://www.elefans.com/category/jswz/34/1771399.html style=完整代码)"/>

Pytorch DataLoader 读取tif(完整代码)

Python读取tif格式文件需要安装libtiff ,此外需要安装 inferno
本文适用于读取三维tif。

from torch.utils.data import DataLoader, Dataset
from inferno.io.transform.base import Transform, Compose
from inferno.io.transform.generic import Normalize, AsTorchBatch
from inferno.io.transform.image import RandomCrop, RandomRotate, RandomFlip 
from libtiff import TIFF
import os    
import torch
import numpy as np

定义一个MyDataSet

class MyDataSet(Dataset):def __init__(self, pathLst, transform): # Parameters and their form vary according to program needsdataPath, labelPath = pathLst    self.tifStreamData, self.tifStreamLabel = [], [] dataFiles, labelFiles = os.listdir(dataPath), os.listdir(labelPath) dataFiles.sort(key = lambda x: int(x[3:-4]))   #sorted by name order, such as LR_20.tiffor dataFile in dataFiles:dataFileName = os.path.join(dataPath, dataFile)self.tifStreamData.append(tiff2Stack(dataFileName, transform))          labelFiles.sort(key = lambda x: int(x[3:-4]))  for labelFile in labelFiles:labelFileName = os.path.join(labelPath, labelFile)self.tifStreamLabel.append(tiff2Stack(labelFileName, transform))assert len(self.tifStreamData) == len(self.tifStreamLabel)    # check length def __len__(self):return len(self.tifStreamData)def __getitem__(self,idx):data, label = self.tifStreamData, self.tifStreamLabelreturn data[idx], label[idx]def tiff2Stack(fileName, transform=None):  # read tif, data transform, output tensortif = TIFF.open(fileName,mode='r')tifLst = list(tif.iter_images()) # (51,101,101)tifArr = np.zeros((len(tifLst), tifLst[0].shape[0], tifLst[0].shape[1]))for i, img in enumerate(list(tif.iter_images())):tifArr[i,:,:] = img/1.0  # avoid that "can't convert np.ndarray of type numpy.uint16."if transform:tifArr = transform(tifArr)return tifArr  

调用

def main():transform = Compose(RandomRotate(), RandomFlip(), Normalize(), AsTorchBatch(2))pathLst = ["/your/tif/image/Data/path/", "/your/tif/image/Label/path/"]myTrainData = MyDataSet(pathLst, transform=transform)trainData = DataLoader(dataset=myTrainData, batch_size=4, shuffle=True)for i,j in enumerate(trainData):print(i)data, label = jprint("data.shape",data.shape,"label.shape",label.shape)if __name__ == "__main__":main()

更多推荐

Pytorch DataLoader 读取tif(完整代码)

本文发布于:2024-02-06 16:14:16,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1750317.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:完整   代码   Pytorch   DataLoader   tif

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!