deeplearning4j训练推理案例2023——手写数字识别

编程入门 行业动态 更新时间:2024-10-24 21:22:17

deeplearning4j训练推理<a href=https://www.elefans.com/category/jswz/34/1770649.html style=案例2023——手写数字识别"/>

deeplearning4j训练推理案例2023——手写数字识别

文章目录

  • 1.minist数据集
  • 2.依赖包
  • 3.手写数字训练与推理
  • 4. 扩展阅读deeplearning4j自带学习案例项目deeplearning4j-examples

1.minist数据集

下载链接 6W训练集,1W测试集

2.依赖包

主要是deeplearning4j、javacv的一些包,案例打出的jar包1.3G,pom来自github deeplearning子项目deeplearning4j-examples 的dl4j-examples模块

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns=".0.0" xmlns:xsi=""xsi:schemaLocation=".0.0 .0.0.xsd"><modelVersion>4.0.0</modelVersion><parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.7.9</version><relativePath/></parent><groupId>com.example</groupId><artifactId>demo</artifactId><version>0.0.1-SNAPSHOT</version><name>demo</name><description>demo</description><properties><dl4j-master.version>1.0.0-M2.1</dl4j-master.version><nd4j.backend>nd4j-native</nd4j.backend><java.version>17</java.version><maven-compiler-plugin.version>3.8.1</maven-compiler-plugin.version><maven.minimum.version>3.3.1</maven.minimum.version><exec-maven-plugin.version>1.4.0</exec-maven-plugin.version><maven-shade-plugin.version>2.4.3</maven-shade-plugin.version><jcommon.version>1.0.23</jcommon.version><jfreechart.version>1.0.13</jfreechart.version><logback.version>1.1.7</logback.version><project.build.sourceEncoding>UTF-8</project.build.sourceEncoding><junit.version>5.8.0-M1</junit.version><javacv.version>1.5.9</javacv.version></properties><dependencyManagement><dependencies><dependency><groupId>org.bytedeco</groupId><artifactId>javacv-platform</artifactId><version>${javacv.version}</version></dependency></dependencies></dependencyManagement><dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter</artifactId></dependency><dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId><scope>test</scope></dependency><dependency><groupId>org.nd4j</groupId><artifactId>${nd4j.backend}</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.datavec</groupId><artifactId>datavec-api</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.datavec</groupId><artifactId>datavec-data-image</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.datavec</groupId><artifactId>datavec-local</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-datasets</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-core</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>resources</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-ui</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-zoo</artifactId><version>${dl4j-master.version}</version></dependency><!-- ParallelWrapper & ParallelInference live here --><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-parallel-wrapper</artifactId><version>${dl4j-master.version}</version></dependency><!-- Used in the feedforward/classification/MLP* and feedforward/regression/RegressionMathFunctions example --><dependency><groupId>jfree</groupId><artifactId>jfreechart</artifactId><version>${jfreechart.version}</version></dependency><dependency><groupId>org.jfree</groupId><artifactId>jcommon</artifactId><version>${jcommon.version}</version></dependency><!-- Used for downloading data in some of the examples --><dependency><groupId>org.apache.httpcomponents</groupId><artifactId>httpclient</artifactId><version>4.3.5</version></dependency><dependency><groupId>ch.qos.logback</groupId><artifactId>logback-classic</artifactId><version>${logback.version}</version></dependency><dependency><groupId>org.bytedeco</groupId><artifactId>javacv-platform</artifactId></dependency><dependency><groupId>org.nd4j</groupId><artifactId>nd4j-api</artifactId><version>1.0.0-M2.1</version></dependency></dependencies><build><plugins><plugin><groupId>org.springframework.boot</groupId><artifactId>spring-boot-maven-plugin</artifactId></plugin><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-compiler-plugin</artifactId><configuration><source>17</source><target>17</target></configuration></plugin></plugins></build></project>

3.手写数字训练与推理

1个epoch训练耗时100s,准确率达97%,详见代码注释,框架的api做得还比较好用

package ai;import lombok.extern.slf4j.Slf4j;
import org.apachemons.io.FileUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4jmon.io.Assert;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;import java.io.File;
import java.util.Random;@Slf4j
public class LeNetMNISTReLu {private static final String DATASET_PATH_BASE = "D:\\";public static void main(String[] args) throws Exception {int height = 28;int width = 28;// 黑白图片通道只有一个int channels = 1;// 0-9十种数字int outputNum = 10;int batchSize = 64;// 这里一个epoch耗时约100s,3次准确率99%int nEpochs = 1;Assert.isTrue(new File(DATASET_PATH_BASE + "/mnist_png").exists(), "请下载压缩包并解压到" + DATASET_PATH_BASE);// 该label生成器会将数据所在父目录名作为label,要求目录名必须为数值,这里mnist数据集正好是放在0-9文件夹的ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();// 归一化(0-1)DataNormalization normalization = new ImagePreProcessingScaler();Random random = new Random(12345);log.info("训练集6W张...");File trainData = new File(DATASET_PATH_BASE + "/mnist_png/training");FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, random);ImageRecordReader trainRecordReader = new ImageRecordReader(height, width, channels, labelMaker);trainRecordReader.initialize(trainSplit);DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, batchSize, 1, outputNum);normalization.fit(trainIter);trainIter.setPreProcessor(normalization); // 先像素归一化log.info("验证集1W张...");File validateData = new File(DATASET_PATH_BASE + "/mnist_png/testing");FileSplit validateSplit = new FileSplit(validateData, NativeImageLoader.ALLOWED_FORMATS, random);ImageRecordReader validateRecordReader = new ImageRecordReader(height, width, channels, labelMaker);validateRecordReader.initialize(validateSplit);DataSetIterator validateIter = new RecordReaderDataSetIterator(validateRecordReader, batchSize, 1, outputNum);validateIter.setPreProcessor(normalization);// 训练集6W数据 每次迭代batchSize=64,故这里大概有1000次迭代// 学习率,每200个迭代更新一次学习率(步长),先大一点,还可以每个Epoch更新一次学习率MapSchedule mapSchedule = new MapSchedule.Builder(ScheduleType.ITERATION).add(0, 0.06).add(200, 0.05).add(600, 0.028).add(800, 0.006).add(1000, 0.001).build();// 超参MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(1).l2(0.0005).updater(new Nesterovs(mapSchedule))//.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) //该优化器导致长时间无法拟合.weightInit(WeightInit.XAVIER).list().layer(new ConvolutionLayer.Builder(5, 5).nIn(channels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()).layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()).layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image.build();// 神经网络对象构建MultiLayerNetwork net = new MultiLayerNetwork(conf);net.init();// 训练监控,每次迭代打印损失函数值net.setListeners(new ScoreIterationListener(10));// WEB UI监控训练过程//UIServer uiServer = UIServer.getInstance();//FileStatsStorage statsStorage = new FileStatsStorage(new File("D:\\ai-webui.dat"));//uiServer.attach(statsStorage);//net.setListeners(new StatsListener(statsStorage));log.info("网络参数个数{}", net.numParams());long startTime = System.currentTimeMillis();// 训练epochs轮for (int i = 0; i < nEpochs; i++) {log.info("Epoch=" + i);net.fit(trainIter);Evaluation eval = net.evaluate(validateIter);log.info(eval.stats());trainIter.reset();validateIter.reset();}log.info("训练耗时{}毫秒", System.currentTimeMillis() - startTime);// 保存模型File ministModelPath = new File(DATASET_PATH_BASE + "/ministModel.zip");ModelSerializer.writeModel(net, ministModelPath, true);// 推理逻辑:加载网络(模型)——>加载测试图片——>预测MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(new File(DATASET_PATH_BASE + "/ministModel.zip"));NativeImageLoader imageLoader = new NativeImageLoader(height, width, channels);FileUtils.listFiles(new File("D:\\mnist_png\\testing"), null, true).parallelStream().forEach(file -> {try {INDArray matrix = imageLoader.asMatrix(file);INDArray output = network.output(matrix);// 取最可能的预测结果int predictedValue = Nd4j.argMax(output, 1).getInt(0);// 数字图片按数值放在每个文件夹的,故图片所在文件夹名即为真实值String realValue = file.getParentFile().getName();log.info("真实值:{},预测值:{}", realValue, predictedValue);Assert.isTrue(predictedValue == Integer.parseInt(realValue), file.getAbsolutePath() + "预测错误");} catch (Exception e) {log.warn(e.getMessage(), e);}});}
}

4. 扩展阅读deeplearning4j自带学习案例项目deeplearning4j-examples

deeplearning4j-examples 参考其readme文档

更多推荐

deeplearning4j训练推理案例2023——手写数字识别

本文发布于:2023-12-03 12:30:03,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1655364.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:案例   数字   deeplearning4j

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!