| from __future__ import print_function, division |
= |
from __future__ import print_function, division |
| |
|
|
| from keras.datasets import mnist |
|
from keras.datasets import mnist |
| from keras.layers import Input, Dense, Reshape, Flatten, Dropout |
|
from keras.layers import Input, Dense, Reshape, Flatten, Dropout |
| from keras.layers import BatchNormalization, Activation, ZeroPadding2D |
|
from keras.layers import BatchNormalization, Activation, ZeroPadding2D |
| from keras.layers.advanced_activations import LeakyReLU |
|
from keras.layers.advanced_activations import LeakyReLU |
| from keras.layers.convolutional import UpSampling2D, Conv2D |
|
from keras.layers.convolutional import UpSampling2D, Conv2D |
| from keras.models import Sequential, Model |
|
from keras.models import Sequential, Model |
| from keras.optimizers import RMSprop |
<> |
from keras.optimizers import Adam |
| |
= |
|
| import keras.backend as K |
+- |
|
| |
= |
|
| import matplotlib.pyplot as plt |
|
import matplotlib.pyplot as plt |
| |
|
|
| import sys |
|
import sys |
| |
|
|
| import numpy as np |
|
import numpy as np |
| |
|
|
| class WGAN(): |
<> |
class DCGAN(): |
| def __init__(self): |
= |
def __init__(self): |
| |
|
|
| self.img_rows = 28 |
|
self.img_rows = 28 |
| self.img_cols = 28 |
|
self.img_cols = 28 |
| self.channels = 1 |
|
self.channels = 1 |
| self.img_shape = (self.img_rows, self.img_cols, self.channels) |
|
self.img_shape = (self.img_rows, self.img_cols, self.channels) |
| self.latent_dim = 100 |
|
self.latent_dim = 100 |
| |
|
|
|
|
|
| self.n_critic = 5 |
<> |
|
| self.clip_value = 0.01 |
|
|
| optimizer = RMSprop(lr=0.00005) |
|
optimizer = Adam(0.0002, 0.5) |
| |
= |
|
|
|
|
| self.critic = self.build_critic() |
<> |
self.discriminator = self.build_discriminator() |
| self.critic.compile(loss=self.wasserstein_loss, |
|
self.discriminator.compile(loss='binary_crossentropy', |
| optimizer=optimizer, |
= |
optimizer=optimizer, |
| metrics=['accuracy']) |
|
metrics=['accuracy']) |
| |
|
|
|
|
|
| self.generator = self.build_generator() |
|
self.generator = self.build_generator() |
| |
|
|
|
|
|
| z = Input(shape=(self.latent_dim,)) |
|
z = Input(shape=(self.latent_dim,)) |
| img = self.generator(z) |
|
img = self.generator(z) |
| |
|
|
|
|
|
| self.critic.trainable = False |
<> |
self.discriminator.trainable = False |
| |
= |
|
|
|
|
| valid = self.critic(img) |
<> |
valid = self.discriminator(img) |
| |
= |
|
|
|
|
| |
|
|
| self.combined = Model(z, valid) |
|
self.combined = Model(z, valid) |
| self.combined.compile(loss=self.wasserstein_loss, |
<> |
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) |
| optimizer=optimizer, |
|
|
| metrics=['accuracy']) |
|
|
| |
= |
|
| def wasserstein_loss(self, y_true, y_pred): |
+- |
|
| return K.mean(y_true * y_pred) |
|
|
| |
= |
|
| def build_generator(self): |
|
def build_generator(self): |
| |
|
|
| model = Sequential() |
|
model = Sequential() |
| |
|
|
| model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim)) |
|
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim)) |
| model.add(Reshape((7, 7, 128))) |
|
model.add(Reshape((7, 7, 128))) |
| model.add(UpSampling2D()) |
|
model.add(UpSampling2D()) |
| model.add(Conv2D(128, kernel_size=4, padding="same")) |
<> |
model.add(Conv2D(128, kernel_size=3, padding="same")) |
| model.add(BatchNormalization(momentum=0.8)) |
= |
model.add(BatchNormalization(momentum=0.8)) |
| model.add(Activation("relu")) |
|
model.add(Activation("relu")) |
| model.add(UpSampling2D()) |
|
model.add(UpSampling2D()) |
| model.add(Conv2D(64, kernel_size=4, padding="same")) |
<> |
model.add(Conv2D(64, kernel_size=3, padding="same")) |
| model.add(BatchNormalization(momentum=0.8)) |
= |
model.add(BatchNormalization(momentum=0.8)) |
| model.add(Activation("relu")) |
|
model.add(Activation("relu")) |
| model.add(Conv2D(self.channels, kernel_size=4, padding="same")) |
<> |
model.add(Conv2D(self.channels, kernel_size=3, padding="same")) |
| model.add(Activation("tanh")) |
= |
model.add(Activation("tanh")) |
| |
|
|
| model.summary() |
|
model.summary() |
| |
|
|
| noise = Input(shape=(self.latent_dim,)) |
|
noise = Input(shape=(self.latent_dim,)) |
| img = model(noise) |
|
img = model(noise) |
| |
|
|
| return Model(noise, img) |
|
return Model(noise, img) |
| |
|
|
| def build_critic(self): |
<> |
def build_discriminator(self): |
| |
= |
|
| model = Sequential() |
|
model = Sequential() |
| |
|
|
| model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same")) |
<> |
model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same")) |
| model.add(LeakyReLU(alpha=0.2)) |
= |
model.add(LeakyReLU(alpha=0.2)) |
| model.add(Dropout(0.25)) |
|
model.add(Dropout(0.25)) |
| model.add(Conv2D(32, kernel_size=3, strides=2, padding="same")) |
<> |
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) |
| model.add(ZeroPadding2D(padding=((0,1),(0,1)))) |
= |
model.add(ZeroPadding2D(padding=((0,1),(0,1)))) |
| model.add(BatchNormalization(momentum=0.8)) |
|
model.add(BatchNormalization(momentum=0.8)) |
| model.add(LeakyReLU(alpha=0.2)) |
|
model.add(LeakyReLU(alpha=0.2)) |
| model.add(Dropout(0.25)) |
|
model.add(Dropout(0.25)) |
| model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) |
<> |
model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) |
| model.add(BatchNormalization(momentum=0.8)) |
= |
model.add(BatchNormalization(momentum=0.8)) |
| model.add(LeakyReLU(alpha=0.2)) |
|
model.add(LeakyReLU(alpha=0.2)) |
| model.add(Dropout(0.25)) |
|
model.add(Dropout(0.25)) |
| model.add(Conv2D(128, kernel_size=3, strides=1, padding="same")) |
<> |
model.add(Conv2D(256, kernel_size=3, strides=1, padding="same")) |
| model.add(BatchNormalization(momentum=0.8)) |
= |
model.add(BatchNormalization(momentum=0.8)) |
| model.add(LeakyReLU(alpha=0.2)) |
|
model.add(LeakyReLU(alpha=0.2)) |
| model.add(Dropout(0.25)) |
|
model.add(Dropout(0.25)) |
| model.add(Flatten()) |
|
model.add(Flatten()) |
| model.add(Dense(1)) |
<> |
model.add(Dense(1, activation='sigmoid')) |
| |
= |
|
| model.summary() |
|
model.summary() |
| |
|
|
| img = Input(shape=self.img_shape) |
|
img = Input(shape=self.img_shape) |
| validity = model(img) |
|
validity = model(img) |
| |
|
|
| return Model(img, validity) |
|
return Model(img, validity) |
| |
|
|
| def train(self, epochs, batch_size=128, sample_interval=50): |
<> |
def train(self, epochs, batch_size=128, save_interval=50): |
| |
= |
|
|
|
|
| (X_train, _), (_, _) = mnist.load_data() |
|
(X_train, _), (_, _) = mnist.load_data() |
| |
|
|
|
|
|
| X_train = (X_train.astype(np.float32) - 127.5) / 127.5 |
<> |
X_train = X_train / 127.5 - 1. |
| X_train = np.expand_dims(X_train, axis=3) |
= |
X_train = np.expand_dims(X_train, axis=3) |
| |
|
|
|
|
|
| valid = -np.ones((batch_size, 1)) |
<> |
valid = np.ones((batch_size, 1)) |
| fake = np.ones((batch_size, 1)) |
|
fake = np.zeros((batch_size, 1)) |
| |
= |
|
| for epoch in range(epochs): |
|
for epoch in range(epochs): |
| |
|
|
| for _ in range(self.n_critic): |
+- |
|
| |
= |
|
|
|
|
|
|
|
|
|
|
| |
|
|
|
|
|
| idx = np.random.randint(0, X_train.shape[0], batch_size) |
<> |
idx = np.random.randint(0, X_train.shape[0], batch_size) |
| imgs = X_train[idx] |
|
imgs = X_train[idx] |
|
= |
|
|
|
|
| noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) |
<> |
noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) |
| |
= |
|
|
|
|
| gen_imgs = self.generator.predict(noise) |
<> |
gen_imgs = self.generator.predict(noise) |
| |
= |
|
|
|
|
| d_loss_real = self.critic.train_on_batch(imgs, valid) |
<> |
d_loss_real = self.discriminator.train_on_batch(imgs, valid) |
| d_loss_fake = self.critic.train_on_batch(gen_imgs, fake) |
|
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) |
| d_loss = 0.5 * np.add(d_loss_fake, d_loss_real) |
|
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) |
| |
= |
|
|
|
|
| for l in self.critic.layers: |
+- |
|
| weights = l.get_weights() |
|
|
| weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights] |
|
|
| l.set_weights(weights) |
|
|
| |
= |
|
| |
|
|
|
|
|
|
|
|
|
|
|
| |
|
|
| |
|
|
| g_loss = self.combined.train_on_batch(noise, valid) |
|
g_loss = self.combined.train_on_batch(noise, valid) |
| |
|
|
|
|
|
| print ("%d [D loss: %f] [G loss: %f]" % (epoch, 1 - d_loss[0], 1 - g_loss[0])) |
<> |
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) |
| |
= |
|
|
|
|
| if epoch % sample_interval == 0: |
<> |
if epoch % save_interval == 0: |
| self.sample_images(epoch) |
|
self.save_imgs(epoch) |
| |
= |
|
| def sample_images(self, epoch): |
<> |
def save_imgs(self, epoch): |
| r, c = 5, 5 |
= |
r, c = 5, 5 |
| noise = np.random.normal(0, 1, (r * c, self.latent_dim)) |
|
noise = np.random.normal(0, 1, (r * c, self.latent_dim)) |
| gen_imgs = self.generator.predict(noise) |
|
gen_imgs = self.generator.predict(noise) |
| |
|
|
|
|
|
| gen_imgs = 0.5 * gen_imgs + 0.5 |
|
gen_imgs = 0.5 * gen_imgs + 0.5 |
| |
|
|
| fig, axs = plt.subplots(r, c) |
|
fig, axs = plt.subplots(r, c) |
| cnt = 0 |
|
cnt = 0 |
| for i in range(r): |
|
for i in range(r): |
| for j in range(c): |
|
for j in range(c): |
| axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') |
|
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') |
| axs[i,j].axis('off') |
|
axs[i,j].axis('off') |
| cnt += 1 |
|
cnt += 1 |
| fig.savefig("images/mnist_%d.png" % epoch) |
|
fig.savefig("images/mnist_%d.png" % epoch) |
| plt.close() |
|
plt.close() |
| |
|
|
| |
|
|
| if __name__ == '__main__': |
|
if __name__ == '__main__': |
| wgan = WGAN() |
<> |
dcgan = DCGAN() |
| wgan.train(epochs=4000, batch_size=32, sample_interval=50) |
|
dcgan.train(epochs=4000, batch_size=32, save_interval=50) |