WXL's blog

Talk is cheap, show me your work.

0%

网络可视化

介绍常见的可视化网络模型和训练过程的方法。

网络模型可视化

HiddenLayer库可视化

参考:hiddenlayer

首先安装hiddenlayer库和graphviz库:

1
2
pip install hiddenlayer
conda install python-graphviz

用如下实例测试一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import hiddenlayer as hl
from torch import nn
import torch


class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Sequential(
nn.Linear(in_features=32*7*7, out_features=128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU()
)
self.out = nn.Linear(64, 10)

def forward(self, X):
H1 = self.conv1(X)
H2 = self.conv2(H1)
H2 = H2.view(H2.size(0), -1)
H3 = self.fc(H2)
output = self.out(H3)
return output

net = ConvNet()
# print(net)
hl_graph = hl.build_graph(net, torch.zeros([1, 1, 28, 28]))
hl_graph.theme = hl.graph.THEMES["blue"].copy()
hl_graph.save('hl.png', format='png')

PyTorchViz可视化

还是上面的net,可视化代码如下:

1
2
3
4
5
6
7
x = torch.randn(1, 1, 28, 28).requires_grad_(True)
y = net(x)
viz = make_dot(y, params=dict(list(net.named_parameters()) + [('x', x)]))
# 将viz保存为图片
viz.format = 'png' # 保存的格式
viz.directory = 'img/myvis' # 保存的路径
viz.view() # 自动在当前文件夹生成相应目录和文件

更高级的方法:

参考这篇文章:

23 款神经网络的设计和可视化工具

行行好,赏一杯咖啡吧~