交互式在线版本: Open In Colab

可视化#

Haiku 支持两种可视化程序的方式。要使用这些功能,您需要安装两个额外的依赖项

[1]:
!pip install dm-tree graphviz
Requirement already satisfied: dm-tree in /tmp/haiku-env/lib/python3.11/site-packages (0.1.8)
Requirement already satisfied: graphviz in /tmp/haiku-env/lib/python3.11/site-packages (0.20.1)
[2]:
import jax
import jax.numpy as jnp
import haiku as hk

Tabulate#

像许多神经网络库一样,Haiku 支持将程序执行的摘要显示为模块表。Haiku 的方法是跟踪程序执行过程,并生成一个(有趣的)模块方法调用表。

例如,一个 3 层 MLP 的有趣方法将是 MLP.__call__,它反过来又在三个内部模块上调用 Linear.__call__。 对于每个模块方法,我们显示与数组的输入/输出大小相关的列,以及模块参数的详细信息以及它在模块层次结构中的位置。

[3]:
def f(x):
  return hk.nets.MLP([300, 100, 10])(x)

f = hk.transform(f)
x = jnp.ones([8, 28 * 28])

print(hk.experimental.tabulate(f)(x))
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| Module                  | Config                                   | Module params   | Input      | Output     |   Param count |   Param bytes |
+=========================+==========================================+=================+============+============+===============+===============+
| mlp (MLP)               | MLP(output_sizes=[300, 100, 10])         |                 | f32[8,784] | f32[8,10]  |       266,610 |       1.07 MB |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_0 (Linear) | Linear(output_size=300, name='linear_0') | w: f32[784,300] | f32[8,784] | f32[8,300] |       235,500 |     942.00 KB |
|  └ mlp (MLP)            |                                          | b: f32[300]     |            |            |               |               |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_1 (Linear) | Linear(output_size=100, name='linear_1') | w: f32[300,100] | f32[8,300] | f32[8,100] |        30,100 |     120.40 KB |
|  └ mlp (MLP)            |                                          | b: f32[100]     |            |            |               |               |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_2 (Linear) | Linear(output_size=10, name='linear_2')  | w: f32[100,10]  | f32[8,100] | f32[8,10]  |         1,010 |       4.04 KB |
|  └ mlp (MLP)            |                                          | b: f32[10]      |            |            |               |               |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+

如果您想创建自己的摘要,我们还提供对用于构建此表的原始数据的访问权限

[4]:
for method_invocation in hk.experimental.eval_summary(f)(x):
  print(method_invocation)
MethodInvocation(module_details=ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={}), args_spec=(f32[8,784],), kwargs_spec={}, output_spec=f32[8,10], context=MethodContext(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', orig_method=functools.partial(<function MLP.__call__ at 0x7f173d83f600>, MLP(output_sizes=[300, 100, 10])), orig_class=<class 'haiku._src.nets.mlp.MLP'>), call_stack=(ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={}),))
MethodInvocation(module_details=ModuleDetails(module=Linear(output_size=300, name='linear_0'), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300]}, state={}), args_spec=(f32[8,784],), kwargs_spec={}, output_spec=f32[8,300], context=MethodContext(module=Linear(output_size=300, name='linear_0'), method_name='__call__', orig_method=functools.partial(<function Linear.__call__ at 0x7f173d927e20>, Linear(output_size=300, name='linear_0')), orig_class=<class 'haiku._src.basic.Linear'>), call_stack=(ModuleDetails(module=Linear(output_size=300, name='linear_0'), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300]}, state={}), ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={})))
MethodInvocation(module_details=ModuleDetails(module=Linear(output_size=100, name='linear_1'), method_name='__call__', params={'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100]}, state={}), args_spec=(f32[8,300],), kwargs_spec={}, output_spec=f32[8,100], context=MethodContext(module=Linear(output_size=100, name='linear_1'), method_name='__call__', orig_method=functools.partial(<function Linear.__call__ at 0x7f173d927e20>, Linear(output_size=100, name='linear_1')), orig_class=<class 'haiku._src.basic.Linear'>), call_stack=(ModuleDetails(module=Linear(output_size=100, name='linear_1'), method_name='__call__', params={'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100]}, state={}), ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={})))
MethodInvocation(module_details=ModuleDetails(module=Linear(output_size=10, name='linear_2'), method_name='__call__', params={'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={}), args_spec=(f32[8,100],), kwargs_spec={}, output_spec=f32[8,10], context=MethodContext(module=Linear(output_size=10, name='linear_2'), method_name='__call__', orig_method=functools.partial(<function Linear.__call__ at 0x7f173d927e20>, Linear(output_size=10, name='linear_2')), orig_class=<class 'haiku._src.basic.Linear'>), call_stack=(ModuleDetails(module=Linear(output_size=10, name='linear_2'), method_name='__call__', params={'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={}), ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={})))

Graphviz (又名 to_dot)#

Haiku 支持将您的程序渲染为 graphviz 图。我们显示给定计算中涉及的所有 JAX 原语,并按 Haiku 模块进行聚类。

让我们从可视化一个不使用 Haiku 模块的简单程序开始

[5]:
def f(a):
  b = jnp.sin(a)
  c = jnp.cos(b)
  d = b + c
  e = a + d
  return e

x = jnp.ones([1])
dot = hk.to_dot(f)(x)

import graphviz
graphviz.Source(dot)
[5]:
../_images/notebooks_visualization_8_0.svg

上面的可视化将我们的程序显示为一个简单的数据流图,其中我们的单个输入以橙色突出显示 (args[0]),它通过一些操作并产生结果(以蓝色突出显示)。 原始操作(例如 sincosadd)以黄色突出显示。

实际的 Haiku 程序通常要复杂得多,涉及许多模块和更多的原始操作。 对于这些程序,按模块逐个模块地可视化程序通常很有用。

to_dot 通过按模块聚类操作来提供此功能。 同样,看一个例子可能是最简单的

[6]:
def f(x):
  return hk.nets.MLP([300, 100, 10])(x)

f = hk.transform(f)

rng = jax.random.PRNGKey(42)
x = jnp.ones([8, 28 * 28])
params = f.init(rng, x)

dot = hk.to_dot(f.apply)(params, None, x)
graphviz.Source(dot)
[6]:
../_images/notebooks_visualization_10_0.svg