admin管理员组

文章数量:1633991

1 #coding: utf-8

2 from PyQt5.QtWidgets import *

3 from PyQt5.QtGui import *

4 from PyQt5.QtCore import *

5 importsys6 sys.path.append(r'../ml/torch')7 from digit_recog importNet8 importtorch9 importos10 importnumpy as np11 importmatplotlib.pyplot as plt12 from PIL importImage13

14

15 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")16 net =Net().to(device)17 #加载参数

18 nn_state = torch.load(os.path.join('../ml/torch/model/', 'net.pth'))19 #参数加载到指定模型

20 net.load_state_dict(nn_state)21 net.eval()22

23

24 defpredict(img):25 #读取图片并重设尺寸

26 image = Image.open(img).resize((28, 28))27 #灰度图

28 gray_image = image.convert('L')29 #plt.imshow(gray_image)

30 #plt.show()

31 #图片数据处理

32 im_data =np.array(gray_image)33 im_data =torch.from_nump

本文标签: 输入法入门入道PythonPytorch