Lecture 6 - Simple and LSTM RNNs

작성자 작성일
신동수 2022.01.19

1. RNN(Recurrent Neural Network) Language Model

RNN을 훈련하는 법

\(x_{1}, x_{2}, ... ,x_{t}\)의 단어들을 준비한다.(big corpus of text)

RNN-LM에 단어들을 넣는다. 모든 단계 \(t\)마다 산출되는 분포 \(\hat{y}^{(t)}\)를 계산한다. (모든 단어에 대해서 해당 단어 다음에 올 단어의 확률을 계산한다.) \(t\)단계에서의 loss function은 예측한 확률 분포인 \(\hat{y}^{(t)}\)와 실제의 다음 단어 \(y^{(t)}\) (\(x^{(t+1)}\)의 one-hot vector)간의 cross-entropy를 사용한다. 만약 \(\hat{y}^{(t)}\)가 \(x^{(t+1)}\)에 확률 1을 주지 않았다면, loss를 적용하는 것이다.

각 단계에서의 loss들의 평균을 내어 전체 training set의 overall loss를 구한다.

이러한 RNN의 학습 과정은 teacher forcing 기법이 사용되는데, 이는 정해진 답인 corpus를 input으로 사용하여 학습하는 것이다.

하지만 전체 corpus의 loss와 grandients를 계산하는 것은 힘들기 때문에 실제로는 하나의 문장이나, 문장의 묶음(ex:32개) 단위로 계산하고, weight를 업데이트하는 것을 반복한다. (stochastic gradient descent가 데이터의 small chunk의 loss와 gradient를 계산하고, 업데이트하는 것을 가능하게 해준다.)

RNN에서 역전파할 때 구하는 \(J^{(\theta)}\)를 \(W_{h}\)에 대해 편미분했을 때의 결과인 gradient는 다음과 같고,

이는 \(t\)부터 그 이전에 대한 모든 단계에서의 gradient를 합한 것과 같다. (이 때, 모든 단계에서의 upstream gradient가 다르기 때문에 저 값이 \(J^{(t)}\)의 \(W_{h}*t\)가 아님을 유의한다)

역전파를 할 때, 각 단계에서 업데이트 값들의 총합을 나중에 전체적으로 적용하는데, 이 때 \(W_{h}\)는 일정한 행렬이어야 하므로 \(W_{h}\)를 업데이트하는 값의 합은 각 역전파단계에서 \(W_{h}\)에 바로 적용하는 것이 아니라 따로 저장한다. 그리고 단어들의 수가 너무 많으면 맨 끝 단어부터 처음 단어까지 역전파하는 시간이 너무 오래걸리기 때문에, truncated backpropagation through time이라는 알고리즘을 사용한다. 예를 들면 역전파를 20번만 진행하여, 20번만큼의 gradient들의 합을 적용시킨다.

RNN으로 문장을 만드는 방법

맨 처음에는 토큰을 입력하여, 각 단계별로 예측확률분포 \(\hat{y}\)을 계산하고, 이 중에서 하나(sample)을 선택한다. 그리고 이 sample을 다음 단계의 input으로 넣고 계산하며, 이 과정을 반복하는 것을 repeated sampling이라고 한다.

예시 -> 각각 해리포터와 레시피책을 학습하여 만들어낸 글이다.

Language Model의 성능을 평가하는 방법

Language Model의 성능을 평가하는 표준 기준은 Perplexity이다.

Perplexity는 확률들의 역수의 기하평균이라고 할 수 있다. 그리고 Perplexity는 cross-entropy loss \(J^{(\theta)}\)의 exponential과 같고, 그러므로 Perplexity가 낮은 것이 좋은 성능을 뜻한다.

Language modeling이 중요한 이유

Language modeling은 NLP에 있어서 대부분의 task를 해결하는 과정에 포함된다. 특히 텍스트를 만들어내고, 텍스트의 확률을 예측하는 것이 필요한 과정에 중요하다.

(ex: Predictive typing, Speech recognition, Handwriting recognition, Spelling/grammar correction, Authorship identification, Machine translation, Summarization, Dialogue 등)

2. Other uses of RNNs

RNN이 Language model을 만들 때 유용하게 사용되는 것을 앞에서 확인했는데, RNN은 다른 곳에도 활용된다.

Sequence tagging

텍스트의 단어들을 분류하는 데 활용된다. (예시는 the를 determiner, startled를 adjective, cat을 noun으로 tagging한 것이다.)

Sentiment classification

문장의 정서를 분류할 때, RNN의 각 과정에서의 hidden state를 sentence encoding 하는 데에 사용한다. 마지막 단계에서의 hidden state만을 sentence encoding하는 데에 사용할 수도 있다(마지막 hidden state가 앞 단계들의 영향을 받은 것이므로).

Language encoder module

Question answering에 해당하는 예시에서처럼, 질문의 representation을 RNN을 통해서 얻을 수 있다. RNN의 language encoder module로서의 역할은 machine translation 등 다른 여러 상황에서 사용되기도 한다.

Generating text

앞에서 RNN이 encoding을 했던 것과 다르게, 예시에서처럼 input이 음성신호를 decoding하여 문장을 생성하는데 RNN이 사용된다. 이외에도 machine translation, summarization에서 RNN이 문장을 생성하는데 사용된다.

3. Exploding and vanishing gradients

그런데 RNN에는 몇 가지 문제가 존재한다(이러한 문제점 때문에 더 발전된 RNN이 고안된다).

Vanishing gradients

gradient값이 작을 경우, 역전파하는 과정이 많아질수록 그 값이 점점 작아져, vanish하는 결과로 이어진다. 따라서 upstream gradient가 없어지기 때문에 parameter의 업데이트가 일어나지 않는 것이 문제가 된다.

조금 더 자세한 설명 ->

\(\sigma(x)\)가 identity function이면, 즉 \(\sigma(x)=x\)이면

\(i\)단계에서의 loss \(J^{(i)}(\theta)\) 와 \(i\)보다 이전의 단계인 \(j\)에서의 hidden state \(h^{(j)}\)에 대해 위와 같다.\((l=i-j)\) 그리고 여기서 \(W_h\)가 작으면, \(l\)이 커질수록 gradient가 작아지게 되는 것이므로 문제가 되는 것이다.

-> 여기서 \(W_h\)가 작은 경우의 하나로, \(W_h\)의 eigenvalue들이 전부 1보다 작을 때를 생각해보자.

왼쪽 항을 \(W_h\)의 eigenvector인 \(q_1,q_2,...,q_n\)들을 기저로 하여 오른쪽과 같이 표현할 수 있는데, 이때 eigenvalue들이 전부 1이하인데, \(l\)이 커지면 gradient가 사라지는 것이다.

그래서 gradient가 vanish하는 것은 가까운 단계에서의 정보는 잘 전달되게 하지만, 멀리 떨어진 단계에서의 정보는 전달되기 힘들게 한다.

Exploding gradients

Gradient가 너무 작아져서 생기는 문제를 알아봤는데, 반대로 gradient가 너무 커져서 생기는 문제도 있다.

gradient가 너무 크면 업데이트도 너무 크게 일어나고, 이는 원하는 결과로 이어지지 않을 수 있으며, 계산 과정에 INF나 NaN이 들어가 학습을 맨 처음부터 다시 해야할 수도 있다.

-> 이에 대한 해결 방안

Gradient clipping

gradient의 norm이 정해둔 기준인 threshold보다 크면, SGD 업데이트를 하기 전에 그 크기를 줄인다(크기만 줄고 진행하는 방향은 같다).

4. LSTM RNN

Simple RNN은 각 time step마다 hidden state가 \(W_h\)에 의해 곱해지고 \(\sigma\)함수에 들어가는 등 계속 변화한다.

-> 이를 해결하기 위해 RNN이 독립적인 메모리를 가지고 정보들을 보존하는 법을 떠올린 것이 LSTM(Long Short-Term Memory)이다.

LSTM RNN은 기존의 RNN이 hidden vector 인 hidden state \(h(t)\)를 갖고 있는 것에 추가로, cell state \(c(t)\)를 같이 사용한다(두 state 모두 길이 \(n\)의 벡터이다). Cell이 long-term information을 저장하고, LSTM은 마치 컴퓨터의 RAM처럼 cell로부터 정보를 읽고, 지우고, 작성할 수 있다.

-> 어떤 정보를 읽거나 지우거나 작성할 지는 probabilistic gates에 의해 결정된다. 이 gate들 또한 길이가 \(n\)인 벡터이고, 모델의 각 단계에서 gate vector를 위한 state를 계산한다. 이 state는 open(1), closed(0),혹은 0과 1사이의 값이 가능하며, 이 state에 따라 정보를 얼마나 읽거나 지우거나 작성할지 결정한다.

-> gate는 dynamic하며, 그 값은 current context에 기반하여 계산됨

forget gate \(f^{(t)}\) : 이전 cell state에서 무엇을 남길지와 지울지 결정

input gate \(i^{(t)}\) : new cell content에서 cell에 새롭게 작성될 부분이 무엇인지 결정

output gate \(o^{(t)}\) : cell의 어떤 부분이 hidden state로의 output이 될지 결정

\(\sigma\) : sigmoid function

new cell content \(\tilde{c}^{(t)}\) : cell에 새로 작성될 내용의 candidate이다.

cell state \(c^{(t)}\) : 이전 cell state로부터 지울 것은 지우고, new cell content로부터 새롭게 cell에 작성할 부분들을 작성한다.

hidden state \(h^{(t)}\) : cell로부터 일부 내용을 output으로 가져온다.

LSTM의 과정을 그림으로 나타내면 이와 같다. 여기서 주목할 점은 cell state에 new cell content로부터 새로 작성하는 부분이 덧셈으로 이뤄진다는 것이다. 이는 hidden state에서 곱셈을 이뤄지기 때문에 멀리 있는 단계에서 정보를 배우기 매우 힘든 것에 비해 수월하게 정보를 보존할 수 있다.

-> 따라서 simple RNN은 대략 7번의 단계까지 정보를 보존할 수 있지만 LSTM RNN은 100번 정도까지 보존할 수 있음

Vanishing/exploding gradients문제는 비단 RNN에서만 일어나는 문제는 아니다. FFNN, CNN과 같은 모든 신경망, 특히 깊이가 깊은 신경망에서 일어날 수 있다.

-> chain rule / choice of nonlinearity function 때문에 역전파 과정에서 gradient가 매우 작아지기 때문

-> 이를 해결하기 위해서 최근의 깊은 신경망들은 direct connection들을 추가한다.

ResNet                         DenseNet                    HighwayNet

따라서 vanishing/exploding gradient문제는 일반적인 문제지만, RNN은 같은 weight matrix를 반복적으로 곱하기 때문에 특히나 불안정한 것이다.

5. Bidirectional RNN

지금까지 본 RNN의 경우에는, 한 단어의 contextual representation이라고 할 수 있는 hidden state가 해당 단어의 왼쪽 단어들의 맥락에 관한 정보만 포함한다는 문제점이 있다. -> 이를 해결하기 위한 것이 bidirectional RNN이다.

양방향의 RNN은 한 단어에 대해 왼쪽에서의 맥락과 오른쪽에서의 맥락을 동시에 고려할 수 있다.

여기서 forward RNN이 우리가 앞에서 살펴본 simple RNN, LSTM RNN등에 해당한다. Forward RNN과 backward RNN은 보통 별개의 weight를 사용한다. 그리고 concatenated hidden state가 contextual representation으로서 신경망의 다른 부분으로 넘겨진다.

results matching ""

    No results matching ""