정리 노트

36일 차(2022/08/22) 본문

[TIL]국민대X프로그래머스 여름방학 인공지능 과정

36일 차(2022/08/22)

꿈만 꾸는 학부생 2022. 8. 22. 17:58
728x90

Attention

query와 비슷한 값을 가진 key를 찾아 value를 얻는 과정입니다. 여기서 key, value는 encoder의 각 time-step 별 출력(각 source language의 단어 또는 문장)을 의미하고 query는 현재 time-step의 decoder 출력(target language로 번역된 단어 또는 문장)을 의미합니다.

참고: https://hazel01.tistory.com/45

Attention 아키텍처

하나의 Attention은 전체 토큰에 대한 출력을 입력으로 받는 FC의 파라미터를 공유해 사용합니다. 전체 encoder의 출력 + 현재 decoder의 hidden이 decoder의 hidden으로 되고 이게 실제 Attention의 값입니다.

import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self, enc_hidden_dim, dec_hidden_dim):
        """
        :param enc_hidden_dim: Encoder의 hidden state 차원  
        :param dec_hidden_dim: Decoder의 hidden state 차원
        """
        super().__init__()
        
        self.attn = nn.Linear((enc_hidden_dim * 2) + dec_hidden_dim, dec_hidden_dim)
        self.v = nn.Linear(dec_hidden_dim, 1, bias=False)
        
    def forward(self, hidden, enc_outputs):
        """
        :param hidden: 현재까지의 모든 단어의 정보([배치 크기, 히든 차원])
        :param enc_outputs: 전체 단어의 출력 정보([단어 개수, 배치 크기, encoder hidden 차원 * 방향 수])
        """
        batch_size = enc_outputs.shape[1]
        src_len = enc_outputs.shape[0]
        
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        # hidden: [배치 크기, 단어 개수, decoder hidden 차원]
        enc_outputs = enc_outputs.permute(1, 0, 2)
        # enc_outputs: [배치 크기, 단어 개수, encoder hidden 차원 * 방향 수]
        
        energy = torch.tanh(self.attn(torch.cat((hidden, enc_outputs), dim=2)))
        # energy: 현재 어떤 단어를 출력하기 위해 source 문장에서 어떤 단어에 초점을 둘 필요가 있는지 수치화한 값([배치 크기, 단어 개수, decoder hidden 차원])
        
        attention = self.v(energy).squeeze(2)
        # attention: 실제 각 단어에 대한 attention 값들([배치 크기, 단어 개수])
        
        return F.softmax(attention, dim=1)

Seq2Seq with Attention 아키텍처

  • Encoder: 주어진 source 문장을 context vector로 encoding
  • Decoder: 주어진 context vector를 target 문장으로 decoding
  • 여기서 decoder는 한 단어씩 넣어서 한 번씩 결과를 구합니다. 그리고 decoder는 context vector뿐만 아니라 encoder의 모든 출력을 참고해 attention을 진행합니다.
  • Teacher forcing: decoder의 예측을 다음 입력으로 사용하지 않고, 실제 목표 출력을 다음 입력으로 사용하는 기법입니다.

 

728x90

'[TIL]국민대X프로그래머스 여름방학 인공지능 과정' 카테고리의 다른 글

37일 차(2022/08/23)  (0) 2022.08.23
35일 차(2022/08/19)  (0) 2022.08.19
34일 차(2022/08/18)  (0) 2022.08.19
33일 차(2022/08/17)  (0) 2022.08.17
32일 차(2022/08/16)  (0) 2022.08.17