图神经网络实战——匈牙利水痘病例预测

编程入门 行业动态 更新时间:2024-10-20 15:54:03

图神经网络实战——<a href=https://www.elefans.com/category/jswz/34/1746003.html style=匈牙利水痘病例预测"/>

图神经网络实战——匈牙利水痘病例预测

目录

引言 

问题分析

图神经网络

PyG-Temporal        

代码和效果展示


引言 

        机器学习大作业要对匈牙利水痘病例进行预测,根据论文1《Chickenpox Cases in Hungary: a Benchmark Dataset for Spatiotemporal Signal Processing with Graph Neural Networks》中的分析说明,水痘病例数具有明显的季节周期性,数据集中的20个country的病例数如下图所示

        仅凭时间序列的规律进行预测,我也已经尝试过 XGBoost 和 LSTM 也进行了效果展示(可以直接点击超链接转到我之前的博客),在规模较小的country中,做出的预测偏差还可以接受,但是当规模较大时,预测的结果偏离和延迟都比较大。

问题分析

        原因在论文1中也描述了,原因有两个,第一:时间序列并不是很稳定,每个country的滚动平均值并不稳定,说明随着季节变化的序列噪声对预测结果的影响还是比较大的,其中的原因可能是人口转移的影响,也可能是内在流行病学动态的结果,Budapest(匈牙利首都布达佩斯)、Fejer、Szabolcs和Zala四个country的 Running mean 如下图

第二个原因就是受到空间连接的影响了,这很显而易见,两个country相邻,一个country的病例数肯定会对临近country的病例数有一些影响的,因为水痘是传染病嘛,而且,越大的country对周边country的影响越大,甚至,大country自身的病例数跟自身规模成指数相关。因此,在论文1中就使用 莫兰指数 评估了空间自适应性,如下图所示

 r可以理解为计算机网络中的跳数,r=1意味着两个country直接相邻,r=2意味着两个country之间隔了一个country,由此可见,相邻country之间的空间相关性还是很大的,当然不能忽视这种影响。

        因此在解决这个问题时,肯定要考虑将空间连接图加入到我们要训练的模型中,这就引出了今天的主角——图神经网络(Graph Neural Networks,GNN)。

图神经网络

        就本题而言,我们把把各个县视为节点,县之间的连接视为路径,就可以把给出的空间连接信息转化为无向图,也就是说,我们在训练模型的时候必须考虑不同县之间的联系,将这张空间连接图加入到网络中。而专门处理这类数据的有效方法就是图神经网络(Graph Neural Networks,GNN)。

        GCN包含节点和边信息的图加入到神经网络中,在数据结构中,我们使用邻接矩阵来表示图,一般的神经网络每层就是输入矩阵乘上权重矩阵再经过一个激活函数即可,而图神经网络就是在每次计算时多乘上一个邻接矩阵来将相应的结构信息加入到网络中一起训练

上面的理解是我最简单的理解,下面几篇博客对GCN从整体到局部的介绍还可以

首先是这篇介绍性的文章,看完后我不仅仅是了解了到底什么叫GCN,还弄清了它的发展、地位和应用场景,讲的确实很清楚:什么是图神经网络?有什么用?终于有人讲明白了

PyG-Temporal        

        更多的文章我也没有什么好的推荐了,因为只是为了做这个作业,对GCN有了一个大致的了解,然后从GitHub上下载了目前最火热的开源GCN的项目代码——PyG-Temporal,装好相关的包之后就能直接运行,连输入都给我写好了,根本就不需要调试。

        至于装包的过程确实是有几分艰辛,需要安装一整套PyG-Temporal相关包,分别是torch及对应版本的torch_cluster、torch_scatter、torch_sparse、torch_spline_conv,具体的安装方法可以直接访问我的安装torch_scatter,torch-sparse,torch-cluster,torch-spline-conv这篇博客,当然,如果你是python3.9或者直接安装一个python3.9,那就可以直接从下面的链接里直接下载

下载链接:     提取码:1234

代码和效果展示

from tqdm import tqdmimport torch
import torch.nn.functional as F
import unittest
from torch_geometric_temporal.nn.recurrent import DCRNNfrom torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_splitloader = ChickenpoxDatasetLoader()dataset = loader.get_dataset()train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
print(train_dataset,test_dataset)class RecurrentGCN(torch.nn.Module):def __init__(self, node_features):super(RecurrentGCN, self).__init__()self.recurrent = DCRNN(node_features, 32, 1)self.linear = torch.nn.Linear(32, 1)def forward(self, x, edge_index, edge_weight):h = self.recurrent(x, edge_index, edge_weight)h = F.relu(h)h = self.linear(h)return hmodel = RecurrentGCN(node_features = 4)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)model.train()for epoch in tqdm(range(200)):cost = 0for time, snapshot in enumerate(train_dataset):y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)cost = cost + torch.mean((y_hat-snapshot.y)**2)cost = cost / (time+1)cost.backward()optimizer.step()optimizer.zero_grad()model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)print(y_hat)print(snapshot.y)cost = cost + torch.mean((y_hat-snapshot.y)**2)print('cha:', y_hat - snapshot.y)print(cost)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))

        该代码是直接使用loader.get_dataset函数直接从对应数据集网址里直接下载的数据,也就是说,我们装好包之后,连输入都不用给它就能直接运行,最后输出的是预测结果的MSE,效果还是很不错的。具体的代码解释官方已经给出了,这是翻译版本PyTorch Geometric Temporal

 

更多推荐

图神经网络实战——匈牙利水痘病例预测

本文发布于:2024-02-26 19:08:55,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1703562.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:匈牙利   水痘   神经网络   病例   实战

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!