WXL's blog

Talk is cheap, show me your work.

0%

pytorch的parameters相关函数

register_parameter()和parameter()

这两个都是一个东西,使用上有细微差别,主要作用是:将一个不可训练的类型Tensor转换成可以训练的类型parameter,并将这个parameter绑定到这个module里面,相当于变成了模型的一部分,成为了模型中可以根据训练进行变化的参数。

使用方法:

1
2
3
4
5
6
7
8
9
10
11
class Example(nn.Module):
def __init__(self):
super(Example, self).__init__()
print('看看我们的模型有哪些parameter:\t', self._parameters, end='\n')
self.W1_params = nn.Parameter(torch.rand(2,3))
print('增加W1后看看:',self._parameters, end='\n')

self.register_parameter('W2_params' , nn.Parameter(torch.rand(2,3)))
print('增加W2后看看:',self._parameters, end='\n')
def forward(self, x):
return x

这里还有一个self._parameters,是在继承了nn.Module的类内所使用的。nn.Linear类中的代码如下:

1
2
3
4
5
6
7
8
9
10
11
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
# 使用默认的方法初始化可训练参数
self.reset_parameters()

参考:

pytorch中的register_parameter()和parameter()

reset_parameters

初始化参数,在继承了nn.Module类的网络类中使用:self.reset_parameters(),pytorch提供了多种初始化函数:

1
2
3
4
torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)
kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

常见的初始化:

nn.init.uniform_()

1
2
w = torch.empty(3, 5)
nn.init.uniform_(w, a=0.0, b=1.0)

nn.init.normal_()

1
2
w = torch.empty(3, 5)
nn.init.normal_(w, mean=0.0, std=1.0)

nn.init.constant_()

Fill the input Tensor with the value val

1
2
w = torch.empty(3, 5)
nn.init.constant_(w, 0.3)

nn.init.ones_()

1
2
w = torch.empty(3, 5)
nn.init.ones_(w)

nn.init.zeros_()

1
2
w = torch.empty(3, 5)
nn.init.zeros_(w)

nn.init.eye_()

1
2
w = torch.empty(3, 5)
nn.init.eye_(w)

nn.init.xavier_uniform_()

1
2
w = torch.empty(3, 5)
nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))

nn.init.kaiming_uniform_()

1
2
w = torch.empty(3, 5)
nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')

nn.init.kaiming_normal_()

1
2
w = torch.empty(3, 5)
nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')

net.apply()

用法:

1
2
3
4
5
6
7
def weights_init(m):
classname=m.__class__.__name__
if classname.find('Conv') != -1:
xavier(m.weight.data)
xavier(m.bias.data)
net = Net()
net.apply(weights_init) #apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。

参考:

Pytorch参数初始化–默认与自定义

行行好,赏一杯咖啡吧~