手写体数字识别"/>
Tensorflow实战 LeNet-5神经网络进行手写体数字识别
完整代码:
关于LeNet5网络的介绍,可以参考我的上一篇博客浅谈LeNet-5
本文基于Tensorflow对LeNet5进行复现并进行手写体数字识别。
运行环境:win10+tensorflow1.8.0+cuda9.0+cudnn7.0
一、项目结构
文件介绍
文件 | 功能 |
---|---|
lenet.py | 定义lenent模型 |
layer_util.py | 封装一些常用函数 |
train.py | 配置一些训练参数以及数据的读取 |
二、代码详解
layer_util.py
为了降低代码耦合度,所以我将卷积,池化,全连接过程,抽象封装成函数,这样便于定义网络的时候使用。现在我们在定义卷积层、池化层,全连接层的时候,只需要关注输出是多少就好,不需要在进行繁琐计算。
参数的获取
定义了获取变量和获取常量的函数
def get_variable(shape, stddev=0.1):initial = tf.truncated_normal(shape, stddev=stddev)return tf.Variable(initial)def get_constant_variable(shape, value=0.1):initial = tf.constant(value, shape=shape)return tf.Variable(initial)
conv2d
具体参数含义见注释。
def conv2d(inputs,out_channels,kernel_size,scope,stride=[1, 1],padding='SAME',stddev=1e-1,activation_fn=tf.nn.relu):"""Args:inputs: 4-D tensor variable BxHxWxCoutput_channels: intkernel_size: a list of 2 intsscope: stringstride: a list of 2 intspadding: 'SAME' or 'VALID'stddev: float, stddev for truncated_normal initactivation_fn: functionReturns:Variable tensor"""with tf.variable_scope(scope) as sc:kernel_h, kernel_w = kernel_sizein_channels = inputs.shape[-1].valuekernel_shape = [kernel_h, kernel_w, in_channels, out_channels]kernel = get_variable(kernel_shape, stddev)stride_h, stride_w = strideoutputs = tf.nn.conv2d(inputs, kernel, strides=[1, stride_h, stride_w, 1], padding=padding)if activation_fn is not None:outputs = activation_fn(outputs)return outputs
max_pool2d
def max_pool2d(inputs,kernel_size,scope,stride=[2, 2],padding='SAME'):""" 2D max pooling.Args:inputs: 4-D tensor BxHxWxCkernel_size: a list of 2 intsstride: a list of 2 intsReturns:Variable tensor"""with tf.variable_scope(scope) as sc:kernel_h, kernel_w = kernel_sizestride_h, stride_w = strideoutputs = tf.nn.max_pool(inputs,ksize=[1, kernel_h, kernel_w, 1],strides=[1, stride_h, stride_w, 1],padding=padding,)return outputs
full_connection
def full_connection(inputs,num_outputs,scope,stddev=1e-1,activation_fn=tf.nn.relu):"""Args:inputs: 2-D tensor BxNnum_outputs: intReturns:Variable tensor of size B x num_outputs."""with tf.variable_scope(scope) as sc:num_inputs = inputs.shape[-1].valueweights = get_variable(shape=[num_inputs, num_outputs], stddev=stddev)outputs = tf.matmul(inputs, weights)biases = get_constant_variable(shape=[num_outputs])outputs = tf.nn.bias_add(outputs, biases)tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(0.01)(weights))if activation_fn is not None:outputs = activation_fn(outputs)return outputs
dropout
虽然本项目中没用到,不过还是写一下。
def dropout(inputs, scope, keep_prob, is_training=False):with tf.variable_scope(scope) as sc:outputs = tf.nn.dropout(inputs, keep_prob=keep_prob)return outputs
layer_util.py完整代码
import tensorflow as tfdef get_variable(shape, stddev=0.1):initial = tf.truncated_normal(shape, stddev=stddev)return tf.Variable(initial)def get_constant_variable(shape, value=0.1):initial = tf.constant(value, shape=shape)return tf.Variable(initial)def conv2d(inputs,out_channels,kernel_size,scope,stride=[1, 1],padding='SAME',stddev=1e-1,activation_fn=tf.nn.relu):"""Args:inputs: 4-D tensor variable BxHxWxCoutput_channels: intkernel_size: a list of 2 intsscope: stringstride: a list of 2 intspadding: 'SAME' or 'VALID'stddev: float, stddev for truncated_normal initactivation_fn: functionReturns:Variable tensor"""with tf.variable_scope(scope) as sc:kernel_h, kernel_w = kernel_sizein_channels = inputs.shape[-1].valuekernel_shape = [kernel_h, kernel_w, in_channels, out_channels]kernel = get_variable(kernel_shape, stddev)stride_h, stride_w = strideoutputs = tf.nn.conv2d(inputs, kernel, strides=[1, stride_h, stride_w, 1], padding=padding)if activation_fn is not None:outputs = activation_fn(outputs)return outputsdef max_pool2d(inputs,kernel_size,scope,stride=[2, 2],padding='SAME'):""" 2D max pooling.Args:inputs: 4-D tensor BxHxWxCkernel_size: a list of 2 intsstride: a list of 2 intsReturns:Variable tensor"""with tf.variable_scope(scope) as sc:kernel_h, kernel_w = kernel_sizestride_h, stride_w = strideoutputs = tf.nn.max_pool(inputs,ksize=[1, kernel_h, kernel_w, 1],strides=[1, stride_h, stride_w, 1],padding=padding,)return outputsdef full_connection(inputs,num_outputs,scope,stddev=1e-1,activation_fn=tf.nn.relu):"""Args:inputs: 2-D tensor BxNnum_outputs: intReturns:Variable tensor of size B x num_outputs."""with tf.variable_scope(scope) as sc:num_inputs = inputs.shape[-1].valueweights = get_variable(shape=[num_inputs, num_outputs], stddev=stddev)outputs = tf.matmul(inputs, weights)biases = get_constant_variable(shape=[num_outputs])outputs = tf.nn.bias_add(outputs, biases)tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(0.01)(weights))if activation_fn is not None:outputs = activation_fn(outputs)return outputsdef dropout(inputs, scope, keep_prob, is_training=False):with tf.variable_scope(scope) as sc:outputs = tf.nn.dropout(inputs, keep_prob=keep_prob)return outputs
lenet.py
根据leNent5网络定义,并返回输出和经过softmax后的结果
import tensorflow as tf
from utils import layer_utildef LeNet(inputs, keep_prob=None):with tf.variable_scope('Conv1'):conv1 = layer_util.conv2d(inputs, 6, [5, 5], 'conv1', padding='VALID')with tf.variable_scope('S2'):s2 = layer_util.max_pool2d(conv1, [2, 2], 'S2')with tf.variable_scope('Conv3'):conv3 = layer_util.conv2d(s2, 16, [5, 5], 'conv3', padding='VALID')with tf.variable_scope('S4'):s4 = layer_util.max_pool2d(conv3, [2, 2], 's4')with tf.variable_scope('Conv5'):conv5 = layer_util.conv2d(s4, 120, [5, 5], 'conv5')flattened_shape = conv5.shape[1].value * conv5.shape[2].value * conv5.shape[3].valueconv5 = tf.reshape(conv5, [-1, flattened_shape])with tf.variable_scope('F6'):f6 = layer_util.full_connection(conv5, 84, 'f6')with tf.variable_scope('output'):outputs = layer_util.full_connection(f6, 10, 'outputs', activation_fn=None)prediction = tf.nn.softmax(outputs)return outputs,predictiondef get_model(inputs, keep_prob=None):return LeNet(inputs, keep_prob)
train.py
加载数据。如果加载目录不存在数据会自动下载。如果网络问题下载失败或者下载速度很慢,可以从我的GitHub项目中获取。
from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("./data/MNIST_data", one_hot=True) #one_hot为独热编码
定义超参数
BATCH_SIZE = 100N_BATCH = mnist.train.num_examples // BATCH_SIZE
定义输入
with tf.variable_scope('inputs'):x = tf.placeholder(tf.float32, [None, 784])y = tf.placeholder(tf.float32, [None, 10])
因为网络输入为图片形式,所以需要对数据进行reshape
x_image = tf.reshape(x, [-1, 28, 28, 1])
获取模型的输出
outputs, prediction = lenet.get_model(x_image)
损失函数的定义,这里用了交叉熵损失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
使用梯度下降优化器
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
进行准确率的计算
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
train.py完整代码
import tensorflow as tf
from moudle import lenet
from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("./data/MNIST_data", one_hot=True)BATCH_SIZE = 100
N_BATCH = mnist.train.num_examples // BATCH_SIZEdef train():with tf.variable_scope('inputs'):x = tf.placeholder(tf.float32, [None, 784])y = tf.placeholder(tf.float32, [None, 10])x_image = tf.reshape(x, [-1, 28, 28, 1])outputs, prediction = lenet.get_model(x_image)loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)init = tf.global_variables_initializer()correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))saver = tf.train.Saver()with tf.Session() as sess:writer = tf.summary.FileWriter('logs/', sess.graph)sess.run(init)for epoch in range(20):for batch in range(N_BATCH):batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE)sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})pre = sess.run(prediction, feed_dict={x: mnist.test.images, y: mnist.test.labels})acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})print('Iter' + str(epoch) + ",Testing Accuracy " + str(acc))saver.save(sess, 'logs/train.ckpt')if __name__ == '__main__':train()
网络结构图
三、训练结果
可以进行一些消融实验,换个优化器或者多训练几轮,准确率最高应该可以到99%甚至100%。
更多推荐
Tensorflow实战 LeNet-5神经网络进行手写体数字识别
发布评论