使用CNN实现Google你画电脑猜

编程入门 行业动态 更新时间:2024-10-10 08:20:56

使用CNN实现Google你画<a href=https://www.elefans.com/category/jswz/34/1770036.html style=电脑猜"/>

使用CNN实现Google你画电脑猜

使用CNN 实现小型Google你画电脑猜

你画电脑猜

Google 的猜画小歌风靡一时,使用 CNN 我们也可以搞一个出来。

数据集

Google AI Lab 已经开源了用来训练的数据集:
里边有 354 类共5千万张涂鸦图片,并提供多种格式,从原始的笔迹数据到转换好的位图都有。我们将使用其中的 numpy 格式版本。

如果能够访问 Google Drive 的话,可以从 下载或者有装过 gsutil 的话,可以 gsutil -m rsync gs://quickdraw_dataset/full/numpy_bitmap/ .

由于这个真实项目的数据集规模相当大,有文件小很多的矢量格式可以下载,比如 binary 里就是一种。另外 这里提供了一个npz格式的子集,只包含7万5千张图片,numpy.load() 既能加载 .npy 也能加载 .npz 倒是方便。不过所有的矢量格式,下载回来都要预处理转成位图,不然没法喂给咱们的 CNN 网络。这两个文件格式可以参考 binaryparser.ipbn

模型

Google AI Lab 在 .03477 这篇论文里解释了他们是怎么干的,并且在 github 上提供了实现,那是一个相对复杂的 RNN 模型。这里我们用一个较为简单的 CNN 模型来解决这个问题。

模型结构如下:

导入

import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
加载数据集
names = []
# 加载数据集
def load_dataset(dafile):count = 0# 每次加载多少张图片images_per_class = 10000# 大小为(0,784) 784的来源就是上面的原因X = np.empty([0, 784])Y = np.empty([0, 1])# 加载数据集[butterfly.npy , clock.npy, door.npy, ...]for file in os.listdir(dafile):fn = dafile + filenames.append(file)# 加载npy, 提取10000张图片分析images = np.load(fn).astype('float32')# (10000,784)images = images[0:images_per_class, :]# 返回给定形状和类型的新数组# 比如 count = 3时, label = [[3],[3],[3],...], (10000,1)labels = np.full((images_per_class, 1), count)# 连接,count = 2时,X(30000, 784), Y(30000, 1)X = np.concatenate((X, images), axis = 0)Y = np.concatenate((Y, labels), axis = 0)count += 1# 洗牌, 打乱顺序,0 到 Y.shape[0], order = np.random.permutation(Y.shape[0])X = X[order, :]Y = Y[order, :]# 拆分测试数据集和训练集,1:9X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state = 0, test_size = 0.1)return X_train, Y_train, X_test, Y_test, count

下载小型的对应文件夹,有12种类别,放在data的目录下

dataset_files = "./data/"
X_train, Y_train, X_test, Y_test, count = load_dataset(dataset_files)
模型结构:

对应模型代码实现
# 重置图,可以反复运行,
tf.reset_default_graph()# 所需变量
x = None
y = None
# 三层卷积,三层池化,三层激活层,两层全连接层
# 权重,
w_conv = [None, None, None]
# 偏差值
b_conv = [None, None, None]
# 卷积
r_conv = [None, None, None]
# 激活
h_conv = [None, None, None]
# 池化层
h_pool = [None, None, None]
keep_prob = None
# 全连接层
w_fc = [None, None]
b_fc = [None, None]
h_fc = [None, None]def build_model():global x, yglobal w_conv, b_conv, r_conv, h_convglobal h_pool, keep_probglobal w_fc, b_fc, h_fc# 输入层 with tf.variable_scope('input'):x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name = 'x')# 卷积层 1: conv1 [None, 28, 28, 16]with tf.variable_scope('conv1'):# truncated_normal 从截断的正态分布中输出随机值w_conv[0] = tf.Variable(tf.truncated_normal([3, 3, 1, 16], stddev = 0.1))b_conv[0] = tf.Variable(tf.truncated_normal([16], stddev = 0.1))r_conv[0] = tf.nn.conv2d(x, w_conv[0], strides=[1,1,1,1], padding="SAME")# wx + br_conv[0] = r_conv[0] + b_conv[0]# 激活层 1: h1with tf.variable_scope('h1'):h_conv[0] = tf.nn.relu(r_conv[0])# 池化层 1: pool1 [None, 14, 14, 16]with tf.variable_scope('pool1'):h_pool[0] = tf.nn.max_pool(h_conv[0], ksize = [1,2,2,1], strides = [1,2,2,1], padding = 'SAME')# 卷积层 2: conv2 with tf.variable_scope('conv2'):w_conv[1] = tf.Variable(tf.truncated_normal([3, 3, 16, 32], stddev = 0.1))b_conv[1] = tf.Variable(tf.truncated_normal([32], stddev = 0.1))r_conv[1] = tf.nn.conv2d(h_pool[0], w_conv[1], strides=[1,1,1,1], padding = 'SAME')r_conv[1] = r_conv[1] + b_conv[1]# 激活层 2: h2with tf.variable_scope('h2'):h_conv[1] = tf.nn.relu(r_conv[1])# 池化层 2:  pool2 [None, 7, 7, 32]with tf.variable_scope('pool2'):h_pool[1] = tf.nn.max_pool(h_conv[1], ksize=[1,2,2,1], strides = [1,2,2,1], padding='SAME')# 卷积层 3: conv3with tf.variable_scope('conv3'):w_conv[2] = tf.Variable(tf.truncated_normal([3, 3, 32, 64], stddev = 0.1))b_conv[2] = tf.Variable(tf.truncated_normal([64], stddev = 0.1))r_conv[2] = tf.nn.conv2d(h_pool[1], w_conv[2], strides=[1,1,1,1], padding = 'SAME')r_conv[2] = r_conv[2] + b_conv[2]# 激活层 3: h3with tf.variable_scope('h3'):h_conv[2] = tf.nn.relu(r_conv[2])# 池化层 3: pool3 [None, 4, 4, 64]with tf.variable_scope('pool3'):h_pool[2] = tf.nn.max_pool(h_conv[2], ksize=[1,2,2,1], strides = [1,2,2,1], padding='SAME')# 全连接层 1 : fc1with tf.variable_scope('fc1'):keep_prob = tf.placeholder(tf.float32)h_pool3_flat = tf.reshape(h_pool[2], [-1, 4*4*64])w_fc[0] = tf.Variable(tf.truncated_normal([4*4*64, 1024], stddev = 0.1))b_fc[0] = tf.Variable(tf.truncated_normal([1024], stddev = 0.1))h_fc[0] = tf.nn.relu(tf.matmul(h_pool3_flat, w_fc[0]))h_fc[0] = h_fc[0] + b_fc[0]# 正则化h_fc[0] = tf.nn.dropout(h_fc[0], keep_prob)# 全连接层 2 :  fc2with tf.variable_scope('fc2'):# 分类 count 12种w_fc[1] = tf.Variable(tf.truncated_normal([1024, count], stddev = 0.1))b_fc[1] = tf.Variable(tf.truncated_normal([count], stddev = 0.1))h_fc[1] = tf.matmul(h_fc[0], w_fc[1])h_fc[1] = h_fc[1] + b_fc[1]# output,多分类softmax函数with tf.variable_scope('output'):prediction = tf.nn.softmax(h_fc[1])return prediction
训练和保存模型
# 训练模型
y = tf.placeholder(tf.float32, shape=[None, count], name = 'y')
batch_size = 512
train_step = Nonemodel = build_model()
reduce_sum = -tf.reduce_sum(y * tf.log(model), reduction_indices = [1])loss = tf.reduce_mean(reduce_sum)train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)saver = tf.train.Saver()with tf.Session() as sess:init = tf.global_variables_initializer()sess.run(init)for i in range(10):train_batch = zip(range(0, len(X_train), batch_size), range(batch_size, len(X_train) + 1, batch_size))for start, end in train_batch:sess.run(train_step, feed_dict = {x:X_train[start:end],y:Y_train[start:end],keep_prob:0.75})saver.save(sess, './model/mymodel')
加载恢复模型
def predictByModel(img):with tf.Session() as sess:saver = tf.train.import_meta_graph('model/mymodel.meta')saver.restore(sess, tf.train.latest_checkpoint("model/"))result = sess.run(model, feed_dict={x:img, keep_prob:1})rIdx = np.argmax(result)pred = names[rIdx]return pred
实现预测
from matplotlib import image
mytest = image.imread('./test.png')
mytest_gray = mytest[:, :, 0]
mytest_img = np.reshape(mytest_gray, [1, 28, 28, 1])
print('your draw is ', predictByModel(mytest_img))
plt.imshow(mytest)

your draw is eraser.npy
<matplotlib.image.AxesImage at 0x7f3d4b6e0160>

总结:这里只是使用CNN实现一个小项目,达到使用CNN分类的情况,上面的步骤省去了,测试。直接训练完后,拿来预测分类。
数据集和源代码: GitHub

更多推荐

使用CNN实现Google你画电脑猜

本文发布于:2024-02-06 06:10:44,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1747349.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:电脑   CNN   Google

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!