RNN과 LSTM
LSTM 은 RNN의 문제를 해결하기 위해 나온 방식중 하나이다. RNN 또한 망이 깊어지고 이전의 정보가 멀어질 경우 역전파시에 그래디언트가 줄어들어 학습능력이 저하 되는 것을 피할수 없었다. 이러한 문제를 우리는 gradient vanishing 이라고 했다.
일반적인 RNN의 경우에는 tanh 연산만 진행을 하였다면 RNN의 히든 state에 cell-state를 추가했다고 한다. 이러한 이유로인해 꽤 오랜 시간이 경과된다 하더라도 정확하게 전파를 할수 있는 모델을 만들었다고 한다.
LSTM도 RNN과 같은 체인 구조로 되어 있지만, 반복 모듈은 단순한 한 개의 tanh layer가 아닌 4개의 layer가 서로 정보를 주고받는 구조로 되어 있다. LSTM 셀에서는 상태(state)가 크게 두 개의 벡터로 나누어집니다.(위의 그림을 보면 알수 있다.)
ht를 단기 상태(short-term state), ct를 장기 상태(long-term state)라고 볼 수 있습니다.LSTM 셀의 수식은 아래와 같습니다. ⊙는 요소별 곱셈을 뜻하는 Hadamard product 연산자입니다.
1. Cell state
cell state는 정보가 바뀌지 않고 그대로 흐르도록 하는 역할을 한다.(ct를 장기 상태(long-term state))
2. Forget gate
Forget gate는 cell state에서 sigmoid layer를 거쳐 어떤 정보를 버릴 것인지 정합니다. (0일경우 버리고 1일 경우 전달을 한다고 생각을 하면 된다.) + ‘과거 정보를 잊기’를 위한 게이트입니다. ht−1과 xt를 받아 시그모이드를 취해준 값이 바로 forget gate가 내보내는 값이 됩니다. 시그모이드 함수의 출력 범위는 0에서 1 사이이기 때문에 그 값이 0이라면 이전 상태의 정보는 잊고, 1이라면 이전 상태의 정보를 온전히 기억하게 됩니다. (추가 설명이다)
3. Input gate
Input gate는 앞으로 들어오는 새로운 정보 중 어떤 것을 cell state에 저장할 것인지를 정합니다. 먼저 sigmoid layer를 거처 어떤 값을 업데이트할 것인지를 정한 후 tanh layer에서 새로운 후보 Vector를 만듭니다.
+ ‘현재 정보를 기억하기’ 위한 게이트입니다. ht−1ht−1과 xtxt를 받아 시그모이드를 취하고, 또 같은 입력으로 하이퍼볼릭탄젠트를 취해준 다음 Hadamard product 연산을 한 값이 바로 input gate가 내보내는 값이 됩니다. 개인적으로 it의 범위는 0~1, Ct^의 범위는 -1~1이기 때문에 각각 강도와 방향을 나타낸다고 이해했습니다.(내가 참고한 블로그의 분은 각각의 출력 값으로 강도와 방향을 나타낸다고 이해 했다고 하신다. 이 부분이 어떤 영향을 준다는 거지...?)
4. Cell state update
이전 gate에서 버릴 정보들과 업데이트할 정보들을 정했다면, Cell state update 과정에서 업데이트를 진행합니다. 기존에 계산해온 것들을 Cell state update에 붙인다.
5. Output gate
Output gate는 어떤 정보를 output으로 내보낼지 정하게 됩니다. 먼저 sigmoid layer에 input data를 넣어 output 정보를 정한 후 Cell state를 tanh layer에 넣어 sigmoid layer의 output과 곱하여 output으로 내보냅니다.
이 부분에서는 처음 들어온 입력을 시그모이드 연산을 진행한 후 0~1의 값을 곱한다. 근데 여기서 궁금한게...? Ct 값은 어디서 가져오는 거지? 라는 의문이 생겼다. 아래의 그림을 보니 이해가 갔다. Ct 의 연산은 마무리가 되고 아래로 전달 된는 것이었다...! 그러고 나면 최종적인 ht 가설 값이 전달 된다.
LSTM의 역전파를 작성해 볼거다.
사실 항상 이런건 그냥 코드에 맡겼기 때문에 직접 작성하는 것은 처음이라 꼼꼼히 곱씹어 봐야겠다.
이렇게 흘러가는 데 벌써 어지럽다!
우선 역전파란 지금의 파라미터가 결과에 어떻게 영향을 미칠까를 중요하게 생각하며된다고 배웠다,
따라서 chain rule로 계속해서 연결되면서 하면 쉽게 해결했다.
https://ratsgo.github.io/deep%20learning/2017/05/14/backprop/
오차 역전파 (backpropagation) · ratsgo's blog
이번 글에서는 오차 역전파법(backpropagation)에 대해 살펴보도록 하겠습니다. 이번 글은 미국 스탠포드대학의 CS231n 강의를 기본으로 하되, 고려대학교 데이터사이언스 연구실의 김해동 석사과정
ratsgo.github.io
이 부분은 꼭 공부하길 바란다. 이 분만큼 잘 정리해두신 분이 없는거 같다. 덧셈 노드와 곱셈 노드가 찢어지는걸로 보면 바로바로 이해가 가능하다.
이렇게 LSTM에 대하여 공부를 해보았다.
이러한 LSTM 의 등장 배경에는 RNN의 기울기 소실로 등장하여 cell state가 forget gate와 input gate의 역할을 잘했기 때문에 잘 전달 되는 성능을 나타내었다는 것을 배웠다.
이걸로 끝!
*****참고 자료 *****
학습에 도움을 주신분들 감사합니다. 꾸벅( __ __ )
https://ratsgo.github.io/natural%20language%20processing/2017/03/09/rnnlstm/