정리 노트

LSTM(Long Short Term Memory) 본문

개념 정리/머신러닝 & 딥러닝 & A.I

LSTM(Long Short Term Memory)

꿈만 꾸는 학부생 2022. 12. 13. 01:42
728x90

이 포스트는 국민대학교 소프트웨어학부 '인공지능' 강의를 듣고 요약하는 포스트입니다. 원하시는 정보가 없을 수도 있습니다. 이 점 유의 바랍니다. 오류 지적은 매우 환영합니다!


LSTM은 RNN에 대한 기초적인 지식을 요구하기 때문에 RNN에 관련한 글을 읽고 이 글을 읽으시는 것을 추천드립니다. 제가 저번에 썼던 글을 읽으셔도 괜찮습니다.

2022.12.12 - [개념 정리/머신러닝 & 딥러닝 & A.I] - RNN(Recurrent Neural Network)

LSTM을 사용하는 이유

저희에게 아주 긴 문장을 RNN의 입력으로 넣는다고 합시다. 예를 들어 "글쓴이는, 어제는 친구와 PC방을 다녀왔고, 그저께는 마트에 가서 라면을 사 오고, 그 전날은 대학교 친구들과 함께 밤을 새우면서 과제를 같이 했기 때문에 오늘은 집에서 푹 쉬기로 작정하였다."라는 문장이 있다고 합시다. 이 문장에서 '글쓴이는'과 '쉬기로'는 아주 밀접한 관련을 갖고 있는 것을 알 수 있습니다. 아주 밀접한 관련을 가지고 있지만 서로 간의 위치가 멀리 있습니다. 이렇듯 관련된 요소가 멀리 떨어져 있는 상황을 장기 문맥 의존성이라고 합니다.(https://velog.io/@peterpictor/%EB%94%A5%EB%9F%AC%EB%8B%9D-RNN%EC%9E%A5%EA%B8%B0-%EB%AC%B8%EB%A7%A5-%EC%9D%98%EC%A1%B4%EC%84%B1)

그리고 RNN에서도 gradient vanishing(경사 소멸) 또는 gradient exploding(경사 폭발) 문제점이 있습니다. 특히 RNN은 긴 입력 샘플이 자주 발생하고, 가중치를 공유하고 있어서 역전파 과정에서 같은 가중치를 계속 곱하기 때문에 CNN 같은 구조보다 심각합니다.

이러한 문제점을 해결하기 위해 사용되는 것이 LSTM입니다.

LSTM의 구조

LSTM의 구조를 RNN과 비슷하게 그려보면 다음과 같이 그려볼 수 있습니다.

LSTM을 RNN과 비슷하게 손으로 그린 그림

LSTM에서도 RNN과 동일하게 파라미터를 공유한다는 특징을 가지고 있습니다.

 

LSTM의 핵심적인 요소들은 크게 4가지에 집중해서 볼 수 있습니다.

  • 메모리 블록(cell): hidden state 장기 기억
  • 망각 개폐구(forget gate): 기억을 유지 혹은 제거(1: 유지, 0: 제거)
  • 입력 개폐구(input gate): 입력 연산
  • 출력 개폐구(output gate): 출력 연산

망각 개폐구(forget gate)

forget gate 구간

t 순간에서, 망각 개폐구는 t - 1 순간까지 기억하고 있던 정보( \( c_{t - 1} \) )를 t 순간에서 어느 정도 기억할지 결정하는 구간입니다. 결정 요인은 t 순간에 받은 입력 \( x_t \)와 t - 1 순간의 hidden state인 \( h_{t - 1} \)에 의해 결정됩니다. 입력은 입력과 forget gate를 연결하는 행렬( \( U^f \) )과 곱해지고, hidden state는 hidden state와 forget gate를 연결하는 행렬 ( \( W^f \) )과 곱해집니다. 곱해진 두 결과와 forget gate에서 사용되는 편향( \( b^f \) )을 더하고 더한 결과에 sigmoid 함수(그림에서 \( \sigma \) )를 적용시킵니다. 여기까지의 결과를 \( f^t \)라고 하면 \( f^t \)를 수식으로 다음과 같이 나타낼 수 있습니다.

$$ f_t = \text {sigmoid} \left( U^fx_t + W^fh_{t - 1} + b^f \right) $$

시그모이드 함수를 통과하였으므로 값은 0에서 1 사이입니다. 1에 가까울수록 온전히 기억하게 할 수 있고, 0에 가까울수록 완전히 잊어버리게 할 수 있습니다. 따라서 forget gate를 통과한 \( c_{t - 1} \)는 아래의 수식에 의해 계산됩니다.

$$ c_{t - 1} = c_{t - 1} \odot f_t $$

\( \odot \) 기호는 element-wise 곱셈을 의미합니다.

입력 개폐구(input gate)

input gate 구간

t 순간에서, 입력 개폐구는 t 순간에 받은 입력을 어느 만큼 cell에 저장할지 결정하는 구간입니다. 왼쪽의 파란 화살표는 forget gate에서 얻은 \( c_t \)입니다. 오른쪽 화살표는 \( U^fx_t + W^fh_{t - 1} + b^f \) 일 것 같지만 다른 값입니다. 입력 개폐구에서는 hidden state와 입력에게 다른 행렬들이 연결됩니다. sigmoid 함수를 거칠 때, tanh 함수를 거칠 때도 각자 다른 행렬들이 사용됩니다.

먼저 sigmoid 함수일 때를 보겠습니다. 입력 개폐구와 입력과 연결되는 행렬은 \( U^i \), 입력 개폐구와 hidden state와 연결되는 행렬은 \( W^i \), 계산 때 사용되는 편향은 \( b^i \)입니다. Sigmoid 함수(그림에서 \( \sigma \) )를 거친 값을 \( i_t \)라 하면 다음과 같이 계산됩니다.

$$ i_t = \text {sigmoid} \left( U^ix_t + W^ih_{t - 1} + b^i \right) $$

시그모이드 함수를 거친 결과는 0~1 이므로 0에 가까울수록 t 순간의 입력은 거의 반영되지 않을 것입니다. 이런 식으로 입력 개폐구의 개폐를 조절하는 역할을 수행하게 됩니다.

 

다음 tanh 함수일 때를 보겠습니다. 입력 개폐구와 입력과 연결되는 행렬은 \( U^c \), 입력 개폐구와 hidden state와 연결되는 행렬은 \( W^c \), 계산 때 사용되는 편향은 \( b^c \)입니다. Tanh 함수를 거친 값을 \( c'_t \)라 하면 다음과 같이 계산됩니다.

$$ c'_t = \text {tanh} \left( U^cx_t + W^ch_{t - 1} + b^c \right) $$

\( i_t \)와 \( c'_t \) 값을 곱한 결과를 forget gate를 통과한 cell과 더해서 t 순간에 받은 입력까지 cell에 저장합니다.

$$ c_t = c_{t - 1} \odot f_t + i_t \odot c'_t $$

출력 개폐구(output gate)

output gate 구간

t 순간에서, 출력 개폐구는 t 순간에 어떤 걸 출력으로 내보낼지 결정하는 구간입니다. 오른쪽에서 시그모이드를 향한 화살표에서는 예상하시는 대로 망각 개폐구, 입력 개폐구에서 쓰였던 행렬과 또 다른 행렬을 사용합니다. 입력과 출력 개폐구와 연결되는 행렬은 \( U^o \), hidden state와 출력 개폐구와 연결되는 행렬은 \( W^o \), 편향은 \( b^o \)입니다. 여기까지의 결과를 \( o_t \)라고 하면 \( o_t \)는 아래의 수식으로 표현할 수 있습니다.

$$ o_t = \text {sigmoid} \left( U^ox_t + W^oh_{t - 1} + b^o \right) $$

시그모이드 함수를 거친 결과는 0~1 이므로 0에 가까울수록 t 순간의 출력이 거의 나오지 않을 것입니다. 이런 식으로 출력 개폐구의 개폐를 조절하는 역할을 수행하게 됩니다.

 

\( o_t \)와 \( c_t \)에 tanh 함수가 적용된 결과를 곱한 결과가 t 순간에서의 출력(빨간 화살표)이 됩니다.

$$ h_t = \text {tanh} \left( c_t \right) \odot o_t $$

\( h_t \)는 그림에서 보이듯이, t 순간에서의 출력으로 사용될 뿐만 아니라 t + 1 순간에서 input으로 들어가는 hidden state로 사용됩니다. \( c_t \)도 t + 1 순간에서 input으로 들어가는 cell로 사용됩니다.

참고 사이트

https://dgkim5360.tistory.com/entry/understanding-long-short-term-memory-lstm-kr

 

Long Short-Term Memory (LSTM) 이해하기

이 글은 Christopher Olah가 2015년 8월에 쓴 글을 우리 말로 번역한 것이다. Recurrent neural network의 개념을 쉽게 설명했고, 그 중 획기적인 모델인 LSTM을 이론적으로 이해할 수 있도록 좋은 그림과 함께

dgkim5360.tistory.com

 

728x90

'개념 정리 > 머신러닝 & 딥러닝 & A.I' 카테고리의 다른 글

RNN(Recurrent Neural Network)  (2) 2022.12.12
GAN  (0) 2022.12.11
확률적 경사 하강법의 변형들  (0) 2022.12.10
배치 정규화(Batch Normalization)  (0) 2022.12.04
Computational Graph(연산 그래프)  (0) 2022.11.25