注意力模型GAT"/>
Keras实现图注意力模型GAT
简介:本文实现了一个GAT图注意力机制的网络层,可以在Keras中像调用Dense网络层、Input网络层一样直接搭积木进行网络组合。
一,基本展示
如下图所示,我们输入邻接矩阵和节点特征矩阵之后,可以直接调用myGraphAttention网络层得到每一头的注意力输出(节点emdbeding),十分的方便。
注意:上图有个BUG,最终的输出层应该是8,和输入节点特征保持一致,上图只是举一个例子。
二,代码实现
2.1 GAT网络层
GAT网络层的代码如下。
from __future__ import absolute_import
from keras.activations import relu
from keras import activations, constraints, initializers, regularizers
from keras import backend as K
from keras.layers import Layer, Dropout, LeakyReLU# 定义图卷积层
class myGraphAttention(Layer):def __init__(self,F_,activation='relu',use_bias=True,drop_rate = 0,kernel_initializer='glorot_uniform',bias_initializer='zeros',attn_kernel_initializer='glorot_uniform',kernel_regularizer=None,bias_regularizer=None,attn_kernel_regularizer=None,activity_regularizer=None,kernel_constraint=None,bias_constraint=None,attn_kernel_constraint=None,**kwargs):self.F_ = F_ # 输出的节点embeding维度self.activation = activations.get(activation) # 输出结果之前的激活函数self.use_bias = use_bias"""其他代码………………"""super(myGraphAttention, self).__init__(**kwargs)def build(self, input_shape):"""其他代码………………"""def call(self, inputs):X = inputs[0] # 节点特征 (N x F)A = inputs[1] # 邻接矩阵 (N x N)"""其他代码………………"""# 加上偏置if self.use_bias:node_features = K.bias_add(node_features, self.bias)# 最终的输出之前得激活一下output = self.activation(node_features)return outputdef compute_output_shape(self, input_shape):output_shape = input_shape[0][0], self.output_dimreturn output_shape
2.2 模型搭建
模型搭建的代码如下。
from keras.layers import Layer,Input,Dense,add,Lambda
from keras.models import Modelinp_adj_martrix = Input(shape=(5,5),name='adj_martrix')
inp_node_features = Input(shape=(5,8),name='node_features_martrix')# 在这里直接调用网络层
flat0 = myGraphAttention(12,name="head_0")([inp_node_features,inp_adj_martrix])
flat1 = myGraphAttention(12,name="head_1")([inp_node_features,inp_adj_martrix])
flat2 = myGraphAttention(12,name="head_2")([inp_node_features,inp_adj_martrix])flat = add([flat0,flat1,flat2])lorder = 1
one_node = Lambda(lambda inp: inp[:,0,:],name = "the-first-node-feature")(flat)o1 = Dense(32,activation="relu")(one_node)
o2 = Dense(32,activation="relu")(o1)
out = Dense(12)(o2)model = Model([inp_adj_martrix,inp_node_features],[out])
创作不易,需要完整代码4_liao我哦。还有很多预测网络结构。
更多推荐
Keras实现图注意力模型GAT
发布评论