register_parameter()和parameter()
这两个都是一个东西,使用上有细微差别,主要作用是:将一个不可训练的类型Tensor
转换成可以训练的类型parameter
,并将这个parameter
绑定到这个module
里面,相当于变成了模型的一部分,成为了模型中可以根据训练进行变化的参数。
使用方法:
1 2 3 4 5 6 7 8 9 10 11
| class Example(nn.Module): def __init__(self): super(Example, self).__init__() print('看看我们的模型有哪些parameter:\t', self._parameters, end='\n') self.W1_params = nn.Parameter(torch.rand(2,3)) print('增加W1后看看:',self._parameters, end='\n') self.register_parameter('W2_params' , nn.Parameter(torch.rand(2,3))) print('增加W2后看看:',self._parameters, end='\n') def forward(self, x): return x
|
这里还有一个self._parameters
,是在继承了nn.Module的类内所使用的。nn.Linear类中的代码如下:
1 2 3 4 5 6 7 8 9 10 11
| def __init__(self, in_features, out_features, bias=True): super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters()
|
参考:
pytorch中的register_parameter()和parameter()
reset_parameters
初始化参数,在继承了nn.Module类的网络类中使用:self.reset_parameters()
,pytorch提供了多种初始化函数:
1 2 3 4
| torch.nn.init.constant(tensor, val) torch.nn.init.normal(tensor, mean=0, std=1) torch.nn.init.xavier_uniform(tensor, gain=1) kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
常见的初始化:
1 2
| w = torch.empty(3, 5) nn.init.uniform_(w, a=0.0, b=1.0)
|
nn.init.normal_()
1 2
| w = torch.empty(3, 5) nn.init.normal_(w, mean=0.0, std=1.0)
|
nn.init.constant_()
Fill the input Tensor with the value val
1 2
| w = torch.empty(3, 5) nn.init.constant_(w, 0.3)
|
nn.init.ones_()
1 2
| w = torch.empty(3, 5) nn.init.ones_(w)
|
nn.init.zeros_()
1 2
| w = torch.empty(3, 5) nn.init.zeros_(w)
|
nn.init.eye_()
1 2
| w = torch.empty(3, 5) nn.init.eye_(w)
|
1 2
| w = torch.empty(3, 5) nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
|
1 2
| w = torch.empty(3, 5) nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
|
nn.init.kaiming_normal_()
1 2
| w = torch.empty(3, 5) nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
|
net.apply()
用法:
1 2 3 4 5 6 7
| def weights_init(m): classname=m.__class__.__name__ if classname.find('Conv') != -1: xavier(m.weight.data) xavier(m.bias.data) net = Net() net.apply(weights_init)
|
参考:
Pytorch参数初始化–默认与自定义