tensorflow入门线性回归

编程入门 行业动态 更新时间:2024-10-28 19:33:08

实际上编写tensorflow可以总结为两步.

       (1)组装一个graph;

       (2)使用session去执行graph中的operation。

 

 

当使用tensorflow进行graph构建时,大体可以分为五部分:

     1、为输入X输出y定义placeholder;

    2、定义权重W;

    3、定义模型结构;

    4、定义损失函数;

    5、定义优化算法

下面是手写识别字程序:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
#导入数据集
x = tf.placeholder(shape=[None,784],dtype=tf.float32)
y = tf.placeholder(shape=[None,10],dtype=tf.float32)
#为输入输出定义placehloderw = tf.Variable(tf.truncated_normal(shape=[784,10],mean=0,stddev=0.5))
b = tf.Variable(tf.zeros([10]))
#定义权重
y_pred = tf.nn.softmax(tf.matmul(x,w)+b)
#定义模型结构
loss =tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pred),reduction_indices=[1]))
#定义损失函数
opt = tf.train.GradientDescentOptimizer(0.05).minimize(loss)
#定义优化算法
sess =tf.Session()
sess.run(tf.global_variables_initializer())
for each in range(1000):batch_xs,batch_ys = mnist.train.next_batch(100)loss1 = sess.run(loss,feed_dict={x:batch_xs,y:batch_ys})opt1 = sess.run(opt,feed_dict={x:batch_xs,y:batch_ys})print(loss1)

更多推荐

线性,入门,tensorflow

本文发布于:2023-05-27 16:59:27,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/300605.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:线性   入门   tensorflow

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!