Torch的参数初始化

编程入门 行业动态 更新时间:2024-10-28 18:31:01

Torch的参数<a href=https://www.elefans.com/category/jswz/34/1770206.html style=初始化"/>

Torch的参数初始化

1.不需要初始化

  • 调用nn.Linear()等封装好的模块,不需要初始化
    def __init__(self, embed_size, heads, adj, dropout, forward_expansion):super(STransformer, self).__init__()# Spatial Embeddingself.adj = adjself.D_S = nn.Parameter(adj)self.embed_liner = nn.Linear(adj.shape[0], embed_size)self.attention = SSelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion * embed_size),nn.ReLU(),nn.Linear(forward_expansion * embed_size, embed_size),)# 调用GCN# input:embed_size;  hidden: embed_size*2;  outpt:embed_sizeself.gcn = GCN(embed_size, embed_size*2, embed_size, dropout)  self.norm_adj = nn.InstanceNorm2d(1)    # 对邻接矩阵归一化self.dropout = nn.Dropout(dropout)self.fs = nn.Linear(embed_size, embed_size)self.fg = nn.Linear(embed_size, embed_size)

2.需要初始化

  • 只有自己定义的参数,例如weight与bias才需要自定义初始化。一般在__init__层里,调用self.reset_parameters()来实现。
def __init__(self, in_features,out_features,bias=True):self.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()
  • 在self.reset_parameters()函数里,使用两种方法
    • 第一种:nn.init.xavier_uniform_(x, gain=nn.init.calculate_gain(‘relu’))。其中, gain 参数来自定义初始化的标准差来匹配特定的激活函数:
    def reset_parameters(self, reset_mode='glorot_uniform'):nn.init.xavier_uniform_(self.bases, gain=nn.init.calculate_gain('relu'))nn.init.xavier_uniform_(self.comps, gain=nn.init.calculate_gain('relu'))nn.init.xavier_uniform_(self.weights, gain=nn.init.calculate_gain('relu'))if self.bias is not None:torch.nn.init.zeros_(self.bias)

-第二种:变量.data.uniform_(-stdv, stdv)

    def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)
  • 补充:
    raise ValueError(“Fan in and fan out can not be computed for tensor with fewer than 2 dimensions”)
        # nn.init.xavier_uniform_(self.bv)   #  raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")self.bv.data.fill_(0)

更多推荐

Torch的参数初始化

本文发布于:2023-07-28 19:46:30,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1291695.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:初始化   参数   Torch

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!