pytorch面试题:实现attention结构
本文最后更新于 2024年7月8日凌晨3点37分
pytorch面试题I:transformer中重要模块
transformer中的attention机制很重要,面试中可能会让你手动实现attention。
这里记录了transformer架构会考的重要知识点:
- pytorch手动搭建ScaledDotProduct Attention;
- pytorch搭建multi-head attention;
- pytorch搭建self-attention;
- 基于numpy的位置编码的实现;
首先import所需的库:1
2
3
4import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
1.单头注意力机制
使用pytorch实现Scaled Dot Product Attention:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36class ScaledDotProductAttention(nn.Module):
""" Scaled Dot-Product Attention """
def __init__(self, scale):
super().__init__()
self.scale = scale
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask=None):
u = torch.bmm(q, k.transpose(1, 2)) # 1.Matmul
u = u / self.scale # 2.Scale
if mask is not None:
u = u.masked_fill(mask, -np.inf) # 3.Mask # mask为1的部分设置为-np.inf
attn = self.softmax(u) # 4.Softmax
output = torch.bmm(attn, v) # 5.Output
return attn, output
if __name__ == "__main__":
n_q, n_k, n_v = 2, 4, 4
d_q, d_k, d_v = 128, 128, 64
batch = 4
q = torch.randn(batch, n_q, d_q)
k = torch.randn(batch, n_k, d_k)
v = torch.randn(batch, n_v, d_v)
mask = torch.zeros(batch, n_q, n_k).bool()
attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
attn, output = attention(q, k, v, mask=mask)
print(attn)
print(output)
2.多头注意力机制
1 |
|
3.自注意力机制
1 |
|
4.位置编码
1 |
|