admin管理员组

文章数量:1616425

.\lucidrains\x-transformers\x_transformers\__init__.py

# 从 x_transformers.x_transformers 模块中导入以下类
from x_transformers.x_transformers import (
    XTransformer,  # XTransformer 类,用于定义 Transformer 模型
    Encoder,  # Encoder 类,用于定义编码器
    Decoder,  # Decoder 类,用于定义解码器
    PrefixDecoder,  # PrefixDecoder 类,用于定义前缀解码器
    CrossAttender,  # CrossAttender 类,用于定义交叉注意力机制
    Attention,  # Attention 类,用于定义注意力机制
    TransformerWrapper,  # TransformerWrapper 类,用于包装 Transformer 模型
    ViTransformerWrapper  # ViTransformerWrapper 类,用于包装 Vision Transformer 模型
)

# 从 x_transformers.autoregressive_wrapper 模块中导入 AutoregressiveWrapper 类
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

# 从 x_transformers.nonautoregressive_wrapper 模块中导入 NonAutoregressiveWrapper 类
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper

# 从 x_transformers.continuous 模块中导入以下类
from x_transformers.continuous import (
    ContinuousTransformerWrapper,  # ContinuousTransformerWrapper 类,用于包装连续 Transformer 模型
    ContinuousAutoregressiveWrapper  # ContinuousAutoregressiveWrapper 类,用于包装连续自回归模型
)

# 从 x_transformers.xval 模块中导入以下类
from x_transformers.xval import (
    XValTransformerWrapper,  # XValTransformerWrapper 类,用于包装交叉验证 Transformer 模型
    XValAutoregressiveWrapper  # XValAutoregressiveWrapper 类,用于包装交叉验证自回归模型
)

# 从 x_transformers.xl_autoregressive_wrapper 模块中导入 XLAutoregressiveWrapper 类
from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper

# 从 x_transformers.dpo 模块中导入 DPO 类
from x_transformers.dpo import (
    DPO  # DPO 类,用于定义 Discrete-Continuous-Optimization 模型
)

x-unet

Implementation of a U-net complete with efficient attention as well as the latest research findings

Install

$ pip install x-unet

Usage

import torch
from x_unet import XUnet

unet = XUnet(
    dim = 64,
    channels = 3,
    dim_mults = (1, 2, 4, 8),
    nested_unet_depths = (7, 4, 2, 1),     # nested unet depths, from unet-squared paper
    consolidate_upsample_fmaps = True,     # whether to consolidate outputs from all upsample blocks, used in unet-squared paper
)

img = torch.randn(1, 3, 256, 256)
out = unet(img) # (1, 3, 256, 256)

For 3d (video or CT / MRI scans)

import torch
from x_unet import XUnet

unet = XUnet(
    dim = 64,
    frame_kernel_size = 3,                 # set this to greater than 1
    channels = 3,
    dim_mults = (1, 2, 4, 8),
    nested_unet_depths = (5, 4, 2, 1),     # nested unet depths, from unet-squared paper
    consolidate_upsample_fmaps = True,     # whether to consolidate outputs from all upsample blocks, used in unet-squared paper
    weight_standardize = True
)

video = torch.randn(1, 3, 10, 128, 128)    # (batch, channels, frames, height, width)
out = unet(video) # (1, 3, 10, 128, 128)

Todo

  • memory efficiency for 3d - reversible blocks, checkpointing, memory efficient unet
  • offer option for axial convolutions (placing frame convolutions at end of the resnet chain)

Citations

@article{Ronneberger2015UNetCN,
    title   = {U-Net: Convolutional Networks for Biomedical Image Segmentation},
    author  = {Olaf Ronneberger and Philipp Fischer and Thomas Brox},
    journal = {ArXiv},
    year    = {2015},
    volume  = {abs/1505.04597}
}
@article{Qin2020U2NetGD,
    title   = {U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection},
    author  = {Xuebin Qin and Zichen Vincent Zhang and Chenyang Huang and Masood Dehghan and Osmar R Zaiane and Martin J{\"a}gersand},
    journal = {ArXiv},
    year    = {2020},
    volume  = {abs/2005.09007}
}
@inproceedings{Henry2020QueryKeyNF,
    title   = {Query-Key Normalization for Transformers},
    author  = {Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen},
    booktitle = {FINDINGS},
    year    = {2020}
}
@article{Qiao2019WeightS,
    title   = {Weight Standardization},
    author  = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1903.10520}
}
@article{Shleifer2021NormFormerIT,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Sam Shleifer and Jason Weston and Myle Ott},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.09456}
}
@article{Sunkara2022NoMS,
    title   = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
    author  = {Raja Sunkara and Tie Luo},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.03641}
}
@inproceedings{Woo2023ConvNeXtVC,
    title   = {ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
    author  = {Sanghyun Woo and Shoubhik Debnath and Ronghang Hu and Xinlei Chen and Zhuang Liu and In-So Kweon and Saining Xie},
    year    = {2023}
}

.\lucidrains\x-unet\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'x-unet',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.3.1',  # 版本号
  license='MIT',  # 许可证
  description = 'X-Unet',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail',  # 作者邮箱
  url = 'https://github/lucidrains/x-unet',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'biomedical segmentation',
    'medical deep learning',
    'unets',
  ],
  install_requires=[  # 安装依赖
    'beartype',
    'einops>=0.4',
    'torch>=1.6',
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\x-unet\x_unet\x_unet.py

# 导入必要的库
from functools import partial
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
# 导入 einops 库中的函数和类
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
# 导入 beartype 库中的函数和类型
from beartype import beartype
from beartype.typing import Tuple, Union, Optional

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回值或默认值
def default(val, d):
    return val if exists(val) else d

# 检查一个数是否为2的幂
def is_power_two(n):
    return math.log2(n).is_integer()

# 检查一个数是否可以被另一个数整除
def divisible_by(num, denom):
    return (num % denom) == 0

# 将值转换为元组
def cast_tuple(val, length = None):
    if isinstance(val, list):
        val = tuple(val)

    output = val if isinstance(val, tuple) else ((val,) * default(length, 1))

    if exists(length):
        assert len(output) == length

    return output

# 辅助类

# 上采样函数
def Upsample(dim, dim_out):
    return nn.ConvTranspose3d(dim, dim_out, (1, 4, 4), (1, 2, 2), (0, 1, 1))

# 下采样函数
def Downsample(dim, dim_out):
    return nn.Sequential(
        Rearrange('b c f (h s1) (w s2) -> b (c s1 s2) f h w', s1 = 2, s2 = 2),
        nn.Conv3d(dim * 4, dim_out, 1)
    )

# 标准化

# 残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

# 层归一化
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + eps).sqrt() * self.gamma

# 权重标准化卷积
class WeightStandardizedConv3d(nn.Conv3d):
    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight

        mean = reduce(weight, 'o ... -> o 1 1 1 1', 'mean')
        var = reduce(weight, 'o ... -> o 1 1 1 1', partial(torch.var, unbiased = False))
        weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

# ResNet 块

# 块类
class Block(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8,
        weight_standardize = False,
        frame_kernel_size = 1
    ):
        super().__init__()
        kernel_conv_kwargs = partial(kernel_and_same_pad, frame_kernel_size)
        conv = nn.Conv3d if not weight_standardize else WeightStandardizedConv3d

        self.proj = conv(dim, dim_out, **kernel_conv_kwargs(3, 3))
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x):
        x = self.proj(x)
        x = self.norm(x)
        return self.act(x)

# ResNet 块类
class ResnetBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8,
        frame_kernel_size = 1,
        nested_unet_depth = 0,
        nested_unet_dim = 32,
        weight_standardize = False
    ):
        super().__init__()
        self.block1 = Block(dim, dim_out, groups = groups, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size)

        if nested_unet_depth > 0:
            self.block2 = NestedResidualUnet(dim_out, depth = nested_unet_depth, M = nested_unet_dim, frame_kernel_size = frame_kernel_size, weight_standardize = weight_standardize, add_residual = True)
        else:
            self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size)

        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x):
        h = self.block1(x)
        h = self.block2(h)
        return h + self.res_conv(x)

# ConvNeXT 2

# 全局响应归一化
class GRN(nn.Module):
    """ global response normalization, proposed in updated convnext paper """
    # 初始化函数,设置参数维度和容差值
    def __init__(self, dim, eps = 1e-5):
        # 调用父类的初始化函数
        super().__init__()
        # 设置容差值
        self.eps = eps
        # 初始化 gamma 参数为全零张量
        self.gamma = nn.Parameter(torch.zeros(dim, 1, 1, 1))
        # 初始化 bias 参数为全零张量
        self.bias = nn.Parameter(torch.zeros(dim, 1, 1, 1))

    # 前向传播函数
    def forward(self, x):
        # 计算 x 在指定维度上的 L2 范数
        spatial_l2_norm = x.norm(p = 2, dim = (2, 3, 4), keepdim = True)
        # 计算特征的归一化值
        feat_norm = spatial_l2_norm / spatial_l2_norm.mean(dim = -1, keepdim = True).clamp(min = self.eps)
        # 返回经过归一化和缩放后的特征值
        return x * feat_norm * self.gamma + self.bias + x
# 定义一个卷积块类,用于构建下一个卷积块
class ConvNextBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        *,
        mult = 2,
        frame_kernel_size = 1,
        nested_unet_depth = 0,
        nested_unet_dim = 32
    ):
        super().__init__()
        kernel_conv_kwargs = partial(kernel_and_same_pad, frame_kernel_size)

        # 深度卷积
        self.ds_conv = nn.Conv3d(dim, dim, **kernel_conv_kwargs(7, 7), groups = dim)

        inner_dim = dim_out * mult

        # 构建一个包含多个层的神经网络
        self.net = nn.Sequential(
            LayerNorm(dim),
            nn.Conv3d(dim, inner_dim, **kernel_conv_kwargs(3, 3), groups = dim_out),
            nn.GELU(),
            GRN(inner_dim),
            nn.Conv3d(inner_dim, dim_out, **kernel_conv_kwargs(3, 3), groups = dim_out)
        )

        # 嵌套的残差 UNet
        self.nested_unet = NestedResidualUnet(dim_out, depth = nested_unet_depth, M = nested_unet_dim, add_residual = True) if nested_unet_depth > 0 else nn.Identity()

        # 残差卷积
        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):
        h = self.ds_conv(x)
        h = self.net(h)
        h = self.nested_unet(h)
        return h + self.res_conv(x)

# 前馈神经网络
def FeedForward(dim, mult = 4.):
    inner_dim = int(dim * mult)
    return Residual(nn.Sequential(
        LayerNorm(dim),
        nn.Conv3d(dim, inner_dim, 1, bias = False),
        nn.GELU(),
        LayerNorm(inner_dim),   # properly credit assign normformer
        nn.Conv3d(inner_dim, dim, 1, bias = False)
    ))

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 64
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = heads * dim_head
        self.norm = LayerNorm(dim)

        self.to_qkv = nn.Conv3d(dim, inner_dim * 3, 1, bias = False)
        self.to_out = nn.Conv3d(inner_dim, dim, 1, bias = False)

    def forward(self, x):
        f, h, w = x.shape[-3:]

        residual = x.clone()

        x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) ... -> b h (...) c', h = self.heads), (q, k, v))

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h (f x y) d -> b (h d) f x y', f = f, x = h, y = w)
        return self.to_out(out) + residual

# Transformer 块
class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        **kwargs
    ):
        super().__init__()
        self.attn = Attention(dim, **kwargs)
        self.ff = FeedForward(dim)

    def forward(self, x):
        x = self.attn(x)
        x = self.ff(x)
        return x

# 特征图整合器
class FeatureMapConsolidator(nn.Module):
    def __init__(
        self,
        dim,
        *,
        dim_ins = tuple(),
        dim_outs = tuple(),
        resize_fmap_before = True,
        conv_block_fn = None
    ):
        super().__init__()
        assert len(dim_ins) == len(dim_outs)
        self.needs_consolidating = len(dim_ins) > 0

        block_fn = default(conv_block_fn, Block)

        # 特征图卷积层列表
        self.fmap_convs = nn.ModuleList([block_fn(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
        self.resize_fmap_before = resize_fmap_before

        self.final_dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)

    # 调整特征图大小
    def resize_fmaps(self, fmaps, height, width):
        return [F.interpolate(fmap, (fmap.shape[-3], height, width)) for fmap in fmaps]
    # 定义一个前向传播函数,接受输入 x 和特征图 fmaps,默认为 None
    def forward(self, x, fmaps = None):
        # 获取输入 x 的高度和宽度
        target_height, target_width = x.shape[-2:]

        # 如果未提供特征图 fmaps,则设置为空元组
        fmaps = default(fmaps, tuple())

        # 如果不需要合并特征图,则直接返回输入 x
        if not self.needs_consolidating:
            return x

        # 如果需要在卷积之前调整特征图大小
        if self.resize_fmap_before:
            # 调整特征图大小
            fmaps = self.resize_fmaps(fmaps, target_height, target_width)

        # 初始化一个空列表用于存储输出
        outs = []
        # 遍历特征图和卷积层,将卷积后的结果添加到输出列表中
        for fmap, conv in zip(fmaps, self.fmap_convs):
            outs.append(conv(fmap))

        # 如果需要在卷积之前调整特征图大小
        if self.resize_fmap_before:
            # 调整输出列表中的特征图大小
            outs = self.resize_fmaps(outs, target_height, target_width)

        # 将输入 x 和所有输出特征图连接在一起,沿着通道维度
        return torch.cat((x, *outs), dim = 1)
# 定义一个函数,返回一个类型为 type 或者包含 type 的元组
def MaybeTuple(type):
    return Union[type, Tuple[type, ...]]

# 根据卷积核大小计算 padding 大小
def kernel_and_same_pad(*kernel_size):
    paddings = tuple(map(lambda k: k // 2, kernel_size))
    return dict(kernel_size = kernel_size, padding = paddings)

# 定义 XUnet 类
class XUnet(nn.Module):

    # 初始化函数
    @beartype
    def __init__(
        self,
        dim,
        init_dim = None,
        out_dim = None,
        frame_kernel_size = 1,
        dim_mults: MaybeTuple(int) = (1, 2, 4, 8),
        num_blocks_per_stage: MaybeTuple(int) = (2, 2, 2, 2),
        num_self_attn_per_stage: MaybeTuple(int) = (0, 0, 0, 1),
        nested_unet_depths: MaybeTuple(int) = (0, 0, 0, 0),
        nested_unet_dim = 32,
        channels = 3,
        use_convnext = False,
        resnet_groups = 8,
        consolidate_upsample_fmaps = True,
        skip_scale = 2 ** -0.5,
        weight_standardize = False,
        attn_heads: MaybeTuple(int) = 8,
        attn_dim_head: MaybeTuple(int) = 32
    def forward(self, x):
        is_image = x.ndim == 4

        # 验证

        assert not (is_image and not self.train_as_images), 'you specified a frame kernel size for the convolutions in this unet, but you are passing in images'
        assert not (not is_image and self.train_as_images), 'you specified no frame kernel size dimension, yet you are passing in a video. fold the frame dimension into the batch'

        # 将图像转换为帧数为 1 的视频

        if is_image:
            x = rearrange(x, 'b c h w -> b c 1 h w')

        # 初始卷积

        x = self.init_conv(x)

        # 残差

        r = x.clone()

        # 下采样和上采样

        down_hiddens = []
        up_hiddens = []

        for init_block, blocks, attn_blocks, downsample in self.downs:
            x = init_block(x)

            for block in blocks:
                x = block(x)

            for attn_block in attn_blocks:
                x = attn_block(x)

            down_hiddens.append(x)
            x = downsample(x)

        x = self.mid(x)
        x = self.mid_attn(x) + x
        x = self.mid_after(x)

        up_hiddens.append(x)
        x = self.mid_upsample(x)


        for init_block, blocks, attn_blocks, upsample in self.ups:
            x = torch.cat((x, down_hiddens.pop() * self.skip_scale), dim=1)

            x = init_block(x)

            for block in blocks:
                x = block(x)

            for attn_block in attn_blocks:
                x = attn_block(x)

            up_hiddens.insert(0, x)
            x = upsample(x)

        # 合并特征图

        x = self.consolidator(x, up_hiddens)

        # 最终残差

        x = torch.cat((x, r), dim = 1)

        # 最终卷积

        out = self.final_conv(x)

        if is_image:
            out = rearrange(out, 'b c 1 h w -> b c h w')

        return out

# 定义 PixelShuffleUpsample 类
class PixelShuffleUpsample(nn.Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        scale_factor = 2
    ):
        super().__init__()
        self.scale_squared = scale_factor ** 2
        dim_out = default(dim_out, dim)
        conv = nn.Conv3d(dim, dim_out * self.scale_squared, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU(),
            Rearrange('b (c r s) f h w -> b c f (h r) (w s)', r = scale_factor, s = scale_factor)
        )

        self.init_conv_(conv)

    # 初始化卷积层
    def init_conv_(self, conv):
        o, i, *rest_dims = conv.weight.shape
        conv_weight = torch.empty(o // self.scale_squared, i, *rest_dims)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.scale_squared)

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    def forward(self, x):
        x = self.net(x)
        return x

# 定义 NestedResidualUnet 类
class NestedResidualUnet(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        dim,
        *,
        depth,
        M = 32,
        frame_kernel_size = 1,
        add_residual = False,
        groups = 4,
        skip_scale = 2 ** -0.5,
        weight_standardize = False
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 设置模型深度和下采样、上采样模块
        self.depth = depth
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])

        # 根据是否需要标准化权重选择卷积层类型
        conv = WeightStandardizedConv3d if weight_standardize else nn.Conv3d

        # 循环构建下采样模块
        for ind in range(depth):
            is_first = ind == 0
            dim_in = dim if is_first else M

            down = nn.Sequential(
                conv(dim_in, M, (1, 4, 4), stride = (1, 2, 2), padding = (0, 1, 1)),
                nn.GroupNorm(groups, M),
                nn.SiLU()
            )

            # 添加到下采样模块列表
            self.downs.append(down)

            # 构建上采样模块
            up = nn.Sequential(
                PixelShuffleUpsample(2 * M, dim_in),
                nn.GroupNorm(groups, dim_in),
                nn.SiLU()
            )

            # 添加到上采样模块列表
            self.ups.append(up)

        # 中间层模块
        self.mid = nn.Sequential(
            conv(M, M, **kernel_and_same_pad(frame_kernel_size, 3, 3)),
            nn.GroupNorm(groups, M),
            nn.SiLU()
        )

        # 设置跳跃连接的缩放因子和是否添加残差连接
        self.skip_scale = skip_scale
        self.add_residual = add_residual

    # 前向传播函数
    def forward(self, x, residual = None):
        # 判断输入是否为视频
        is_video = x.ndim == 5

        # 如果需要添加残差连接,则复制输入作为残差
        if self.add_residual:
            residual = default(residual, x.clone())

        # 获取输入张量的高度和宽度
        *_, h, w = x.shape

        # 计算模型层数
        layers = len(self.ups)

        # 检查输入张量的高度和宽度是否符合要求
        for dim_name, size in (('height', h), ('width', w)):
            assert divisible_by(size, 2 ** layers), f'{dim_name} dimension {size} must be divisible by {2 ** layers} ({layers} layers in nested unet)'
            assert (size % (2 ** self.depth)) == 0, f'the unet has too much depth for the image {dim_name} ({size}) being passed in'

        # hiddens

        # 存储中间特征
        hiddens = []

        # unet

        # 下采样过程
        for down in self.downs:
            x = down(x)
            hiddens.append(x.clone().contiguous())

        # 中间层处理
        x = self.mid(x)

        # 上采样过程
        for up in reversed(self.ups):
            x = torch.cat((x, hiddens.pop() * self.skip_scale), dim = 1)
            x = up(x)

        # 添加残差连接
        if self.add_residual:
            x = x + residual
            x = F.silu(x)

        # 返回处理后的张量
        return x

.\lucidrains\x-unet\x_unet\__init__.py

# 从 x_unet 模块中导入 XUnet 和 NestedResidualUnet 类
from x_unet.x_unet import XUnet, NestedResidualUnet

Zorro - Pytorch

Implementation of Zorro, Masked Multimodal Transformer, in Pytorch. This is a Deepmind work that claims a special masking strategy within a transformer help them achieve SOTA on a few multimodal benchmarks.

Appreciation

  • Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research

Install

$ pip install zorro-pytorch

Usage

import torch
from zorro_pytorch import Zorro, TokenTypes as T

model = Zorro(
    dim = 512,                        # model dimensions
    depth = 6,                        # depth
    dim_head = 64,                    # attention dimension heads
    heads = 8,                        # attention heads
    ff_mult = 4,                      # feedforward multiple
    num_fusion_tokens = 16,           # number of fusion tokens
    audio_patch_size = 16,            # audio patch size, can also be Tuple[int, int]
    video_patch_size = 16,            # video patch size, can also be Tuple[int, int]
    video_temporal_patch_size = 2,    # video temporal patch size
    video_channels = 3,               # video channels
    return_token_types = (
        T.AUDIO,
        T.AUDIO,
        T.FUSION,
        T.GLOBAL,
        T.VIDEO,
        T.VIDEO,
        T.VIDEO,
    ) # say you want to return 2 tokens for audio, 1 token for fusion, 3 for video - for whatever self-supervised learning, supervised learning, etc etc
)

video = torch.randn(2, 3, 8, 32, 32) # (batch, channels, time, height, width)
audio = torch.randn(2, 1024 * 10)    # (batch, time)

return_tokens = model(audio = audio, video = video) # (2, 6, 512) - all 6 tokes as indicated above is returned

# say you only want 1 audio and 1 video token, for contrastive learning

audio_token, video_token = model(audio = audio, video = video, return_token_indices = (0, 3)).unbind(dim = -2) # (2, 512), (2, 512)

Citations

@inproceedings{Recasens2023ZorroTM,
  title  = {Zorro: the masked multimodal transformer},
  author = {Adri{\`a} Recasens and Jason Lin and Jo{\~a}o Carreira and Drew Jaegle and Luyu Wang and Jean-Baptiste Alayrac and Pauline Luc and Antoine Miech and Lucas Smaira and Ross Hemsley and Andrew Zisserman},
  year   = {2023}
}

.\lucidrains\zorro-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'zorro-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.1.1',
  # 许可证类型
  license='MIT',
  # 描述
  description = 'Zorro - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github/lucidrains/zorro-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'multimodal fusion'
  ],
  # 安装依赖
  install_requires=[
    'beartype',
    'einops>=0.4',
    'torch>=1.6',
    'torchaudio'
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\zorro-pytorch\zorro_pytorch\zorro_pytorch.py

# 导入所需的模块和类
from enum import Enum
import functools
from functools import wraps

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

from beartype import beartype
from beartype.typing import Tuple, Optional, Union

from torchaudio.transforms import Spectrogram

# 定义枚举类型 TokenTypes,包含音频、视频、融合和全局四种类型
class TokenTypes(Enum):
    AUDIO = 0
    VIDEO = 1
    FUSION = 2
    GLOBAL = 3

# 定义一些通用的函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回参数列表中第一个存在的参数,如果都不存在则返回 None
def default(*args):
    for arg in args:
        if exists(arg):
            return arg
    return None

# 返回小于等于 n 的最接近的 divisor 的倍数
def round_down_nearest_multiple(n, divisor):
    return n // divisor * divisor

# 将输入转换为元组,如果输入不是元组则返回 (t, t)
def pair(t):
    return (t, t) if not isinstance(t, tuple) else t

# 对可迭代对象进行累积乘法
def cum_mul(it):
    return functools.reduce(lambda x, y: x * y, it, 1)

# 判断 numer 是否能被 denom 整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 装饰器

# 保证函数只调用一次的装饰器
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用 once 装饰的 print 函数,确保只打印一次
print_once = once(print)

# 无偏置的 Layernorm 类
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# GEGLU 激活函数
class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return F.gelu(gate) * x

# FeedForward 网络结构
def FeedForward(dim, mult = 4):
    inner_dim = int(dim * mult * 2 / 3)
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, inner_dim * 2, bias = False),
        GEGLU(),
        nn.Linear(inner_dim, dim, bias = False)
    )

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(
        self,
        x,
        context = None,
        attn_mask = None
    ):
        x = self.norm(x)
        kv_x = default(context, x)

        q, k, v = (self.to_q(x), *self.to_kv(kv_x).chunk(2, dim = -1))

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        if exists(attn_mask):
            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 主类 Zorro
class Zorro(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        num_fusion_tokens = 16,
        audio_patch_size: Union[int, Tuple[int, int]] = 16,
        video_patch_size: Union[int, Tuple[int, int]] = 16,
        video_temporal_patch_size = 2,
        video_channels = 3,
        spec_n_fft = 128,
        spec_power = 2,
        spec_win_length = 24,
        spec_hop_length = None,
        spec_pad = 0,
        spec_center = True,
        spec_pad_mode = 'reflect',
        spec_aug_stretch_factor = 0.8,
        spec_aug_freq_mask = 80,
        spec_aug_time_mask = 80,
        return_token_types: Tuple[TokenTypes] = (TokenTypes.AUDIO, TokenTypes.VIDEO, TokenTypes.FUSION)
        ):
        # 调用父类的构造函数
        super().__init__()
        # 设置最大返回标记数为返回标记类型列表的长度
        self.max_return_tokens = len(return_token_types)

        # 存储返回标记类型列表
        self.return_token_types = return_token_types
        # 将返回标记类型列表转换为张量
        return_token_types_tensor = torch.tensor(list(map(lambda t: t.value, return_token_types)))
        # 将返回标记类型张量注册为缓冲区
        self.register_buffer('return_token_types_tensor', return_token_types_tensor, persistent=False)

        # 初始化返回标记张量
        self.return_tokens = nn.Parameter(torch.randn(self.max_return_tokens, dim))
        # 初始化注意力池
        self.attn_pool = Attention(dim=dim, dim_head=dim_head, heads=heads)

        # 音频输入

        # 设置音频块大小
        self.audio_patch_size = audio_patch_height, audio_patch_width = pair(audio_patch_size)

        # 初始化频谱图
        self.spec = Spectrogram(
            n_fft=spec_n_fft,
            power=spec_power,
            win_length=spec_win_length,
            hop_length=spec_hop_length,
            pad=spec_pad,
            center=spec_center,
            pad_mode=spec_pad_mode
        )

        # 计算音频输入维度
        audio_input_dim = cum_mul(self.audio_patch_size)
        # 将音频转换为标记
        self.audio_to_tokens = nn.Sequential(
            Rearrange('b (h p1) (w p2) -> b h w (p1 p2)', p1=audio_patch_height, p2=audio_patch_width),
            nn.LayerNorm(audio_input_dim),
            nn.Linear(audio_input_dim, dim),
            nn.LayerNorm(dim)
        )

        # 视频输入

        # 设置视频块大小
        self.video_patch_size = (video_temporal_patch_size, *pair(video_patch_size))

        # 计算视频输入维度
        video_input_dim = cum_mul(self.video_patch_size) * video_channels
        video_patch_time, video_patch_height, video_patch_width = self.video_patch_size

        # 将视频转换为标记
        self.video_to_tokens = nn.Sequential(
            Rearrange('b c (t p1) (h p2) (w p3) -> b t h w (c p1 p2 p3)', p1=video_patch_time, p2=video_patch_height, p3=video_patch_width),
            nn.LayerNorm(video_input_dim),
            nn.Linear(video_input_dim, dim),
            nn.LayerNorm(dim)
        )

        # 融合标记

        # 初始化融合标记
        self.fusion_tokens = nn.Parameter(torch.randn(num_fusion_tokens, dim))

        # transformer

        # 初始化层列表
        self.layers = nn.ModuleList([])

        # 循环创建指定数量的层
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim=dim, dim_head=dim_head, heads=heads),
                FeedForward(dim=dim, mult=ff_mult)
            ]))

        # 初始化层归一化
        self.norm = LayerNorm(dim)

    def forward(
        self,
        *,
        audio,
        video,
        return_token_indices: Optional[Tuple[int]] = None
        ):
        # 获取音频的批次大小和设备信息
        batch, device = audio.shape[0], audio.device
    
        # 验证视频是否可以被分块
        assert all([divisible_by(numer, denom) for denom, numer in zip(self.video_patch_size, tuple(video.shape[-3:]))]), f'video shape {video.shape[-3:]} needs to be divisible by {self.video_patch_size}'

        # 如果音频产生的二维频谱图不是patch大小的倍数,则自动裁剪
        audio = self.spec(audio)

        height, width = audio.shape[-2:]
        patch_height, patch_width = self.audio_patch_size

        rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width)))

        if (height, width) != (rounded_height, rounded_width): # 只要打印,直到修复为止
            print_once(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer')

        audio = audio[..., :rounded_height, :rounded_width]

        # 转换为tokens
        audio_tokens = self.audio_to_tokens(audio)
        video_tokens = self.video_to_tokens(video)
        fusion_tokens = repeat(self.fusion_tokens, 'n d -> b n d', b = batch)

        # 构建所有tokens
        audio_tokens, fusion_tokens, video_tokens = map(lambda t: rearrange(t, 'b ... d -> b (...) d'), (audio_tokens, fusion_tokens, video_tokens))
        tokens, ps = pack((
            audio_tokens,
            fusion_tokens,
            video_tokens
        ), 'b * d')

        # 构建mask(即zorro)
        token_types = torch.tensor(list((
            *((TokenTypes.AUDIO.value,) * audio_tokens.shape[-2]),
            *((TokenTypes.FUSION.value,) * fusion_tokens.shape[-2]),
            *((TokenTypes.VIDEO.value,) * video_tokens.shape[-2]),
        )), device = device, dtype = torch.long)

        token_types_attend_from = rearrange(token_types, 'i -> i 1')
        token_types_attend_to = rearrange(token_types, 'j -> 1 j')

        # 逻辑是每个模态,包括融合,都可以关注自己
        zorro_mask = token_types_attend_from == token_types_attend_to

        # 融合可以关注所有
        zorro_mask = zorro_mask | (token_types_attend_from == TokenTypes.FUSION.value)

        # 注意力和前馈
        for attn, ff in self.layers:
            tokens = attn(tokens, attn_mask = zorro_mask) + tokens
            tokens = ff(tokens) + tokens

        tokens = self.norm(tokens)

        # 最终注意力池化 - 每个模态池token只能关注自己的tokens
        return_tokens = self.return_tokens
        return_token_types_tensor = self.return_token_types_tensor

        if exists(return_token_indices):
            assert len(set(return_token_indices)) == len(return_token_indices), 'all indices must be unique'
            assert all([indice < self.max_return_tokens for indice in return_token_indices]), 'indices must range from 0 to max_num_return_tokens - 1'

            return_token_indices = torch.tensor(return_token_indices, dtype = torch.long, device = device)

            return_token_types_tensor = return_token_types_tensor[return_token_indices]
            return_tokens = return_tokens[return_token_indices]

        return_tokens = repeat(return_tokens, 'n d -> b n d', b = batch)
        pool_mask = rearrange(return_token_types_tensor, 'i -> i 1') == token_types_attend_to
        # 全局查询可以关注所有tokens
        pool_mask = pool_mask | rearrange(return_token_types_tensor, 'i -> i 1') == torch.ones_like(token_types_attend_to, dtype=torch.long) * TokenTypes.GLOBAL.value

        pooled_tokens = self.attn_pool(return_tokens, context = tokens, attn_mask = pool_mask) + return_tokens

        return pooled_tokens

.\lucidrains\zorro-pytorch\zorro_pytorch\__init__.py

# 从 zorro_pytorch.zorro_pytorch 模块中导入 Zorro 类和 TokenTypes 常量
from zorro_pytorch.zorro_pytorch import Zorro, TokenTypes

本文标签: 源码一百一十项目系列Lucidrains