home search
paper Review
Machine Learning
Deep Learning
Generative Adversarial Nets 구현
2022. 07. 29

이 글은 나동빈 님의 [꼼꼼한 딥러닝 논문 리뷰와 코드 실습: Deep Learning Paper Review and Practice] 강의와 코드를 보고 실습한 내용을 담고 있다. 나동빈 님의 코드 구현을 약간 변형하여 실습을 진행하였다. 실습 코드는 깃허브 레포지토리에 있다.

기존 내용과의 변경점

  • 주어진 실습 코드는 GPU를 사용하였으나, GPU를 이용하지 못하는 관계로 CPU 기반 코드로 바꾸었다.
  • Jupyter Notebook에서 작성된 코드를 일반 *.py로 바꾸었다.
  • TorchSummary와 Tensorboard를 사용하여 학습 내용을 출력하고 저장한다.

모델 구조

Generator(
  (model): Sequential(
    (0): Linear(in_features=100, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Linear(in_features=256, out_features=512, bias=True)
    (6): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Linear(in_features=512, out_features=1024, bias=True)
    (9): BatchNorm1d(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Linear(in_features=1024, out_features=784, bias=True)
    (12): Tanh()
  )
)
Discriminator(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Linear(in_features=256, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

학습

  • MNIST 데이터셋을 사용했다.
  • 크게 생성자 모델은 5개, 판별자 모델은 3개의 층을 쌓아 만들었다.
  • 200 에폭으로 지정하고 100 에폭 정도를 학습시켰다. 출력 결과 일부는 아래와 같다.
 0 / 200 epoch (  0 %)   0 iter          G loss: 0.67209220      D loss: 0.68603897
 0 / 200 epoch ( 43 %) 200 iter          G loss: 0.79006857      D loss: 0.52829790
 0 / 200 epoch ( 85 %) 400 iter          G loss: 0.78521383      D loss: 0.48324159
-> [Epoch   0 / 200] D loss: 0.488496 | G loss: 1.059803 (Elapsed time:    15.80 s)
 1 / 200 epoch ( 28 %) 600 iter          G loss: 0.66291004      D loss: 0.52356422
 1 / 200 epoch ( 71 %) 800 iter          G loss: 2.27265215      D loss: 0.45594978
-> [Epoch   1 / 200] D loss: 0.436787 | G loss: 1.112136 (Elapsed time:    32.37 s)
 2 / 200 epoch ( 13 %) 1000 iter         G loss: 2.08537960      D loss: 0.67539519
 2 / 200 epoch ( 56 %) 1200 iter         G loss: 0.79396963      D loss: 0.46769536
 2 / 200 epoch ( 99 %) 1400 iter         G loss: 2.26668715      D loss: 0.47183973
-> [Epoch   2 / 200] D loss: 0.282499 | G loss: 1.210893 (Elapsed time:    48.05 s)
 3 / 200 epoch ( 41 %) 1600 iter         G loss: 2.09876323      D loss: 0.40114748
 3 / 200 epoch ( 84 %) 1800 iter         G loss: 1.45045888      D loss: 0.31495392
-> [Epoch   3 / 200] D loss: 0.266879 | G loss: 1.206154 (Elapsed time:    64.50 s)
 4 / 200 epoch ( 26 %) 2000 iter         G loss: 0.63941079      D loss: 0.49253616
 4 / 200 epoch ( 69 %) 2200 iter         G loss: 1.96853662      D loss: 0.21636574
-> [Epoch   4 / 200] D loss: 0.307741 | G loss: 2.022917 (Elapsed time:    81.93 s)
 5 / 200 epoch ( 12 %) 2400 iter         G loss: 1.11803007      D loss: 0.29675204
 5 / 200 epoch ( 54 %) 2600 iter         G loss: 1.15375996      D loss: 0.26232159
 5 / 200 epoch ( 97 %) 2800 iter         G loss: 3.05328846      D loss: 0.32685372
-> [Epoch   5 / 200] D loss: 0.571236 | G loss: 2.913118 (Elapsed time:    98.59 s)

[중간 생략]

103 / 200 epoch ( 20 %) 48400 iter       G loss: 1.23042071      D loss: 0.34976822
103 / 200 epoch ( 62 %) 48600 iter       G loss: 1.44759262      D loss: 0.34187675
-> [Epoch 103 / 200] D loss: 0.287411 | G loss: 2.787339 (Elapsed time:  2336.81 s)
104 / 200 epoch (  5 %) 48800 iter       G loss: 1.73160005      D loss: 0.22666973
104 / 200 epoch ( 48 %) 49000 iter       G loss: 1.81321466      D loss: 0.36260408
104 / 200 epoch ( 90 %) 49200 iter       G loss: 2.66965365      D loss: 0.31132776
-> [Epoch 104 / 200] D loss: 0.221305 | G loss: 2.094731 (Elapsed time:  2360.17 s)
105 / 200 epoch ( 33 %) 49400 iter       G loss: 5.80921221      D loss: 0.87963939
105 / 200 epoch ( 76 %) 49600 iter       G loss: 1.84587741      D loss: 0.27305806
-> [Epoch 105 / 200] D loss: 0.278796 | G loss: 2.474602 (Elapsed time:  2382.44 s)
106 / 200 epoch ( 18 %) 49800 iter       G loss: 1.43034863      D loss: 0.29280686
106 / 200 epoch ( 61 %) 50000 iter       G loss: 2.17637467      D loss: 0.22612974
-> [Epoch 106 / 200] D loss: 0.390988 | G loss: 3.475564 (Elapsed time:  2404.98 s)
107 / 200 epoch (  4 %) 50200 iter       G loss: 3.01599503      D loss: 0.40974611
107 / 200 epoch ( 46 %) 50400 iter       G loss: 1.70899105      D loss: 0.28430742
107 / 200 epoch ( 89 %) 50600 iter       G loss: 2.39048934      D loss: 0.24424908
-> [Epoch 107 / 200] D loss: 0.374506 | G loss: 2.111918 (Elapsed time:  2426.77 s)
108 / 200 epoch ( 32 %) 50800 iter       G loss: 1.74034941      D loss: 0.31988096
108 / 200 epoch ( 74 %) 51000 iter       G loss: 3.36526370      D loss: 0.45345509
-> [Epoch 108 / 200] D loss: 0.327548 | G loss: 1.956849 (Elapsed time:  2450.05 s)
109 / 200 epoch ( 17 %) 51200 iter       G loss: 2.06895447      D loss: 0.25632367
109 / 200 epoch ( 59 %) 51400 iter       G loss: 2.01403880      D loss: 0.31695950
-> [Epoch 109 / 200] D loss: 0.253708 | G loss: 2.829304 (Elapsed time:  2473.93 s)

image

결과

각 iteration에서의 생성 이미지는 아래와 같다. 초반에는 글씨를 알아볼 수 없을 정도이다가, 점차 숫자의 형태를 갖춰감을 알 수 있다.

image

References

arrow_upward arrow_downward
loading