使用 TensorFlowJS 进行简单的线性回归

编程入门 行业动态 更新时间:2024-10-11 23:25:31
本文介绍了使用 TensorFlowJS 进行简单的线性回归的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

限时送ChatGPT账号..

我正在尝试为一个项目进行一些线性回归.由于习惯了 Javascript,我决定尝试使用 TensorFlowJS.

I'm trying to get some linear regression for a project. As I'm used to Javascript, I decided to try and use TensorFlowJS.

我正在关注他们网站上的教程,并观看了一些解释其工作原理的视频,但我仍然不明白为什么我的算法没有返回我期望的结果.

I'm following the tutorial from their website and have watched some videos explaining how it works, but I still can't understand why my algorithm doesn't return the result I expect.

这是我正在做的:

// Define a model for linear regression.
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));

// Prepare the model for training: Specify the loss and the optimizer.
modelpile({loss: 'meanSquaredError', optimizer: 'sgd'});

// Generate some synthetic data for training.
const xs = tf.tensor1d([1, 2, 3, 4]);
const ys = tf.tensor1d([1, 2, 3, 4]);

// Train the model using the data.
model.fit(xs, ys).then(() => {
  // Use the model to do inference on a data point the model hasn't seen before:
  // Open the browser devtools to see the output
  const output = model.predict(tf.tensor2d([5], [1,1]));
  console.log(Array.from(output.dataSync())[0]);
});

我在这里尝试制作一个线性图,其中输入应始终等于输出.

I'm trying here to have a linear graph, where the input should always be equal to the output.

我试图预测输入 5 会得到什么,但输出似乎是随机的.

I'm trying to predict what I would get with an input of 5, however it seems that the output is random.

这是在 codepen 上,您可以尝试:https://codepen.io/anon/pen/RJJNeO?editors=0011

Here it is on codepen so you can try: https://codepen.io/anon/pen/RJJNeO?editors=0011

推荐答案

您的模型仅在 一个时期(一个训练周期)后进行预测.结果损失仍然很大,导致预测不准确.

Your model is making prediction after only one epoch (one cyle of training). As a result the loss is still big which leads to unaccurate prediction.

模型的权重随机初始化.所以只有一个时期,预测是非常随机的.这就是为什么,一个人需要训练不止一个时期,或者在每批之后更新权重(这里你也只有一批).要查看训练期间的损失,您可以通过这种方式更改拟合方法:

The weights of the model are initialized randomly. So with only one epoch, the prediction is very random. That's why, one needs to train for more than one epoch, or update weights after each batch (here you have only one batch also). To have a look at the loss during training, you can change your fit method that way:

model.fit(xs, ys, { 
  callbacks: {
      onEpochEnd: (epoch, log) => {
        // display loss
        console.log(epoch, log.loss);
      }
    }}).then(() => {
  // make the prediction after one epoch
})

为了获得准确的预测,您可以增加 epochs 的数量

To get accurate prediction, you can increase the number of epochs

 model.fit(xs, ys, {
      epochs: 50,  
      callbacks: {
          onEpochEnd: (epoch, log) => {
            // display loss
            console.log(epoch, log.loss);
          }
        }}).then(() => {
      // make the prediction after one epoch
    })

这是一个片段,展示了增加 epoch 数将如何帮助模型表现良好

Here is a snippet which shows how increasing the number of epochs will help the model to perform well

// Define a model for linear regression.
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));

// Prepare the model for training: Specify the loss and the optimizer.
modelpile({loss: 'meanSquaredError', optimizer: 'sgd'});

// Generate some synthetic data for training.
const xs = tf.tensor1d([1, 2, 3, 4]);
const ys = tf.tensor1d([1, 2, 3, 4]);

// Train the model using the data.
model.fit(xs, ys, {
  epochs: 50,
  callbacks: {
      onEpochEnd: (epoch, log) => {
        console.log(epoch, log.loss);
      }
    }}).then(() => {
  // Use the model to do inference on a data point the model hasn't seen before:
  // Open the browser devtools to see the output
  const output = model.predict(tf.tensor2d([6], [1,1]));
  output.print();
});

<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdnjs.cloudflare/ajax/libs/tensorflow/0.12.4/tf.js"> </script>
  </head>

  <body>
  </body>
</html>

这篇关于使用 TensorFlowJS 进行简单的线性回归的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

更多推荐

[db:关键词]

本文发布于:2023-05-01 01:45:58,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1402369.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:线性   简单   TensorFlowJS

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!