本项目使用PyTorch搭建Transformer模型,并给出尽可能详细的代码注释。在阅读代码之前,请先熟读经典论文Attention Is All You Need以及关于Transformer的解释。
关于代码的细节,也可以参考Transformer in PyTorch。
-
param.py
设置了相关参数。 -
utils.py
包括四个帮助函数:自注意力的计算、掩码的生成、克隆函数。 -
embeddings.py
实现了词嵌入。 -
positional_encoding.py
实现了添加位置编码。 -
multi_head_attn.py
是模型中最核心的部分,实现了多头注意力机制。 -
pos_feed_forward.py
实现了前馈网络。 -
encoder_layer.py
与decoder_layer.py
分别实现了编码器和解码器,包括了多头注意力机制和前馈网络。 -
encoder.py
与decoder.py
实现了编码器和解码器的堆栈,以及编码器和解码器之间的互动。 -
transformer.py
实现了整一个Transformer模型。
- 运行
easy_example.py
即可训练一个简易的德英翻译模型。
-
简易例子的测试
-
复杂例子