【AI达人特训营】RiR论文复现

编程入门 行业动态 更新时间:2024-10-11 23:23:22

【AI<a href=https://www.elefans.com/category/jswz/34/1750605.html style=达人特训营】RiR论文复现"/>

【AI达人特训营】RiR论文复现

【AI达人特训营】RiR论文复现

摘要

        残差网络(ResNets)在计算机视觉任务中达到了state of art。我们提出了Resnet in Resnet(RiR):一种深度dual-stream架构,它对ResNets和标准的CNN进行了推广,并且很容易实现(没有额外的计算开销)。RiR在ResNets的基础上进一步提高了性能(同样是在CIFAR-10数据集上,采用和ResNets一样的数据增强技术),并且在CIFAR-100上达到了新的state of art。

1. RiR

        本文提出了一个广义残差网络架构,对ResNet和标准CNN进行推广,广义残差网络架构的模块化单元是一个并行结构的广义残差块,并行包含了一个残差通道 r \text{r} r 和一个瞬变通道 t \text{t} t 。残差通道采用和ResNet类似的identity shortcut连接,瞬变通道采用标准的卷积层。另外,有两组fliter对两个通道进行交叉卷积( W l , r → t W_{l, \mathrm{r} \rightarrow \mathrm{t}} Wl,r→t​ 和 W l , t → r W_{l, \mathrm{t} \rightarrow \mathrm{r}} Wl,t→r​ )
r l + 1 = σ ( conv ⁡ ( r l , W l , r → r ) + conv ⁡ ( t l , W l , t → r ) + shortcut ⁡ ( r l ) ) t l + 1 = σ ( conv ⁡ ( r l , W l , r → t ) + conv ⁡ ( t l , W l , t → r ) ) \begin{array}{c} \mathrm{r}_{l+1}=\sigma\left(\operatorname{conv}\left(\mathrm{r}_{l}, W_{l, \mathrm{r} \rightarrow \mathrm{r}}\right)+\operatorname{conv}\left(\mathrm{t}_{l}, W_{l, \mathrm{t} \rightarrow \mathrm{r}}\right)+\operatorname{shortcut}\left(\mathrm{r}_{l}\right)\right) \\ \mathrm{t}_{l+1}=\sigma\left(\operatorname{conv}\left(\mathrm{r}_{l}, W_{l, \mathrm{r} \rightarrow \mathrm{t}}\right)+\operatorname{conv}\left(\mathrm{t}_{l}, W_{l, \mathrm{t} \rightarrow \mathrm{r}}\right)\right) \end{array} rl+1​=σ(conv(rl​,Wl,r→r​)+conv(tl​,Wl,t→r​)+shortcut(rl​))tl+1​=σ(conv(rl​,Wl,r→t​)+conv(tl​,Wl,t→r​))​
        残差通道的使用可以保留残差单元的优化特性,瞬变通道的使用将允许前层提取的特征被去除。下面是广义残差块的框架图

        如果残差通道的权重为0,广义残差块就相当于一个标准的卷积层;如果瞬变通道的权重为0,广义残差块就相当于标准的残差块。通过广义残差块的堆叠,网络可以学习图1b中的各种可能的结构(例如图1c)。新的广义残差块增强了信息处理能力。广义残差块不仅可以用于CNN,也可以用于其它类型的网络。用广义残差块(图1b)替换原始的残差块中的conv,就产生了一个新的架构(ResNet in ResNet(RiR)图1d),在图2中,我们总结了CNN、ResNet Init、ResNet和RiR架构之间的关系。

2. 代码复现

2.1 下载并导入所需要的包

!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex

2.2 创建数据集

train_tfm = transforms.Compose([transforms.Resize((40, 40)),transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),transforms.RandomResizedCrop(32, scale=(0.6, 1.0)),transforms.RandomHorizontalFlip(0.5),transforms.RandomRotation(20),paddlex.transforms.MixupImage(),transforms.ToTensor(),transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])test_tfm = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2)

2.3 标签平滑

class LabelSmoothingCrossEntropy(nn.Layer):def __init__(self, smoothing=0.1):super().__init__()self.smoothing = smoothingdef forward(self, pred, target):confidence = 1. - self.smoothinglog_probs = F.log_softmax(pred, axis=-1)idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)nll_loss = paddle.gather_nd(-log_probs, index=idx)smooth_loss = paddle.mean(-log_probs, axis=-1)loss = confidence * nll_loss + self.smoothing * smooth_lossreturn loss.mean()

2.4 RiR模型搭建

        本文实验采用的是RiR-18,其网络结构如下图所示:

class RiR_Init(nn.Layer):def __init__(self, in_channel, out_channel=None, stride=1):super().__init__()self.in_channel = in_channelself.out_channel = out_channel if out_channel is not None else in_channelself.stride = strideself.conv_res1 = nn.Conv2D(in_channel, out_channel, 3, stride, 1)self.conv_res2 = nn.Conv2D(in_channel, out_channel, 3, stride, 1)self.conv1 = nn.Conv2D(in_channel, out_channel, 3, stride, 1)self.conv2 = nn.Conv2D(in_channel, out_channel, 3, stride, 1)self.bnres = nn.BatchNorm2D(out_channel)self.relures = nn.ReLU()self.bn = nn.BatchNorm2D(out_channel)self.relu = nn.ReLU()self.resize_indentity = (in_channel != out_channel) or (stride != 1)self.indentity_connection = nn.Conv2D(in_channel, out_channel, 1, stride) if self.resize_indentity else nn.Identity()def forward(self, x_res, x_tran):x_shortcut = self.indentity_connection(x_res)x_res1 = self.conv_res1(x_res)x_res2 = self.conv_res2(x_res)x1 = self.conv1(x_tran)x2 = self.conv2(x_tran)out_res = x_res1 + x1 + x_shortcutout_tran = x_res2 + x2out_res = self.relures(self.bnres(out_res))out_tran = self.relu(self.bn(out_tran))return out_res, out_tran
class RiRBlock(nn.Layer):def __init__(self, in_channel, out_channel, stride=1):super().__init__()self.rir_init1 = RiR_Init(in_channel, out_channel, stride)self.rir_init2 = RiR_Init(out_channel ,out_channel, 1)self.resize_indentity = (in_channel != out_channel) or (stride != 1)self.indentity_connection1 = nn.Conv2D(in_channel, out_channel, 1, stride) if self.resize_indentity else nn.Identity()self.indentity_connection2 = nn.Conv2D(in_channel, out_channel, 1, stride) if self.resize_indentity else nn.Identity()def forward(self, x_res, x_tran):x_shortcut1 = self.indentity_connection2(x_res)x_shortcut2 = self.indentity_connection2(x_tran)out_res, out_tran = self.rir_init1(x_res, x_tran)out_res, out_tran = self.rir_init2(out_res, out_tran)out_res = x_shortcut1 + out_resout_tran = x_shortcut2 + out_tranreturn out_res, out_tran
class RiRInitBlock(nn.Layer):def __init__(self, in_channel, out_channel, stride=1):super().__init__()self.conv_res = nn.Conv2D(in_channel, out_channel, 3, stride, 1)self.conv = nn.Conv2D(in_channel, out_channel, 3, stride, 1)self.bnres = nn.BatchNorm2D(out_channel)self.relures = nn.ReLU()self.bn = nn.BatchNorm2D(out_channel)self.relu = nn.ReLU()def forward(self, x):x_res = self.conv_res(x)x_tran = self.conv(x)x_res = self.relures(self.bnres(x_res))x_tran = self.relu(self.bn(x_tran))return x_res, x_tran
class RiRFinalBlock(nn.Layer):def __init__(self):super().__init__()def forward(self, x_res, x_tran):return paddle.concat([x_res, x_tran], axis=1)
class RiR(nn.Layer):def __init__(self, channels, blocks, in_channels=3, in_size=(32, 32), num_classes=10):super().__init__()assert len(channels) == len(blocks), 'the length of channels is not the same as the length of blocks'self.init = RiRInitBlock(in_channels, channels[0], 1)self.stage = nn.LayerList()for i in range(len(blocks)):if i == 0:for j in range(blocks[i]):self.stage.append(RiRBlock(channels[i], channels[i]))else:for j in range(blocks[i]):self.stage.append(RiRBlock(channels[i-1] if j==0 else channels[i], channels[i], stride = 2 if j==0 else 1))self.final = RiRFinalBlock()self.classifier = nn.Sequential(nn.Conv2D(channels[-1] * 2, num_classes, 1),nn.AdaptiveAvgPool2D(1), nn.Flatten(1))self.apply(self._init_weights)def _init_weights(self, m):zeros_ = nn.initializer.Constant(value=0.)ones_ = nn.initializer.Constant(value=1.)if isinstance(m, (nn.Linear, nn.Conv2D)):paddle.nn.initializer.KaimingNormal(m.weight)if isinstance(m, (nn.Linear, nn.Conv2D)) and m.bias is not None:zeros_(m.bias)elif isinstance(m, (nn.BatchNorm2D)):zeros_(m.bias)ones_(m.weight)def forward(self, x):x_res, x_tran = self.init(x)for i in range(len(self.stage)):x_res, x_tran = self.stage[i](x_res, x_tran)out = self.final(x_res, x_tran)out = self.classifier(out)return out

2.5 模型的参数和FLOPs

model = RiR([48, 96, 192], [2, 3, 3])
paddle.summary(model, (batch_size, 3, 32, 32))

model = RiR([48, 96, 192], [2, 3, 3])
paddle.flops(model, (batch_size, 3, 32, 32))
<class 'paddle.nn.layer.conv.Conv2D'>'s flops has been counted
<class 'paddle.nn.layer.norm.BatchNorm2D'>'s flops has been counted
<class 'paddle.nn.layer.activation.ReLU'>'s flops has been counted
Cannot find suitable count function for <class 'paddle.nn.layermon.Identity'>. Treat it as zero FLOPs.
Cannot find suitable count function for <class '__main__.RiRFinalBlock'>. Treat it as zero FLOPs.
<class 'paddle.nn.layer.pooling.AdaptiveAvgPool2D'>'s flops has been counted
Cannot find suitable count function for <class 'paddle.fluid.dygraph.nn.Flatten'>. Treat it as zero FLOPs.
Total Flops: 164831593728     Total Params: 9532234164831593728

2.6 训练

learning_rate = 0.001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'# RiR-18
model = RiR([48, 96, 192], [2, 3, 3])criterion = LabelSmoothingCrossEntropy()scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracyloss_iter = 0
acc_iter = 0for epoch in range(n_epochs):# ---------- Training ----------model.train()train_num = 0.0train_loss = 0.0val_num = 0.0val_loss = 0.0accuracy_manager = paddle.metric.Accuracy()val_accuracy_manager = paddle.metric.Accuracy()print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))for batch_id, data in enumerate(train_loader):x_data, y_data = datalabels = paddle.unsqueeze(y_data, axis=1)logits = model(x_data)loss = criterion(logits, y_data)acc = paddle.metric.accuracy(logits, labels)accuracy_manager.update(acc)if batch_id % 10 == 0:loss_record['train']['loss'].append(loss.numpy())loss_record['train']['iter'].append(loss_iter)loss_iter += 1loss.backward()optimizer.step()scheduler.step()optimizer.clear_grad()train_loss += losstrain_num += len(y_data)total_train_loss = (train_loss / train_num) * batch_sizetrain_acc = accuracy_manager.accumulate()acc_record['train']['acc'].append(train_acc)acc_record['train']['iter'].append(acc_iter)acc_iter += 1# Print the information.print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))# ---------- Validation ----------model.eval()for batch_id, data in enumerate(val_loader):x_data, y_data = datalabels = paddle.unsqueeze(y_data, axis=1)with paddle.no_grad():logits = model(x_data)loss = criterion(logits, y_data)acc = paddle.metric.accuracy(logits, labels)val_accuracy_manager.update(acc)val_loss += lossval_num += len(y_data)total_val_loss = (val_loss / val_num) * batch_sizeloss_record['val']['loss'].append(total_val_loss.numpy())loss_record['val']['iter'].append(loss_iter)val_acc = val_accuracy_manager.accumulate()acc_record['val']['acc'].append(val_acc)acc_record['val']['iter'].append(acc_iter)print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))# ===================save====================if val_acc > best_acc:best_acc = val_accpaddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

2.7 实验结果

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):''' Plot learning curve of your CNN '''maxtrain = max(map(float, record['train'][title]))maxval = max(map(float, record['val'][title]))ymax = max(maxtrain, maxval) * 1.1mintrain = min(map(float, record['train'][title]))minval = min(map(float, record['val'][title]))ymin = min(mintrain, minval) * 0.9total_steps = len(record['train'][title])x_1 = list(map(int, record['train']['iter']))x_2 = list(map(int, record['val']['iter']))figure(figsize=(10, 6))plt.plot(x_1, record['train'][title], c='tab:red', label='train')plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')plt.ylim(ymin, ymax)plt.xlabel('Training steps')plt.ylabel(ylabel)plt.title('Learning curve of {}'.format(title))plt.legend()plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

import time
work_path = 'work/model'
model = RiR([48, 96, 192], [2, 3, 3])
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):x_data, y_data = datalabels = paddle.unsqueeze(y_data, axis=1)with paddle.no_grad():logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:1462
def get_cifar10_labels(labels):  """返回CIFAR10数据集的文本标签。"""text_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog','horse', 'ship', 'truck']return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):  """Plot a list of images."""figsize = (num_cols * scale, num_rows * scale)_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if paddle.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if pred or gt:ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = RiR([48, 96, 192], [2, 3, 3])
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 32, 32, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

对比实验

ResNet-18(根据RiR论文里写的)的测试结果见main-copy1.ipynb

modelAccParameterFLOPs
RiR-180.925149,532,234164831593728
ResNet-180.919809,574,474165021910272

总结

        本文提出了一个广义残差架构(generalized residual architecture),通过对原始方案简单的修改便可以实现这个网络(ResNet Init)。将ResNet Init应用到原始的ResNet中,从而得到RiR架构,RiR架构取得了非常好的结果。
        从对比实验中可以看出RiR可以以更少的参数和FLOPs超越ResNet性能,在准确率方面提高了约0.6%。
        未来工作:对RiR进一步的改进,并尝试在更大的数据集上比较性能

开源链接:=1

更多推荐

【AI达人特训营】RiR论文复现

本文发布于:2024-02-11 18:20:25,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1682545.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:达人   论文   AI   RiR   特训营

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!