初始化"/>
Torch的参数初始化
1.不需要初始化
- 调用nn.Linear()等封装好的模块,不需要初始化
def __init__(self, embed_size, heads, adj, dropout, forward_expansion):super(STransformer, self).__init__()# Spatial Embeddingself.adj = adjself.D_S = nn.Parameter(adj)self.embed_liner = nn.Linear(adj.shape[0], embed_size)self.attention = SSelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion * embed_size),nn.ReLU(),nn.Linear(forward_expansion * embed_size, embed_size),)# 调用GCN# input:embed_size; hidden: embed_size*2; outpt:embed_sizeself.gcn = GCN(embed_size, embed_size*2, embed_size, dropout) self.norm_adj = nn.InstanceNorm2d(1) # 对邻接矩阵归一化self.dropout = nn.Dropout(dropout)self.fs = nn.Linear(embed_size, embed_size)self.fg = nn.Linear(embed_size, embed_size)
2.需要初始化
- 只有自己定义的参数,例如weight与bias才需要自定义初始化。一般在__init__层里,调用self.reset_parameters()来实现。
def __init__(self, in_features,out_features,bias=True):self.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()
- 在self.reset_parameters()函数里,使用两种方法
- 第一种:nn.init.xavier_uniform_(x, gain=nn.init.calculate_gain(‘relu’))。其中, gain 参数来自定义初始化的标准差来匹配特定的激活函数:
def reset_parameters(self, reset_mode='glorot_uniform'):nn.init.xavier_uniform_(self.bases, gain=nn.init.calculate_gain('relu'))nn.init.xavier_uniform_(self.comps, gain=nn.init.calculate_gain('relu'))nn.init.xavier_uniform_(self.weights, gain=nn.init.calculate_gain('relu'))if self.bias is not None:torch.nn.init.zeros_(self.bias)
-第二种:变量.data.uniform_(-stdv, stdv)
def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)
- 补充:
raise ValueError(“Fan in and fan out can not be computed for tensor with fewer than 2 dimensions”)
# nn.init.xavier_uniform_(self.bv) # raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")self.bv.data.fill_(0)
更多推荐
Torch的参数初始化
发布评论