python .npy文件自制数据集基本使用

编程入门 行业动态 更新时间:2024-10-06 10:27:54

python .npy<a href=https://www.elefans.com/category/jswz/34/1771438.html style=文件自制数据集基本使用"/>

python .npy文件自制数据集基本使用

读取和保存npy文件

有实习需要我们对数据集做出一系列处理,如归一化等等,如果每次从文件中读出来使用都要进行归一化、转换格式等操作会很浪费时间,于是可以自制数据集并保存到npy文件中以供随时调用。

import tensorflow as tf
from PIL import Image
import numpy as np
import ostrain_path = './mnist_image_label/mnist_train_jpg_60000/'
train_txt = './mnist_image_label/mnist_train_jpg_60000.txt'
x_train_savepath = './mnist_image_label/mnist_x_train.npy'
y_train_savepath = './mnist_image_label/mnist_y_train.npy'test_path = './mnist_image_label/mnist_test_jpg_10000/'
test_txt = './mnist_image_label/mnist_test_jpg_10000.txt'
x_test_savepath = './mnist_image_label/mnist_x_test.npy'
y_test_savepath = './mnist_image_label/mnist_y_test.npy'def generateds(path, txt):f = open(txt, 'r')  # 以只读形式打开txt文件contents = f.readlines()  # 读取文件中所有行f.close()  # 关闭txt文件x, y_ = [], []  # 建立空列表for content in contents:  # 逐行取出value = content.split()  # 以空格分开,图片路径为value[0] , 标签为value[1] , 存入列表img_path = path + value[0]  # 拼出图片路径和文件名img = Image.open(img_path)  # 读入图片img = np.array(img.convert('L'))  # 图片变为8位宽灰度值的np.array格式img = img / 255.  # 数据归一化 (实现预处理)x.append(img)  # 归一化后的数据,贴到列表xy_.append(value[1])  # 标签贴到列表y_print('loading : ' + content)  # 打印状态提示x = np.array(x)  # 变为np.array格式y_ = np.array(y_)  # 变为np.array格式y_ = y_.astype(np.int64)  # 变为64位整型return x, y_  # 返回输入特征x,返回标签y_if os.path.exists(x_train_savepath) and os.path.exists(y_train_savepath) and os.path.exists(x_test_savepath) and os.path.exists(y_test_savepath):print('-------------Load Datasets-----------------')x_train_save = np.load(x_train_savepath)y_train = np.load(y_train_savepath)x_test_save = np.load(x_test_savepath)y_test = np.load(y_test_savepath)x_train = np.reshape(x_train_save, (len(x_train_save), 28, 28))x_test = np.reshape(x_test_save, (len(x_test_save), 28, 28))
else:print('-------------Generate Datasets-----------------')x_train, y_train = generateds(train_path, train_txt)x_test, y_test = generateds(test_path, test_txt)print('-------------Save Datasets-----------------')x_train_save = np.reshape(x_train, (len(x_train), -1))x_test_save = np.reshape(x_test, (len(x_test), -1))np.save(x_train_savepath, x_train_save)np.save(y_train_savepath, y_train)np.save(x_test_savepath, x_test_save)np.save(y_test_savepath, y_test)model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])modelpile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

例子中:

import numpy as np

# .npy文件是numpy专用的二进制文件
arr = np.array([[1, 2], [3, 4]])

# 保存.npy文件
np.save("../data/arr.npy", arr)
print("save .npy done")

# 读取.npy文件
np.load("../data/arr.npy")
print(arr)
print("load .npy done")

参考:

 

更多推荐

python .npy文件自制数据集基本使用

本文发布于:2024-02-28 05:59:12,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1768329.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:文件   数据   python   npy

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!