Generative Adversarial Nets
GAN: Generative Adversarial Nets
Introduction
기존 생성 모델은 주로 최대우도추정(Maximum Likelihood Estimation, MLE)를 기반으로 학습을 하였다. 모델이 생성하는 샘플의 확률 분포 $p_{\text{model}}$이 진짜 데이터의 확률 분포 $p_{\text{data}}$에 가까워지도록 수학적으로 계산하는 것이 목표였다.
하지만 이미지와 같은 고차원 데이터에서 이를 하기 위해선, 복잡한 적분 연산을 해야 하는데, 이 적분을 계산하는게 거의 불가능하기 때문에 근사하는 것 또한 어려웠다.
또한 생성 모델들은 확률 분포에 의한 제약으로, 내부 레이어에 ReLU와 같은 선형 유닛들을 사용하기가 어려웠다.
GAN(Generative Adversarial Nets)은 확률 분포에 대한 내용은 아에 던져버리고, 생성형 모델 $G$와 판별 모델 $D$ 를 동시에 경쟁하듯이 학습시키는 구조로 이러한 기존 문제를 해결하였다.
Architecture
논문의 예시를 빌려와 아이디어를 설명하자면, 다음과 같다.
- 생성 모델 $G$: 화폐 위조범
- 판별 모델 $D$: 경찰 경찰 $D$는 화폐위조범 $G$가 만들어낸 위조 지폐를 구별해낼 수 있도록 학습된다. 화폐위조범 $G$는 경찰에게 들키지 않기 위해 자신들이 만드는 위조지폐를 더욱이 실제 지폐와 비슷하게 만들게끔 학습하여 경찰 $D$를 속인다. 이렇게 두 모델이 서로 경쟁해가며 자신의 성능을 향상시키는 것이다.
논문에선 이 두 모델을 모두 MLP로 정의하는데, 이 경우를 Adversarial Nets라고 나타낸다.
다시 말해서, 모델 $G(z;\theta_{g})$는 입력 노이즈 변수 $z \sim p_{z}(z)$를 입력으로 받아 데이터 공간(진짜 지폐에 해당되는 데이터 공간)으로 매핑한다. 모델 $D(x;\theta_{d})$는 데이터 공간에서 한개의 스칼라로 매핑하는데, 이 스칼라는 입력 $x$가 모델 $G$가 아니라 실제 데이터의 분포로 샘플링되었을 확률을 나타낸다. 즉, 이 지폐가 위조지폐가 아니라 진짜 지폐일 확률을 반환한다.
두 모델은 다음 함수를 각각 최대, 최소화 하는 것을 목표로 학습한다.
\[\underset{G}{\min} \underset{D}{\max}V(D,G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_{z}(z)}[\log[1-D(G(z))]]\]모델 $D$가 이 함수를 최대화한 다는 것은, 실제 데이터 $x$에 대해선 $D(x) \to 1$이 되고, 모델 $G$가 생성한 값 $G(z)$에 대해선 $D(G(z)) \to 0$이 되도록 학습한다는 것이다.
그리고 모델 $G$는 자신이 생성한 값 $G(z)$에 대해서 $D(G(z)) \to 1$이 되도록 하여 위 함수를 최소화하도록 학습한다.
학습 과정은 위와 같다. 검은색 점선은 실제 지폐의 분포 ($p_{\text{data}}$), 파란색 점선은 모델 $D$의 출력, 초록색 실선은 모델 $G$가 출력한 위조 지폐의 분포이다 ($p_{g}$).
(a)는 초기 상태이고, (b)는 모델 $D$가 학습하여 $D^{*}(x) = \frac{p_{\text{data}}}{p_{\text{data}} + p_{g}}$에 수렴한 모습이다. (c)는 이렇게 학습된 모델 $D$의 기울기를 통해 실제 데이터 분포 $p_{\text{data}}$를 유추하여 $p_{g}$를 갱신한다. 이 과정을 무한히 반복하면, (d)와 같이 $p_{\text{data}} = p_{g}$가 되고, 모델 $D$는 결국 $\frac{1}{2}$만을 출력하여 실제 데이터와 모델 $G$가 생성한 데이터를 구분하지 못하게 된다.
DCGAN: Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
Introduction
DCGAN(Deep Convolutional GANs)은 앞서 설명한 Adversarial Nets에서 모델 $D$, $G$를 MLP로 정의한 것과 달리, CNN으로 정의한 구조를 말한다.
Architecture
위 사진은 $64 \times 64$ 이미지에 대한 모델 $G$의 구조이다. $(C, H, W)$에 대해서 입력 노이즈 $z \in (d_{z}, 1, 1)$가 주어졌을 때 각 컨볼루션 레이어를 거쳐 최종적으로 $(3, 64, 64)$의 RGB 채널의 $64 \times 64$ 이미지를 생성한다.
모델 $D$는 모델 $G$ 구조의 역순으로, 마지막 컨볼루션 레이어에서 $z \in (d_{z}, 1, 1)$ 이 아닌 스칼라를 반환한다.
물론 위 구조는 이 논문에서 사용될 LSUN의 $64 \times 64$ 이미지에 맞춘 것으로 이미지 크기에 따라 구조를 수정할 수 있다.
보다시피 모든 레이어가 컨볼루션 레이어로 이루어져 있음을 알 수 있고, 각 파라미터의 안정화를 위해 배치 정규화 (BN, Batch Normalization)을 사용하였다. 하지만 모든 레이어에 적용하면 오히려 불안정한 결과를 초래하기 때문에, $G$의 출력 레이어와 $D$의 입력레이어를 제외하고 모든 레이어에 배치 정규화를 적용하였다.
또한 모델 $G$에서 출력 레이어를 제외한 모든 레이어에 ReLU 활성화 함수를 사용하였고, 출력 레이어엔 Tanh 활성화 함수를 사용하였다.
모델 $D$에선 모든 레이어에 대해 LeakyReLU 활성화 함수를 사용하였다.
Implementation
논문에서의 DCGAN 구조를 기반으로 MNIST, LSUN 데이터셋에서 학습을 진행하였으며, 학습 환경의 제약으로 기존 LSUN 데이터의 10%인 약 30만개의 샘플만 학습에 사용하였다. 데이터셋 링크. 학습의 설정과 하이퍼파라미터는 논문의 설정을 그대로 사용하였다. 또한 MNIST 데이터셋의 경우 1채널의 $28 \times 28$ 크기의 이미지를 데이터로 가지므로, 논문에서의 $64 \times 64$ 이미지에 대한 구조를 축소시켜 훈련하였다.
Latent Space
논문에서는 이 모델이 이미지를 단순히 암기한 것인지, 아니면 데이터의 고차원적 개념을 연속적인 Manifold 공간 내에서 제대로 이해한 것인지 검증하기 위해 두 가지 잠재 공간 탐색 실험을 수행하였다.
위 사진은 잠재 공간 보간의 예시로, 임의로 샘플링한 두 개의 노이즈 벡터 $z1, z2 \sim p_{z}(z)$ 사이를 선형 내분점 공식 $(1-\alpha)z_{1} + \alpha z_{2}$를 통해 부드럽게 가로지르며 이미지를 연속 생성하여 각 생성된 이미지를 확인한 것이다.
각 이미지들이 툭툭 끊기는 식으로 전환되는 것이 아니라, 전체 이미지가 매끄럽게 모핑되며 변형되는 연속성을 확인 할 수 있다. (DC)GAN 모델이 학습을 통해서 단순히 각 이미지 데이터를 외우는게 아니라, 유의미한 이미지 생성을 학습했다고 생각할 수 있다.
위 사진은 임의의 노이즈 벡터 $z1 \sim p_{z}(z)$에서 특정 차원의 가중치만을 임의의 범위만큼 변화시키며 각 과정에서 생성된 이미지를 확인한 것이다.
각 이미지들을 확인했을 때, 전체적인 방의 구조는 고정되지만, 문(?)으로 보이는 형체가 서서히 창문(?)으로 바뀌거나 방의 전체적인 조명 (색감?) 이 바뀌는 등의 특정 단일 시각적 개념만 독립적으로 제어되는 것을 볼 수 있다. 이는 (DC)GAN 모델의 구조가 심층 표현 학습(Deep Representation Learning)을 수행하는 과정에서, 이미지의 특징들을 독립적인 잠재 공간의 차원으로 분리하여 표현하고 있음을 알 수 있다.
찌라시
기존 생성 모델들은 데이터의 확률 분포 $p_{\text{data}}$를 수학적으로 계산하거나 근사하는 식으로 학습했여야 했었지만, 이 논문에서는 이러한 확률 분포 계산 없이 두 모델 $G,D$를 서로 경쟁시키며 학습한다는 신기한(?) 구조를 제시하였다. 생성된 이미지 결과는 나름 그럴듯(?) 하다.
논문에선 추가로 이미지 생성 뿐만 아니라 판별 모델 $D$의 특징 추출기(Feature Extractor)로서의 성능도 평가했는데, ImageNet-1k 에서 학습한 후에 모델 $D$의 특징 벡터들을 활용하여 CIFAR-10 데이터셋에서 이미지 분류를 시험해보았다. 결과는 82.8%로 K-means 기반의 방법론들을 능가했다. 이때 CIFAR-10 데이터로 전혀 학습을 하지 않았다는 점에서 (DC)GAN의 모델이 사물의 본질적인 시각적 특징을 학습할 수 있음을 나타낸다.





