结构图及其keras实现"/>
GoogleNetV3详细结构图及其keras实现
目录
GoogleNetV3结构图
Keras完整代码
GoogleNetV3结构图
该结构图是由Google论文中给出的表格细化而来,程序是自己编写的,Injection层的输出shape与论文中给出的由细微差别,如有错漏请多多指教。
上图为论文中给出的结构表格,下图是keras.plot生成的详细结构:
Keras完整代码:
将想要进行训练的图片放在对应文件夹中即可进行训练:
例如 ./data/train/cat 其中cat、dog为需要网络分辨的类别,存储在不同文件夹中flow_from_directory
./data/train/dog 会自动对图片进行标记,有几个类别就建立几个文件夹。
./data/validation/cat train文件夹下的图片用于训练,validation文件夹下的图片用于验证准确性,关于
./data/validation/dog
ImageDataGenerator的详细内容可参考: 或
/
from keras.models import Model
from keras.layers import Input, Dense, Dropout, BatchNormalization, Flatten, concatenate, Activation, Reshape
from keras.layers.convolutional import Conv2D, MaxPooling2D, AveragePooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import plot_model
from keras.utils import np_utils
import numpy as np
from time import ctime
from PIL import Image
import pickle
import glob
import os
import tensorflow as tfdef Conv2d_BN(x, nb_filter, kernel_size, padding='same', strides=(1, 1), name=None):if name is not None:bn_name = name + '_bn'conv_name = name + '_conv'else:bn_name = Noneconv_name = Nonex = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides, name=conv_name)(x)x = BatchNormalization(axis=3, name=bn_name)(x)x = Activation('relu')(x)return xdef InceptionT1(x, nb_filter, strides = (1,1)):branch1 = Conv2d_BN(x, nb_filter, (1, 1), padding='same', strides=(1, 1), name=None)branch1 = Conv2d_BN(branch1, nb_filter, (3, 3), padding='same', strides=(1, 1), name=None)branch1 = Conv2d_BN(branch1, nb_filter, (3, 3), padding='same', strides=strides, name=None)branch2 = Conv2d_BN(x, nb_filter, (1, 1), padding='same', strides=(1, 1), name=None)branch2 = Conv2d_BN(branch2, nb_filter, (3, 3), padding='same', strides=strides, name=None)branch3 = MaxPooling2D(pool_size=(3, 3), strides=strides, padding='same')(x)branch3 = Conv2d_BN(branch3, nb_filter, (1, 1), padding='same', strides=(1, 1), name=None)branch4 = Conv2d_BN(x, nb_filter, (1, 1), padding='valid', strides=strides, name=None)# branch4 = tf.reshape(branch4,(17,17,192))x = concatenate([branch1, branch2, branch3, branch4], axis=3)return xdef InceptionT2(x, nb_filter, strides = (1,1), stride1 = (1, 1), stride2 = (1, 1)):branch1 = Conv2d_BN(x, nb_filter, (1, 1), padding='same', strides=(1, 1), name=None)branch1 = Conv2d_BN(branch1, nb_filter, (1, 7), padding='same', strides=(1, 1), name=None)branch1 = Conv2d_BN(branch1, nb_filter, (7, 1), padding='same', strides=(1, 1), name=None)branch1 = Conv2d_BN(branch1, nb_filter, (1, 7), padding='same', strides=stride1, name=None)branch1 = Conv2d_BN(branch1, nb_filter, (7, 1), padding='same', strides=stride2, name=None)branch2 = Conv2d_BN(x, nb_filter, (1, 1), padding='same', strides=(1, 1), name=None)branch2 = Conv2d_BN(branch2, nb_filter, (1, 7), padding='same', strides=stride1, name=None)branch2 = Conv2d_BN(branch2, nb_filter, (7, 1), padding='same', strides=stride2, name=None)branch3 = MaxPooling2D(pool_size=(3, 3), strides=strides, padding='same')(x)branch3 = Conv2d_BN(branch3, nb_filter, (1, 1), padding='same', strides=(1, 1), name=None)branch4 = Conv2d_BN(x, nb_filter, (1, 1), padding='same', strides=strides, name=None)x = concatenate([branch1, branch2, branch3, branch4], axis=3)return xdef InceptionT3(x, nb_filter, strides = (1,1)):branch1 = Conv2d_BN(x, nb_filter, (1, 1), padding='same', strides=(1, 1), name=None)branch1 = Conv2d_BN(branch1, nb_filter, (3, 3), padding='same', strides=strides, name=None)branch1_1 = Conv2d_BN(branch1, nb_filter, (1, 3), padding='same', strides=(1, 1), name=None)branch1_2 = Conv2d_BN(branch1, nb_filter, (3, 1), padding='same', strides=(1, 1), name=None)branch2 = Conv2d_BN(x, nb_filter, (1, 1), padding='same', strides=strides, name=None)branch2_1 = Conv2d_BN(branch2, nb_filter, (1, 3), padding='same', strides=(1, 1), name=None)branch2_2 = Conv2d_BN(branch2, nb_filter, (3, 1), padding='same', strides=(1, 1), name=None)branch3 = MaxPooling2D(pool_size=(3, 3), strides=strides, padding='same')(x)branch3 = Conv2d_BN(branch3, nb_filter, (1, 1), padding='same', strides=(1, 1), name=None)branch4 = Conv2d_BN(x, nb_filter, (1, 1), padding='same', strides=strides, name=None)x = concatenate([branch1_1,branch1_2, branch2_1, branch2_2, branch3, branch4], axis=3)return xdef creatcnn():inpt = Input(shape=(299, 299, 3))x = Conv2d_BN(inpt, 32, (3, 3), strides=(2, 2), padding='valid')#(149,149)32x = Conv2d_BN(x, 32, (3, 3), strides=(1, 1), padding='valid')#(147,147)32x = Conv2d_BN(x, 64, (3, 3), strides=(1, 1), padding='same')#(147,147)64x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='valid')(x)#64x = Conv2d_BN(x, 80, (3, 3), strides=(1, 1), padding='valid')#80x = Conv2d_BN(x, 192, (3, 3), strides=(2, 2), padding='valid')#192x = Conv2d_BN(x, 288, (3, 3), strides=(1, 1), padding='same')#288x = InceptionT1(x, 192) # 768x = InceptionT1(x, 192) # 768x = InceptionT1(x, 192, strides = (2, 2)) # 768x = InceptionT2(x, 320) # 1280x = InceptionT2(x, 320) # 1280x = InceptionT2(x, 320) # 1280x = InceptionT2(x, 320) # 1280x = InceptionT2(x, 320, strides = (2, 2), stride1 = (2, 1), stride2 = (1, 2)) # 1280x = InceptionT3(x, 341) # 2048x = InceptionT3(x, 341, strides = (1, 1)) # 2048x = AveragePooling2D(pool_size=(9, 9), strides=(9, 9), padding='same')(x)x = Flatten()(x)x = Dense(1000, activation='relu')(x)x = Dense(500, activation='relu')(x)x = Dense(1, activation='softmax')(x)model = Model(inpt, x, name='inception')return modelif __name__ == "__main__":model = creatcnn()plot_model(model, to_file='GoogleNetV3_model.png', show_shapes=True)modelpile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])train_datagen = ImageDataGenerator(shear_range=0.2,zoom_range=0.2,horizontal_flip=True)test_datagen = ImageDataGenerator()train_generator = train_datagen.flow_from_directory('data/train',target_size=(299, 299),batch_size=10,class_mode='binary')validation_generator = test_datagen.flow_from_directory('data/validation',target_size=(299, 299),batch_size=10,class_mode='binary')model.fit_generator(train_generator,steps_per_epoch=97,epochs=50,validation_data=validation_generator,validation_steps=7,verbose=2)model.save_weights('GoogleNet_weight.h5')
更多推荐
GoogleNetV3详细结构图及其keras实现
发布评论