Skip to content

Commit 594e06b

Browse files
committed
docs: 初步添加 attention 算子定义
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 39538e7 commit 594e06b

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

src/08-01llm/README.md

+36
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,39 @@ y = (x^2 + δ)^(-1/2) * w * x
2525
1 Output:
2626

2727
- **Y(heterogeneous) - T**: 输出张量。形状与 `X` 相同。
28+
29+
## Attention
30+
31+
### Summary
32+
33+
Multi-head Self Attention 的封装形式,用于 transformer 模型。
34+
35+
支持使用 kv cache,使用条件由输入和属性综合决定。有以下 6 种情况:
36+
37+
| 序号 | 输入数量 | `max_seq_len` | 使用 kv cache | 输出数量 | cache s 维度 | 备注
38+
|:-:|:-:|:-----:|:-------:|:-:|:------------------------:|:-
39+
| 1 | 3 | 0 | none | 1 | - |
40+
| 2 | 3 | S > 0 | init | 3 | `S` | `assert(S >= seq_len)`
41+
| 3 | 4 | 0 | inplace | 3 | `past_seq_len + seq_len` | `past_seq_len` 必须是常量
42+
| 4 | 4 | S > 0 | inplace | 3 | `S` | `assert(S >= past_seq_len + seq_len)`
43+
| 5 | 6 | 0 | copy | 3 | `past_seq_len + seq_len` | `past_seq_len` 必须是常量
44+
| 6 | 6 | S > 0 | copy | 3 | `S` | `assert(S >= past_seq_len + seq_len)`
45+
46+
### Attributes
47+
48+
- **max_seq_len - INT** (default is `0`): 最大序列长度,用于初始化 kv cache。
49+
50+
### Inputs
51+
52+
- **query(heterogeneous) - T**: 形状为 `N x n_head x seq_len x head_dim`
53+
- **key(heterogeneous) - T**: 形状为 `N x n_kv_head x seq_len x head_dim`
54+
- **value(heterogeneous) - T**: 形状为 `N x n_kv_head x seq_len x head_dim`
55+
- **past_seq_len(optional) -int64**: 要连接的历史序列长度,必须为标量。不使用 kv cache 时留空。
56+
- **k_cache(optional, heterogeneous) -T**: k 缓存的初始值,形状为 `N x n_kv_head x s x head_dim``s` 为不小于 `past_seq_len` 的任意值。不使用或不重置 kv cache 时留空。
57+
- **v_cache(optional, heterogeneous) -T**: v 缓存的初始值,形状为 `N x n_kv_head x s x head_dim``s` 为不小于 `past_seq_len` 的任意值。不使用或不重置 kv cache 时留空。
58+
59+
### Outputs
60+
61+
- **output(heterogeneous) - T**: 形状与 `query` 相同。
62+
- **k_cache(optional, heterogeneous) - T**: 形状为 `N x n_kv_head x s x head_dim``s` 的值根据 `Summary` 的描述计算。
63+
- **v_cache(optional, heterogeneous) - T**: 形状为 `N x n_kv_head x s x head_dim``s` 的值根据 `Summary` 的描述计算。

0 commit comments

Comments
 (0)