Tensorflow实战 LeNet-5神经网络进行手写体数字识别

编程入门 行业动态 更新时间:2024-10-28 15:24:31

Tensorflow实战 LeNet-5神经网络进行<a href=https://www.elefans.com/category/jswz/34/1725823.html style=手写体数字识别"/>

Tensorflow实战 LeNet-5神经网络进行手写体数字识别

完整代码:

关于LeNet5网络的介绍,可以参考我的上一篇博客浅谈LeNet-5

本文基于Tensorflow对LeNet5进行复现并进行手写体数字识别。

运行环境:win10+tensorflow1.8.0+cuda9.0+cudnn7.0

一、项目结构

文件介绍

文件功能
lenet.py定义lenent模型
layer_util.py封装一些常用函数
train.py配置一些训练参数以及数据的读取

二、代码详解

layer_util.py

为了降低代码耦合度,所以我将卷积,池化,全连接过程,抽象封装成函数,这样便于定义网络的时候使用。现在我们在定义卷积层、池化层,全连接层的时候,只需要关注输出是多少就好,不需要在进行繁琐计算。

参数的获取

定义了获取变量和获取常量的函数

def get_variable(shape, stddev=0.1):initial = tf.truncated_normal(shape, stddev=stddev)return tf.Variable(initial)def get_constant_variable(shape, value=0.1):initial = tf.constant(value, shape=shape)return tf.Variable(initial)

conv2d

具体参数含义见注释。

def conv2d(inputs,out_channels,kernel_size,scope,stride=[1, 1],padding='SAME',stddev=1e-1,activation_fn=tf.nn.relu):"""Args:inputs: 4-D tensor variable BxHxWxCoutput_channels: intkernel_size: a list of 2 intsscope: stringstride: a list of 2 intspadding: 'SAME' or 'VALID'stddev: float, stddev for truncated_normal initactivation_fn: functionReturns:Variable tensor"""with tf.variable_scope(scope) as sc:kernel_h, kernel_w = kernel_sizein_channels = inputs.shape[-1].valuekernel_shape = [kernel_h, kernel_w, in_channels, out_channels]kernel = get_variable(kernel_shape, stddev)stride_h, stride_w = strideoutputs = tf.nn.conv2d(inputs, kernel, strides=[1, stride_h, stride_w, 1], padding=padding)if activation_fn is not None:outputs = activation_fn(outputs)return outputs

max_pool2d

def max_pool2d(inputs,kernel_size,scope,stride=[2, 2],padding='SAME'):""" 2D max pooling.Args:inputs: 4-D tensor BxHxWxCkernel_size: a list of 2 intsstride: a list of 2 intsReturns:Variable tensor"""with tf.variable_scope(scope) as sc:kernel_h, kernel_w = kernel_sizestride_h, stride_w = strideoutputs = tf.nn.max_pool(inputs,ksize=[1, kernel_h, kernel_w, 1],strides=[1, stride_h, stride_w, 1],padding=padding,)return outputs

full_connection

def full_connection(inputs,num_outputs,scope,stddev=1e-1,activation_fn=tf.nn.relu):"""Args:inputs: 2-D tensor BxNnum_outputs: intReturns:Variable tensor of size B x num_outputs."""with tf.variable_scope(scope) as sc:num_inputs = inputs.shape[-1].valueweights = get_variable(shape=[num_inputs, num_outputs], stddev=stddev)outputs = tf.matmul(inputs, weights)biases = get_constant_variable(shape=[num_outputs])outputs = tf.nn.bias_add(outputs, biases)tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(0.01)(weights))if activation_fn is not None:outputs = activation_fn(outputs)return outputs

dropout

虽然本项目中没用到,不过还是写一下。

def dropout(inputs, scope, keep_prob, is_training=False):with tf.variable_scope(scope) as sc:outputs = tf.nn.dropout(inputs, keep_prob=keep_prob)return outputs

layer_util.py完整代码

import tensorflow as tfdef get_variable(shape, stddev=0.1):initial = tf.truncated_normal(shape, stddev=stddev)return tf.Variable(initial)def get_constant_variable(shape, value=0.1):initial = tf.constant(value, shape=shape)return tf.Variable(initial)def conv2d(inputs,out_channels,kernel_size,scope,stride=[1, 1],padding='SAME',stddev=1e-1,activation_fn=tf.nn.relu):"""Args:inputs: 4-D tensor variable BxHxWxCoutput_channels: intkernel_size: a list of 2 intsscope: stringstride: a list of 2 intspadding: 'SAME' or 'VALID'stddev: float, stddev for truncated_normal initactivation_fn: functionReturns:Variable tensor"""with tf.variable_scope(scope) as sc:kernel_h, kernel_w = kernel_sizein_channels = inputs.shape[-1].valuekernel_shape = [kernel_h, kernel_w, in_channels, out_channels]kernel = get_variable(kernel_shape, stddev)stride_h, stride_w = strideoutputs = tf.nn.conv2d(inputs, kernel, strides=[1, stride_h, stride_w, 1], padding=padding)if activation_fn is not None:outputs = activation_fn(outputs)return outputsdef max_pool2d(inputs,kernel_size,scope,stride=[2, 2],padding='SAME'):""" 2D max pooling.Args:inputs: 4-D tensor BxHxWxCkernel_size: a list of 2 intsstride: a list of 2 intsReturns:Variable tensor"""with tf.variable_scope(scope) as sc:kernel_h, kernel_w = kernel_sizestride_h, stride_w = strideoutputs = tf.nn.max_pool(inputs,ksize=[1, kernel_h, kernel_w, 1],strides=[1, stride_h, stride_w, 1],padding=padding,)return outputsdef full_connection(inputs,num_outputs,scope,stddev=1e-1,activation_fn=tf.nn.relu):"""Args:inputs: 2-D tensor BxNnum_outputs: intReturns:Variable tensor of size B x num_outputs."""with tf.variable_scope(scope) as sc:num_inputs = inputs.shape[-1].valueweights = get_variable(shape=[num_inputs, num_outputs], stddev=stddev)outputs = tf.matmul(inputs, weights)biases = get_constant_variable(shape=[num_outputs])outputs = tf.nn.bias_add(outputs, biases)tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(0.01)(weights))if activation_fn is not None:outputs = activation_fn(outputs)return outputsdef dropout(inputs, scope, keep_prob, is_training=False):with tf.variable_scope(scope) as sc:outputs = tf.nn.dropout(inputs, keep_prob=keep_prob)return outputs

lenet.py

根据leNent5网络定义,并返回输出和经过softmax后的结果

import tensorflow as tf
from utils import layer_utildef LeNet(inputs, keep_prob=None):with tf.variable_scope('Conv1'):conv1 = layer_util.conv2d(inputs, 6, [5, 5], 'conv1', padding='VALID')with tf.variable_scope('S2'):s2 = layer_util.max_pool2d(conv1, [2, 2], 'S2')with tf.variable_scope('Conv3'):conv3 = layer_util.conv2d(s2, 16, [5, 5], 'conv3', padding='VALID')with tf.variable_scope('S4'):s4 = layer_util.max_pool2d(conv3, [2, 2], 's4')with tf.variable_scope('Conv5'):conv5 = layer_util.conv2d(s4, 120, [5, 5], 'conv5')flattened_shape = conv5.shape[1].value * conv5.shape[2].value * conv5.shape[3].valueconv5 = tf.reshape(conv5, [-1, flattened_shape])with tf.variable_scope('F6'):f6 = layer_util.full_connection(conv5, 84, 'f6')with tf.variable_scope('output'):outputs = layer_util.full_connection(f6, 10, 'outputs', activation_fn=None)prediction = tf.nn.softmax(outputs)return outputs,predictiondef get_model(inputs, keep_prob=None):return LeNet(inputs, keep_prob)

train.py

加载数据。如果加载目录不存在数据会自动下载。如果网络问题下载失败或者下载速度很慢,可以从我的GitHub项目中获取。

from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("./data/MNIST_data", one_hot=True) #one_hot为独热编码

定义超参数

BATCH_SIZE = 100N_BATCH = mnist.train.num_examples // BATCH_SIZE

定义输入

with tf.variable_scope('inputs'):x = tf.placeholder(tf.float32, [None, 784])y = tf.placeholder(tf.float32, [None, 10])

因为网络输入为图片形式,所以需要对数据进行reshape

x_image = tf.reshape(x, [-1, 28, 28, 1])

获取模型的输出

outputs, prediction = lenet.get_model(x_image)

损失函数的定义,这里用了交叉熵损失

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))

使用梯度下降优化器

train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

进行准确率的计算

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

train.py完整代码

import tensorflow as tf
from moudle import lenet
from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("./data/MNIST_data", one_hot=True)BATCH_SIZE = 100
N_BATCH = mnist.train.num_examples // BATCH_SIZEdef train():with tf.variable_scope('inputs'):x = tf.placeholder(tf.float32, [None, 784])y = tf.placeholder(tf.float32, [None, 10])x_image = tf.reshape(x, [-1, 28, 28, 1])outputs, prediction = lenet.get_model(x_image)loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)init = tf.global_variables_initializer()correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))saver = tf.train.Saver()with tf.Session() as sess:writer = tf.summary.FileWriter('logs/', sess.graph)sess.run(init)for epoch in range(20):for batch in range(N_BATCH):batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE)sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})pre = sess.run(prediction, feed_dict={x: mnist.test.images, y: mnist.test.labels})acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})print('Iter' + str(epoch) + ",Testing Accuracy " + str(acc))saver.save(sess, 'logs/train.ckpt')if __name__ == '__main__':train()

网络结构图

三、训练结果

可以进行一些消融实验,换个优化器或者多训练几轮,准确率最高应该可以到99%甚至100%。

更多推荐

Tensorflow实战 LeNet-5神经网络进行手写体数字识别

本文发布于:2023-07-28 20:49:06,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1305223.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:手写体   神经网络   实战   数字   Tensorflow

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!