PYTORCH LIGHTNING DATAMODULES及SimCLR源码解读

编程入门 行业动态 更新时间:2024-10-09 15:15:07

PYTORCH LIGHTNING DATAMODULES及SimCLR<a href=https://www.elefans.com/category/jswz/34/1770099.html style=源码解读"/>

PYTORCH LIGHTNING DATAMODULES及SimCLR源码解读

官方文档地址
项目地址

PYTORCH LIGHTNING DATAMODULES

DataModules将数据与模型解耦分开,从而可以只关注模型本身而不用关注数据
自定义DataModules时需要继承LightningModule,并实现以下几个方法

def __init__(self): # 一般用来指定data_dir(数据目录),定义transform,定义默认的self.dims,方便后面对数据的使用def prepare_data(self): # 下载数据,在该函数对不对数据进行任何操作def setup(self,stage): # 加载之前下载好的数据,并分配到训练、验证和测试上,stage可以为'fit'或'test',若为'fit'只分配训练集,'test'分配测试机,None则都分配def train_dataloader(self): # 返回训练集的dataloaderdef val_dataloader(self): # 返回验证集的dataloaderdef test_dataloader(self): # 返回测试集的dataloader

自定义cifar10的DataModule

class CIFAR10DataModule(LightningDataModule):def __init__(self, data_dir: str = "./"):super().__init__()self.data_dir = data_dirself.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])self.dims = (3, 32, 32)self.num_classes = 10def prepare_data(self):# downloadCIFAR10(self.data_dir, train=True, download=True)CIFAR10(self.data_dir, train=False, download=True)def setup(self, stage=None):# Assign train/val datasets for use in dataloadersif stage == "fit" or stage is None:cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])# Assign test dataset for use in dataloader(s)if stage == "test" or stage is None:self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)def train_dataloader(self):return DataLoader(self.cifar_train, batch_size=BATCH_SIZE)def val_dataloader(self):return DataLoader(self.cifar_val, batch_size=BATCH_SIZE)def test_dataloader(self):return DataLoader(self.cifar_test, batch_size=BATCH_SIZE)

SimCLR源码解读

项目同时实现了多个自监督学习方法,定义了vision基类

import os
from abc import abstractmethod
from typing import Any, Callable, List, Optional, Unionimport torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, random_splitclass VisionDataModule(LightningDataModule): # 继承LightningDataModuleEXTRA_ARGS: dict = {}name: str = ""#: Dataset class to usedataset_cls: type # 用来下载数据以及数据的划分,比如后面要加载CIFAR10,dataset_cls就是CIFAR10,要加载自己的数据集的话需要相应的实现#: A tuple describing the shape of the datadims: tupledef __init__(self,data_dir: Optional[str] = None,val_split: Union[int, float] = 0.2,# 如果是int则是验证集数据的长度,如果是float是验证集占训练集的百分比num_workers: int = 0,normalize: bool = False,batch_size: int = 32,seed: int = 42,shuffle: bool = True,pin_memory: bool = True,drop_last: bool = False,*args: Any,**kwargs: Any,) -> None:"""Args:data_dir: Where to save/load the dataval_split: Percent (float) or number (int) of samples to use for the validation splitnum_workers: How many workers to use for loading datanormalize: If true applies image normalizebatch_size: How many samples per batch to loadseed: Random seed to be used for train/val/test splitsshuffle: If true shuffles the train data every epochpin_memory: If true, the data loader will copy Tensors into CUDA pinned memory beforereturning themdrop_last: If true drops the last incomplete batch"""super().__init__(*args, **kwargs)self.data_dir = data_dir if data_dir is not None else os.getcwd()self.val_split = val_splitself.num_workers = num_workersself.normalize = normalizeself.batch_size = batch_sizeself.seed = seedself.shuffle = shuffleself.pin_memory = pin_memoryself.drop_last = drop_lastdef prepare_data(self, *args: Any, **kwargs: Any) -> None: # 下载数据"""Saves files to data_dir."""self.dataset_cls(self.data_dir, train=True, download=True)self.dataset_cls(self.data_dir, train=False, download=True)def setup(self, stage: Optional[str] = None) -> None:"""Creates train, val, and test dataset."""if stage == "fit" or stage is None: # 分配训练集数据train_transforms = self.default_transforms() if self.train_transforms is None else self.train_transformsval_transforms = self.default_transforms() if self.val_transforms is None else self.val_transformsdataset_train = self.dataset_cls(self.data_dir, train=True, transform=train_transforms, **self.EXTRA_ARGS)dataset_val = self.dataset_cls(self.data_dir, train=True, transform=val_transforms, **self.EXTRA_ARGS)# Split 分割数据集self.dataset_train = self._split_dataset(dataset_train)self.dataset_val = self._split_dataset(dataset_val, train=False)if stage == "test" or stage is None: # 分配测试数据test_transforms = self.default_transforms() if self.test_transforms is None else self.test_transformsself.dataset_test = self.dataset_cls(self.data_dir, train=False, transform=test_transforms, **self.EXTRA_ARGS)def _split_dataset(self, dataset: Dataset, train: bool = True) -> Dataset:"""Splits the dataset into train and validation set."""len_dataset = len(dataset)  # type: ignore[arg-type]splits = self._get_splits(len_dataset) # 分割后的训练集和测试集数据dataset_train, dataset_val = random_split(dataset, splits, generator=torch.Generator().manual_seed(self.seed))if train:return dataset_trainreturn dataset_valdef _get_splits(self, len_dataset: int) -> List[int]:"""Computes split lengths for train and validation set."""if isinstance(self.val_split, int):train_len = len_dataset - self.val_split # 训练集长度splits = [train_len, self.val_split] # 返回分割后的训练集和验证集elif isinstance(self.val_split, float):val_len = int(self.val_split * len_dataset)train_len = len_dataset - val_lensplits = [train_len, val_len]else:raise ValueError(f"Unsupported type {type(self.val_split)}")return splits@abstractmethoddef default_transforms(self) -> Callable: # 子类实现"""Default transform for the dataset."""# dataloaderdef train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:"""The train dataloader."""return self._data_loader(self.dataset_train, shuffle=self.shuffle)def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:"""The val dataloader."""return self._data_loader(self.dataset_val)def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:"""The test dataloader."""return self._data_loader(self.dataset_test)def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:return DataLoader(dataset,batch_size=self.batch_size,shuffle=shuffle,num_workers=self.num_workers,drop_last=self.drop_last,pin_memory=self.pin_memory,)

实现CIFAR10DataModule,继承VisionDataModule

from typing import Any, Callable, Optional, Sequence, Unionfrom pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets import TrialCIFAR10
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkgif _TORCHVISION_AVAILABLE:from torchvision import transforms as transform_libfrom torchvision.datasets import CIFAR10
else:  # pragma: no coverwarn_missing_pkg("torchvision")CIFAR10 = Noneclass CIFAR10DataModule(VisionDataModule):""".. figure:: .png:width: 400:alt: CIFAR-10Specs:- 10 classes (1 per class)- Each image is (3 x 32 x 32)Standard CIFAR10, train, val, test splits and transformsTransforms::mnist_transforms = transform_lib.Compose([transform_lib.ToTensor(),transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],std=[x / 255.0 for x in [63.0, 62.1, 66.7]])])Example::from pl_bolts.datamodules import CIFAR10DataModuledm = CIFAR10DataModule(PATH)model = LitModel()Trainer().fit(model, datamodule=dm)Or you can set your own transformsExample::dm.train_transforms = ...dm.test_transforms = ...dm.val_transforms  = ..."""name = "cifar10"dataset_cls = CIFAR10 # 下载、处理数据dims = (3, 32, 32)def __init__(self,data_dir: Optional[str] = None,val_split: Union[int, float] = 0.2,num_workers: int = 0,normalize: bool = False,batch_size: int = 32,seed: int = 42,shuffle: bool = True,pin_memory: bool = True,drop_last: bool = False,# true则丢掉最后一个不满的batch*args: Any,**kwargs: Any,) -> None:"""Args:data_dir: Where to save/load the dataval_split: Percent (float) or number (int) of samples to use for the validation splitnum_workers: How many workers to use for loading datanormalize: If true applies image normalizebatch_size: How many samples per batch to loadseed: Random seed to be used for train/val/test splitsshuffle: If true shuffles the train data every epochpin_memory: If true, the data loader will copy Tensors into CUDA pinned memory beforereturning themdrop_last: If true drops the last incomplete batch"""super().__init__(  # type: ignore[misc]data_dir=data_dir,val_split=val_split,num_workers=num_workers,normalize=normalize,batch_size=batch_size,seed=seed,shuffle=shuffle,pin_memory=pin_memory,drop_last=drop_last,*args,**kwargs,)@propertydef num_samples(self) -> int:train_len, _ = self._get_splits(len_dataset=50_000)return train_len@propertydef num_classes(self) -> int:"""Return:10"""return 10def default_transforms(self) -> Callable: # 覆盖基类函数if self.normalize:cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])else:cf10_transforms = transform_lib.Compose([transform_lib.ToTensor()])return cf10_transforms

上面代码的CIFAR10是直接调用torch.vision里的,在项目里也实现了这个函数,要加载torch没有的数据集的话可以参考一下

class LightDataset(ABC, Dataset):data: Tensortargets: Tensornormalize: tupledir_path: strcache_folder_name: strDATASET_NAME = "light"def __len__(self) -> int:return len(self.data)@propertydef cached_folder_path(self) -> str:return os.path.join(self.dir_path, self.DATASET_NAME, self.cache_folder_name)@staticmethoddef _prepare_subset(full_data: Tensor,full_targets: Tensor,num_samples: int,labels: Sequence,) -> Tuple[Tensor, Tensor]:"""Prepare a subset of a common dataset."""classes = {d: 0 for d in labels}indexes = []for idx, target in enumerate(full_targets):label = target.item()if classes.get(label, float("inf")) >= num_samples:continueindexes.append(idx)classes[label] += 1if all(classes[k] >= num_samples for k in classes):breakdata = full_data[indexes]targets = full_targets[indexes]return data, targetsdef _download_from_url(self, base_url: str, data_folder: str, file_name: str):url = os.path.join(base_url, file_name)logging.info(f"Downloading {url}")fpath = os.path.join(data_folder, file_name)try:urllib.request.urlretrieve(url, fpath)except HTTPError as err:raise RuntimeError(f"Failed download from {url}") from errclass CIFAR10(LightDataset):"""Customized `CIFAR10 <.html>`_ dataset for testing Pytorch Lightningwithout the torchvision dependency.Part of the code was copied from.5.0/torchvision/datasets/Args:data_dir: Root directory of dataset where ``CIFAR10/processed/training.pt``and  ``CIFAR10/processed/test.pt`` exist.train: If ``True``, creates dataset from ``training.pt``,otherwise from ``test.pt``.download: If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again.Examples:>>> from torchvision import transforms>>> from pl_bolts.transforms.dataset_normalizations import cifar10_normalization>>> cf10_transforms = transforms.Compose([transforms.ToTensor(), cifar10_normalization()])>>> dataset = CIFAR10(download=True, transform=cf10_transforms, data_dir="datasets")>>> len(dataset)50000>>> torch.bincount(dataset.targets)tensor([5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000])>>> data, label = dataset[0]>>> data.shapetorch.Size([3, 32, 32])>>> label6Labels::airplane: 0automobile: 1bird: 2cat: 3deer: 4dog: 5frog: 6horse: 7ship: 8truck: 9"""BASE_URL = "/" # 下载地址FILE_NAME = "cifar-10-python.tar.gz" # 目标文件cache_folder_name = "complete" # 完成解压后pt文件存放目录TRAIN_FILE_NAME = "training.pt" # 训练集TEST_FILE_NAME = "test.pt" # 测试集DATASET_NAME = "CIFAR10" # 根目录labels = set(range(10))relabel = Falsedef __init__(self, data_dir: str = ".", train: bool = True, transform: Optional[Callable] = None, download: bool = True):super().__init__()self.dir_path = data_dirself.train = train  # training set or test setself.transform = transformif not _PIL_AVAILABLE:raise ImportError("You want to use PIL.Image for loading but it is not installed yet.")os.makedirs(self.cached_folder_path, exist_ok=True)self.prepare_data(download)if not self._check_exists(self.cached_folder_path, (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME)):raise RuntimeError("Dataset not found.")data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAMEself.data, self.targets = torch.load(os.path.join(self.cached_folder_path, data_file))def __getitem__(self, idx: int) -> Tuple[Tensor, int]:img = self.data[idx].reshape(3, 32, 32)target = int(self.targets[idx])if self.transform is not None:img = img.numpy().transpose((1, 2, 0))  # convert to HWCimg = self.transform(Image.fromarray(img))if self.relabel:target = list(self.labels).index(target)return img, target@classmethoddef _check_exists(cls, data_folder: str, file_names: Sequence[str]) -> bool:if isinstance(file_names, str):file_names = [file_names]return all(os.path.isfile(os.path.join(data_folder, fname)) for fname in file_names)def _unpickle(self, path_folder: str, file_name: str) -> Tuple[Tensor, Tensor]:with open(os.path.join(path_folder, file_name), "rb") as fo:pkl = pickle.load(fo, encoding="bytes")return torch.tensor(pkl[b"data"]), torch.tensor(pkl[b"labels"])def _extract_archive_save_torch(self, download_path):# extract achievewith tarfile.open(os.path.join(download_path, self.FILE_NAME), "r:gz") as tar:tar.extractall(path=download_path)# this is internal path in the archivepath_content = os.path.join(download_path, "cifar-10-batches-py")# load Test and save as PTtorch.save(self._unpickle(path_content, "test_batch"), os.path.join(self.cached_folder_path, self.TEST_FILE_NAME))# load Train and save as PTdata, labels = [], []for i in range(5):fname = f"data_batch_{i + 1}"_data, _labels = self._unpickle(path_content, fname)data.append(_data)labels.append(_labels)# stash all to onedata = torch.cat(data, dim=0)labels = torch.cat(labels, dim=0)# and save as PTtorch.save((data, labels), os.path.join(self.cached_folder_path, self.TRAIN_FILE_NAME))def prepare_data(self, download: bool):if self._check_exists(self.cached_folder_path, (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME)):returnbase_path = os.path.join(self.dir_path, self.DATASET_NAME)if download:self.download(base_path)self._extract_archive_save_torch(base_path)def download(self, data_folder: str) -> None:"""Download the data if it doesn't exist in cached_folder_path already."""if self._check_exists(data_folder, self.FILE_NAME):returnself._download_from_url(self.BASE_URL, data_folder, self.FILE_NAME)

更多推荐

PYTORCH LIGHTNING DATAMODULES及SimCLR源码解读

本文发布于:2024-03-13 01:08:29,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1732839.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:源码   LIGHTNING   PYTORCH   SimCLR   DATAMODULES

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!