현정천 님의 코드입니다.
GPT는 transformer의 decoder 부분을 사용한다.
때문에 기존 transformer에서 encdoder의 출력과 decoder의 출력을 attention해주는 두번째 multi-head-attention 부분을 제거해야 한다.
1. config
config = Config({
"n_dec_vocab": len(vocab),
"n_dec_seq": 256,
"n_layer": 6,
"d_hidn": 256,
"i_pad": 0,
"d_ff": 1024,
"n_head": 4,
"d_head": 64,
"dropout": 0.1,
"layer_norm_epsilon": 1e-12
})
print(config)
2. Decoder
""" decoder layer """
class DecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.self_attn = MultiHeadAttention(self.config)
self.layer_norm1 = nn.LayerNorm(self.config.d_hidn, eps=self.config.layer_norm_epsilon)
self.pos_ffn = PoswiseFeedForwardNet(self.config)
self.layer_norm3 = nn.LayerNorm(self.config.d_hidn, eps=self.config.layer_norm_epsilon)
def forward(self, dec_inputs, self_attn_mask):
# (bs, n_dec_seq, d_hidn), (bs, n_head, n_dec_seq, n_dec_seq)
self_att_outputs, self_attn_prob = self.self_attn(dec_inputs, dec_inputs, dec_inputs, self_attn_mask)
self_att_outputs = self.layer_norm1(dec_inputs + self_att_outputs)
# (bs, n_dec_seq, d_hidn)
ffn_outputs = self.pos_ffn(self_att_outputs)
ffn_outputs = self.layer_norm3(self_att_outputs + ffn_outputs)
# (bs, n_dec_seq, d_hidn), (bs, n_head, n_dec_seq, n_dec_seq), (bs, n_head, n_dec_seq, n_enc_seq)
return ffn_outputs, self_attn_prob
- decoder의 input들만 attention에 들어간다.
""" decoder """
class Decoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.dec_emb = nn.Embedding(self.config.n_dec_vocab, self.config.d_hidn)
sinusoid_table = torch.FloatTensor(get_sinusoid_encoding_table(self.config.n_dec_seq + 1, self.config.d_hidn))
self.pos_emb = nn.Embedding.from_pretrained(sinusoid_table, freeze=True)
self.layers = nn.ModuleList([DecoderLayer(self.config) for _ in range(self.config.n_layer)])
def forward(self, dec_inputs):
positions = torch.arange(dec_inputs.size(1), device=dec_inputs.device, dtype=dec_inputs.dtype).expand(dec_inputs.size(0), dec_inputs.size(1)).contiguous() + 1
pos_mask = dec_inputs.eq(self.config.i_pad)
positions.masked_fill_(pos_mask, 0)
# (bs, n_dec_seq, d_hidn)
dec_outputs = self.dec_emb(dec_inputs) + self.pos_emb(positions)
# (bs, n_dec_seq, n_dec_seq)
dec_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.config.i_pad)
# (bs, n_dec_seq, n_dec_seq)
dec_attn_decoder_mask = get_attn_decoder_mask(dec_inputs)
# (bs, n_dec_seq, n_dec_seq)
dec_self_attn_mask = torch.gt((dec_attn_pad_mask + dec_attn_decoder_mask), 0)
self_attn_probs = []
for layer in self.layers:
# (bs, n_dec_seq, d_hidn), (bs, n_dec_seq, n_dec_seq)
dec_outputs, self_attn_prob = layer(dec_outputs, dec_self_attn_mask)
self_attn_probs.append(self_attn_prob)
# (bs, n_dec_seq, d_hidn), [(bs, n_dec_seq, n_dec_seq)]
return dec_outputs, self_attn_probs
3. Decoder
- gpt는 단순히 transformer decoder를 실행한다.
- save와 load는 단순히 pretrain된 모델을 저장하고 읽기위한 함수.
""" gpt """
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.decoder = Decoder(self.config)
def forward(self, dec_inputs):
# (bs, n_seq, d_hidn), [(bs, n_head, n_dec_seq, n_dec_seq)]
dec_outputs, dec_self_attn_probs = self.decoder(dec_inputs)
# (bs, n_dec_seq, n_dec_vocab), [(bs, n_head, n_dec_seq, n_dec_seq)]
return dec_outputs, dec_self_attn_probs
def save(self, epoch, loss, path):
torch.save({
"epoch": epoch,
"loss": loss,
"state_dict": self.state_dict()
}, path)
def load(self, path):
save = torch.load(path)
self.load_state_dict(save["state_dict"])
return save["epoch"], save["loss"]
'코드 리뷰' 카테고리의 다른 글
Transformer 코드 리뷰 (0) | 2024.01.14 |
---|