完整代码)"/>
Pytorch DataLoader 读取tif(完整代码)
Python读取tif格式文件需要安装libtiff
,此外需要安装 inferno
本文适用于读取三维tif。
from torch.utils.data import DataLoader, Dataset
from inferno.io.transform.base import Transform, Compose
from inferno.io.transform.generic import Normalize, AsTorchBatch
from inferno.io.transform.image import RandomCrop, RandomRotate, RandomFlip
from libtiff import TIFF
import os
import torch
import numpy as np
定义一个MyDataSet
class MyDataSet(Dataset):def __init__(self, pathLst, transform): # Parameters and their form vary according to program needsdataPath, labelPath = pathLst self.tifStreamData, self.tifStreamLabel = [], [] dataFiles, labelFiles = os.listdir(dataPath), os.listdir(labelPath) dataFiles.sort(key = lambda x: int(x[3:-4])) #sorted by name order, such as LR_20.tiffor dataFile in dataFiles:dataFileName = os.path.join(dataPath, dataFile)self.tifStreamData.append(tiff2Stack(dataFileName, transform)) labelFiles.sort(key = lambda x: int(x[3:-4])) for labelFile in labelFiles:labelFileName = os.path.join(labelPath, labelFile)self.tifStreamLabel.append(tiff2Stack(labelFileName, transform))assert len(self.tifStreamData) == len(self.tifStreamLabel) # check length def __len__(self):return len(self.tifStreamData)def __getitem__(self,idx):data, label = self.tifStreamData, self.tifStreamLabelreturn data[idx], label[idx]def tiff2Stack(fileName, transform=None): # read tif, data transform, output tensortif = TIFF.open(fileName,mode='r')tifLst = list(tif.iter_images()) # (51,101,101)tifArr = np.zeros((len(tifLst), tifLst[0].shape[0], tifLst[0].shape[1]))for i, img in enumerate(list(tif.iter_images())):tifArr[i,:,:] = img/1.0 # avoid that "can't convert np.ndarray of type numpy.uint16."if transform:tifArr = transform(tifArr)return tifArr
调用
def main():transform = Compose(RandomRotate(), RandomFlip(), Normalize(), AsTorchBatch(2))pathLst = ["/your/tif/image/Data/path/", "/your/tif/image/Label/path/"]myTrainData = MyDataSet(pathLst, transform=transform)trainData = DataLoader(dataset=myTrainData, batch_size=4, shuffle=True)for i,j in enumerate(trainData):print(i)data, label = jprint("data.shape",data.shape,"label.shape",label.shape)if __name__ == "__main__":main()
更多推荐
Pytorch DataLoader 读取tif(完整代码)
发布评论