可视化
目录
可视化#
Haiku 支持两种可视化程序的方式。要使用这些功能,您需要安装两个额外的依赖项
[1]:
!pip install dm-tree graphviz
Requirement already satisfied: dm-tree in /tmp/haiku-env/lib/python3.11/site-packages (0.1.8)
Requirement already satisfied: graphviz in /tmp/haiku-env/lib/python3.11/site-packages (0.20.1)
[2]:
import jax
import jax.numpy as jnp
import haiku as hk
Tabulate#
像许多神经网络库一样,Haiku 支持将程序执行的摘要显示为模块表。Haiku 的方法是跟踪程序执行过程,并生成一个(有趣的)模块方法调用表。
例如,一个 3 层 MLP 的有趣方法将是 MLP.__call__
,它反过来又在三个内部模块上调用 Linear.__call__
。 对于每个模块方法,我们显示与数组的输入/输出大小相关的列,以及模块参数的详细信息以及它在模块层次结构中的位置。
[3]:
def f(x):
return hk.nets.MLP([300, 100, 10])(x)
f = hk.transform(f)
x = jnp.ones([8, 28 * 28])
print(hk.experimental.tabulate(f)(x))
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| Module | Config | Module params | Input | Output | Param count | Param bytes |
+=========================+==========================================+=================+============+============+===============+===============+
| mlp (MLP) | MLP(output_sizes=[300, 100, 10]) | | f32[8,784] | f32[8,10] | 266,610 | 1.07 MB |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_0 (Linear) | Linear(output_size=300, name='linear_0') | w: f32[784,300] | f32[8,784] | f32[8,300] | 235,500 | 942.00 KB |
| └ mlp (MLP) | | b: f32[300] | | | | |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_1 (Linear) | Linear(output_size=100, name='linear_1') | w: f32[300,100] | f32[8,300] | f32[8,100] | 30,100 | 120.40 KB |
| └ mlp (MLP) | | b: f32[100] | | | | |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_2 (Linear) | Linear(output_size=10, name='linear_2') | w: f32[100,10] | f32[8,100] | f32[8,10] | 1,010 | 4.04 KB |
| └ mlp (MLP) | | b: f32[10] | | | | |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
如果您想创建自己的摘要,我们还提供对用于构建此表的原始数据的访问权限
[4]:
for method_invocation in hk.experimental.eval_summary(f)(x):
print(method_invocation)
MethodInvocation(module_details=ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={}), args_spec=(f32[8,784],), kwargs_spec={}, output_spec=f32[8,10], context=MethodContext(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', orig_method=functools.partial(<function MLP.__call__ at 0x7f173d83f600>, MLP(output_sizes=[300, 100, 10])), orig_class=<class 'haiku._src.nets.mlp.MLP'>), call_stack=(ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={}),))
MethodInvocation(module_details=ModuleDetails(module=Linear(output_size=300, name='linear_0'), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300]}, state={}), args_spec=(f32[8,784],), kwargs_spec={}, output_spec=f32[8,300], context=MethodContext(module=Linear(output_size=300, name='linear_0'), method_name='__call__', orig_method=functools.partial(<function Linear.__call__ at 0x7f173d927e20>, Linear(output_size=300, name='linear_0')), orig_class=<class 'haiku._src.basic.Linear'>), call_stack=(ModuleDetails(module=Linear(output_size=300, name='linear_0'), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300]}, state={}), ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={})))
MethodInvocation(module_details=ModuleDetails(module=Linear(output_size=100, name='linear_1'), method_name='__call__', params={'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100]}, state={}), args_spec=(f32[8,300],), kwargs_spec={}, output_spec=f32[8,100], context=MethodContext(module=Linear(output_size=100, name='linear_1'), method_name='__call__', orig_method=functools.partial(<function Linear.__call__ at 0x7f173d927e20>, Linear(output_size=100, name='linear_1')), orig_class=<class 'haiku._src.basic.Linear'>), call_stack=(ModuleDetails(module=Linear(output_size=100, name='linear_1'), method_name='__call__', params={'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100]}, state={}), ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={})))
MethodInvocation(module_details=ModuleDetails(module=Linear(output_size=10, name='linear_2'), method_name='__call__', params={'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={}), args_spec=(f32[8,100],), kwargs_spec={}, output_spec=f32[8,10], context=MethodContext(module=Linear(output_size=10, name='linear_2'), method_name='__call__', orig_method=functools.partial(<function Linear.__call__ at 0x7f173d927e20>, Linear(output_size=10, name='linear_2')), orig_class=<class 'haiku._src.basic.Linear'>), call_stack=(ModuleDetails(module=Linear(output_size=10, name='linear_2'), method_name='__call__', params={'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={}), ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={})))
Graphviz (又名 to_dot
)#
Haiku 支持将您的程序渲染为 graphviz 图。我们显示给定计算中涉及的所有 JAX 原语,并按 Haiku 模块进行聚类。
让我们从可视化一个不使用 Haiku 模块的简单程序开始
[5]:
def f(a):
b = jnp.sin(a)
c = jnp.cos(b)
d = b + c
e = a + d
return e
x = jnp.ones([1])
dot = hk.to_dot(f)(x)
import graphviz
graphviz.Source(dot)
[5]:
上面的可视化将我们的程序显示为一个简单的数据流图,其中我们的单个输入以橙色突出显示 (args[0]
),它通过一些操作并产生结果(以蓝色突出显示)。 原始操作(例如 sin
、cos
和 add
)以黄色突出显示。
实际的 Haiku 程序通常要复杂得多,涉及许多模块和更多的原始操作。 对于这些程序,按模块逐个模块地可视化程序通常很有用。
to_dot
通过按模块聚类操作来提供此功能。 同样,看一个例子可能是最简单的
[6]:
def f(x):
return hk.nets.MLP([300, 100, 10])(x)
f = hk.transform(f)
rng = jax.random.PRNGKey(42)
x = jnp.ones([8, 28 * 28])
params = f.init(rng, x)
dot = hk.to_dot(f.apply)(params, None, x)
graphviz.Source(dot)
[6]: