问题描述
限时送ChatGPT账号..我正在尝试在预训练的 TF 模型 EfficientNetB0 中用 relu 激活替换 swish 激活.EfficientNetB0 在 Conv2D 和 Activation 层中使用 swish 激活.这个 SO post 与我正在寻找的非常相似.我还找到了 一个答案,它适用于没有跳过连接的模型.代码如下:
I'm trying to replace swish activation with relu activation in pretrained TF model EfficientNetB0. EfficientNetB0 uses swish activation in Conv2D and Activation layers. This SO post is very similar to what I'm looking for. I also found an answer which works for models without skip connection. Below is the code:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import ReLU
def replace_swish_with_relu(model):
'''
Modify passed model by replacing swish activation with relu
'''
for layer in tuple(model.layers):
layer_type = type(layer).__name__
if hasattr(layer, 'activation') and layer.activation.__name__ == 'swish':
print(layer_type, layer.activation.__name__)
if layer_type == "Conv2D":
# conv layer with swish activation.
# Do something
layer.activation = ReLU() # This didn't work
else:
# activation layer
# Do something
layer = tf.keras.layers.Activation('relu', name=layer.name + "_relu") # This didn't work
return model
# load pretrained efficientNet
model = tf.keras.applications.EfficientNetB0(
include_top=True, weights='imagenet', input_tensor=None,
input_shape=(224, 224, 3), pooling=None, classes=1000,
classifier_activation='softmax')
# convert swish activation to relu activation
model = replace_swish_with_relu(model)
model.save("efficientNet-relu")
如何修改replace_swish_with_relu
以在传递的模型中用relu替换swish激活?
How to modify replace_swish_with_relu
to replace swish activations with relu in the passed model?
感谢您的指点/帮助.
推荐答案
layer.activation
指向 tf.keras.activations.swish
函数地址.我们可以修改它以指向tf.keras.activations.relu
.下面是修改后的,replace_swish_with_relu
:
layer.activation
points to tf.keras.activations.swish
function address. We can modify it to point to tf.keras.activations.relu
. Below is the modified, replace_swish_with_relu
:
def replace_swish_with_relu(model):
'''
Modify passed model by replacing swish activation with relu
'''
for layer in tuple(model.layers):
layer_type = type(layer).__name__
if hasattr(layer, 'activation') and layer.activation.__name__ == 'swish':
print(layer_type, layer.activation.__name__)
if layer_type == "Conv2D":
# conv layer with swish activation
layer.activation = tf.keras.activations.relu
else:
# activation layer
layer.activation = tf.keras.activations.relu
return model
注意:如果您正在修改激活函数,那么您需要重新训练模型以使用新的激活函数.相关.
Note: If you are modifying the activation function, then you need to retrain the model to work with the new activation. Related.
这篇关于TensorFlow、Keras:替换预训练模型中的激活层的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
更多推荐
[db:关键词]
发布评论