11.5. 多头注意力

在实践中,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。 因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。

为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的\(h\)组不同的 线性投影(linear projections)来变换查询、键和值。 然后,这\(h\)组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这\(h\)个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为多头注意力(multihead attention) ()。 对于\(h\)个注意力汇聚输出,每一个注意力汇聚都被称作一个(head)。 图11.5.1 展示了使用全连接层来实现可学习的线性变换的多头注意力。

../_images/multi-head-attention.svg

图11.5.1 多头注意力:多个头连结然后线性变换

11.5.1. 模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。 给定查询\(\mathbf{q} \in \mathbb{R}^{d_q}\)、 键\(\mathbf{k} \in \mathbb{R}^{d_k}\)和 值\(\mathbf{v} \in \mathbb{R}^{d_v}\), 每个注意力头\(\mathbf{h}_i\)\(i = 1, \ldots, h\))的计算方法为:

(11.5.1)\[\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},\]

其中,可学习的参数包括 \(\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}\)\(\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}\)\(\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}\), 以及代表注意力汇聚的函数\(f\)\(f\)可以是 11.3节中的 加性注意力和缩放点积注意力。 多头注意力的输出需要经过另一个线性转换, 它对应着\(h\)个头连结后的结果,因此其可学习参数是 \(\mathbf W_o\in\mathbb R^{p_o\times h p_v}\)

(11.5.2)\[\begin{split}\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.\end{split}\]

基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。

import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from d2l import mlx as d2l

11.5.2. 实现

在实现过程中通常选择缩放点积注意力作为每一个注意力头。 为了避免计算代价和参数代价的大幅增长, 我们设定\(p_q = p_k = p_v = p_o / h\)。 值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 \(p_q h = p_k h = p_v h = p_o\), 则可以并行计算\(h\)个头。 在下面的实现中,\(p_o\)是通过参数num_hiddens指定的。

#@save
class MultiHeadAttention(d2l.Module):  #@save
    """Multi-head attention."""
    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        self.W_k = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        self.W_v = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def __call__(self, queries, keys, values, valid_lens):
        # Shape of queries, keys, or values:
        # (batch_size, no. of queries or key-value pairs, num_hiddens)
        # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
        # After transposing, shape of output queries, keys, or values:
        # (batch_size * num_heads, no. of queries or key-value pairs,
        # num_hiddens / num_heads)
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))

        if valid_lens is not None:
            # On axis 0, copy the first item (scalar or vector) for num_heads
            # times, then copy the next item, and so on
            valid_lens = np.array(valid_lens)
            valid_lens = mx.array(np.repeat(
                valid_lens, repeats=self.num_heads, axis=0))

        # Shape of output: (batch_size * num_heads, no. of queries,
        # num_hiddens / num_heads)
        output = self.attention(queries, keys, values, valid_lens)
        # Shape of output_concat: (batch_size, no. of queries, num_hiddens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)

为了能够使多个头并行计算, 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作。

#@save
@d2l.add_to_class(MultiHeadAttention)  #@save
def transpose_qkv(self, X):
    """为了多注意力头的并行计算而变换形状

    Defined in :numref:`sec_multihead-attention`"""
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
    # num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    X = X.transpose(0, 2, 1, 3)
    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

@d2l.add_to_class(MultiHeadAttention)  #@save
def transpose_output(self, X):
    """逆转transpose_qkv函数的操作

    Defined in :numref:`sec_multihead-attention`"""
    X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
    X = X.transpose(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)

attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5)
  )
  (W_q): Linear(input_dims=100, output_dims=100, bias=False)
  (W_k): Linear(input_dims=100, output_dims=100, bias=False)
  (W_v): Linear(input_dims=100, output_dims=100, bias=False)
  (W_o): Linear(input_dims=100, output_dims=100, bias=False)
)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = mx.array([3, 2])
X = mx.ones((batch_size, num_queries, num_hiddens))
Y = mx.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
(2, 4, 100)

11.5.3. 小结

  • 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。

  • 基于适当的张量操作,可以实现多头注意力的并行计算。

11.5.4. 练习

  1. 分别可视化这个实验中的多个头的注意力权重。

  2. 假设有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢?