visualize_token2token_scores(norm_fn(output_attentions_all, dim=2).squeeze().detach().cpu().numpy(), x_label_name='Layer')
维度变化链路
output_attentions_all:(layer, batch, head, seq_len, seq_len)
→ norm_fn(dim=2):聚合head维度 → (layer, batch, seq_len, seq_len)
→ squeeze():删除batch维度 → (layer, seq_len, seq_len)
→ 最终用于可视化:每层的“token-token注意力强度矩阵”(汇总所有头的信息)
维度格式
output_attentions_all.shape = (layer, batch, head, seq_len, seq_len)<br />
(文档中通过代码 output_attentions_all = torch.stack(output_attentions) 明确堆叠逻辑,且在 [29] 单元格注释中验证了维度构成)
各维度含义
-
layer(维度索引 0,数值示例为 12) 表示 BERT 模型中编码器的层数。以bert-base-uncased为例,模型默认包含 12 个 Transformer 编码层。 -
batch(维度索引 1,数值示例为 1) 表示输入样本的批次大小。文档示例中仅使用了 1 个问答对作为输入,因此该维度取值为 1。 -
head(维度索引 2,数值示例为 12) 表示每一层中的多头注意力头数。对于bert-base-uncased,每个编码层默认包含 12 个注意力头。 -
seq_len(维度索引 3,行维度,数值示例为 26) 表示输入序列的长度,包括[CLS]、[SEP]等特殊 token。该维度对应注意力的“发出者”(query)token。 -
seq_len(维度索引 4,列维度,数值示例为 26) 与上一维度含义一致,同样表示序列长度,但对应注意力的“接收者”(key)token。张量中的每个元素 $[l, b, h, i, j]$ 表示在第 $l$ 层、第 $b$ 个样本、第 $h$ 个注意力头下,第 $i$ 个 token 对第 $j$ 个 token 分配的注意力权重(经 softmax 归一化)。
文档中 norm_fn 是 L2 范数计算函数(基于 PyTorch 版本选择 torch.linalg.norm 或 torch.norm),调用方式为 norm_fn(output_attentions_all, dim=2),核心是在“注意力头(head)”维度上计算范数,以汇总每层所有头的注意力信息。
操作逻辑
- 输入:output_attentions_all 维度为 (layer, batch, head, seq_len, seq_len)<br />
- 关键参数:dim=2 表示对第2维(head维度)计算L2范数——即对每层、每个样本、每个“发出者-接收者”token对(i,j),将12个注意力头的权重作为向量,计算其L2范数( \(\sqrt{\sum_{h=1}^{12} w_{l,b,h,i,j}^2}\) )。
输出维度与含义
- 输出维度(norm_fn 后):(layer, batch, seq_len, seq_len)<br />
(因在 head 维度(dim=2)上聚合,故维度数从5维减少为4维,删除了 head 维度)
- 后续处理:squeeze().detach().cpu().numpy() 是张量格式转换操作,不改变维度含义:
- squeeze():去除维度大小为1的维度(此处 batch=1,故删除 batch 维度),最终维度变为 (layer, seq_len, seq_len);
- detach().cpu().numpy():将PyTorch张量转为NumPy数组,用于后续可视化。
最终维度
-
layer(维度索引 0,数值示例为 12) 与输入保持一致,表示 BERT 的 12 个编码器层。 -
seq_len(维度索引 1,行维度,数值示例为 26) 表示输入序列的长度,对应注意力的“发出者”(query)token,与原始注意力张量的第 3 维一致。 -
seq_len(维度索引 2,列维度,数值示例为 26) 同样表示输入序列的长度,对应注意力的“接收者”(key)token,与原始注意力张量的第 4 维一致。张量中每个元素 $[l,i,j]$ 表示在第 $l$ 层中,第 $i$ 个 token 对第 $j$ 个 token 的多头注意力权重汇总范数,用于刻画该 token 对在该层上的整体注意力强度,而不区分具体注意力头。