CycleGAN模型原理

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: CycleGAN模型原理

CycleGAN模型原理

前面我们了解了好几种GAN,它们大致可分为:
随机生成模型的GAN,包括GAN,DCGAN,WGAN,WGAN-GP等;
带条件生成模型的GAN,包括CGAN,InfoGAN,ACGAN等。
它们都是监督学习模型,即生成网络都有一个目标样本集。

除了监督学习模型,还有一类非监督学习模型,比如CycleGAN。
CycleGAN如下图所示,它能在油画到相片互相生成;马到斑马互相生成;夏天到冬天季节相互变化。它们都不是要把当前域的样本拟合到另一个域。
它们只是把当前域X的某种公共特征(着色填充)变成了另一个域Y的公共特征,且保持当前图片的主要特征(线条轮廓)不变。
即:
1 X域所有图片的共性:油画 -> Y域所有图片的共性:真实照片,当前主要特征画内容的线条保持不变;
2 X域所有图片的共性:马的纹理 -> Y域所有图片的共性:斑马纹理,当前马的轮廓不会变;
3 X域所有图片的共性:夏天风景 -> Y域所有图片的共性:冬天风景,当前的河流,树木轮廓都不会变。


怎么才能达到上面的要求呢?以马到斑马为例:
1 在X马的样本里面,我们能找到它的公共特性就是马那黄色填充的毛,在Y斑马的样本里面,公共特性就是那条纹的着色填充。如果要把X拟合成Y,把Y拟合成X,那么很容易想到用GAN既可以解决,即用最小化“原始GAN的损失”来达到。
2 但是在步骤1完成后,我们会发现,马变成斑马后再也不像以前那匹马了。对人来说,感受物体最主要的特征就是物体的轮廓,所以我们采用最小化“Cycle一致性损失”,即对于单一的当前样本来说,X转成Y,然后再转回来X1,必须保持轮廓大致不变,这样人就会觉得物体没有变化。但对于计算机来说,它能看到的主要特征是大面积的着色填充,在X转成Y,或者Y再转回来的的过程中,着色填充确实做到了X到Y,Y到X的拟合。这样步骤1和步骤2就完成了计算机和人的双重欺骗。
3 怎么还有步骤3呢?这是论文后来加上的。在油画到真实照片的转换中,或反之,我们不希望它们的转换对自身产生作用,即X转成X1,Y转成Y1,不希望转换做任何改变,直接返回就可以了。论文中用最小化“Identity映射损失”来达到。
综上所述,就是用大面积的着色来骗计算机,用对人眼敏感的轮廓来骗人。这个理论也可以用在CGAN训练MNIST上,高斯随机噪声加数字输入,数字可以看作轮廓,噪声里面隐含了计算机能看到的着色。

损失函数组成

根据上面的分析,整个CycleGAN的目标损失函数由3部分组成:

1 原始GAN的损失:X通过生成网络G生成Y,到鉴别网络DY的输出;Y通过生成网络F生成X,到鉴别网络DX的输出。

如代码中:
a X到Y的鉴别网络loss(dB_loss_fake),生成网络loss(g_AB_loss):
img_A -> g_AB(img_A) -> fake_B -> d_B(fake_B) -> valid_B : 输入img_A,输出valid_B,鉴别网络d_B目标fake,生成网络g_AB目标valid;
b 真实Y的鉴别网络loss(dB_loss_real):
img_B-> d_B(img_B) -> valid_B : 输入img_B,输出valid_B,鉴别网络d_B目标valid;
c Y到X的鉴别网络loss(dA_loss_fake),生成网络loss(g_BA_loss):
img_B -> g_BA(img_B) -> fake_A -> d_A(fake_A) -> valid_A : 输入img_B,输出valid_A,鉴别网络d_A目标fake,生成网络g_BA目标valid;
d 真实X的鉴别网络loss(dA_loss_real):
img_A-> d_A(img_A) -> valid_A : 输入img_A,输出valid_A,鉴别网络d_A目标valid。

2 Cycle一致性损失:X通过生成网络G生成Y,然后再通过生成网络F回到X1,X到X1的损失。网络的目标是保证这种损失尽量小,即能还原X。

如代码中:
a 生成网络loss(g_AB_BA_loss):
img_A -> g_AB(img_A) -> fake_B -> g_BA(fake_B) -> reconstr_A :输入img_A,输出reconstr_A,生成网络g_AB -> g_BA目标imgs_A;
b 生成网络loss(g_BA_AB_loss):
img_B -> g_BA(img_B) -> fake_A -> g_AB(fake_A) -> reconstr_B :输入img_B,输出reconstr_B,生成网络g_BA -> g_AB目标imgs_B。

3 Identity映射损失:X通过网络F生成X1,网络需要保证X到X1几乎没有改变,即X域到X域的转变不需要修改,最大保留X的特性。

如代码中:
a 生成网络loss(g_BA_Ident_loss):
img_A -> g_BA -> img_A_id :输入img_A,输出img_A_id,目标imgs_A;
b 生成网络loss(g_AB_Ident_loss):
img_B -> g_AB -> img_B_id :输入img_B,输出img_B_id,目标img_B。

综上,鉴别网络的loss为:dA_loss_real,dA_loss_fake,dB_loss_real,dB_loss_fake;生成网络的loss为:g_AB_loss,g_BA_loss,g_AB_BA_loss,g_BA_AB_loss,g_AB_Ident_loss,g_BA_Ident_loss。

        # Translate images to the other domain
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)
        # Translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)
        # Identity mapping of images
        img_A_id = self.g_BA(img_A)
        img_B_id = self.g_AB(img_B)
 
        # For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False
 
        # Discriminators determines validity of translated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)
 
        # Combined model trains generators to fool discriminators
        self.combined = Model(inputs=[img_A, img_B],
                              outputs=[ valid_A, valid_B,
                                        reconstr_A, reconstr_B,
                                        img_A_id, img_B_id ])
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                            loss_weights=[  1, 1,
                                            self.lambda_cycle, self.lambda_cycle,
                                            self.lambda_id, self.lambda_id ],
                            optimizer=optimizer)
                # ----------------------
                #  Train Discriminators
                # ----------------------
 
                # Translate images to opposite domain
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)
 
                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)
 
                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)
 
                # Total disciminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)
 
 
                # ------------------
                #  Train Generators
                # ------------------
 
                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                        [valid, valid,
                                                        imgs_A, imgs_B,
                                                        imgs_A, imgs_B])

CycleGAN的Keras实现
from __future__ import print_function, division
import scipy
 
from keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os
 
class CycleGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
 
        # Configure data loader
        self.dataset_name = 'summer2winter_yosemite'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.img_rows, self.img_cols))
 
 
        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)
 
        # Number of filters in the first layer of G and D
        self.gf = 32
        self.df = 64
 
        # Loss weights
        self.lambda_cycle = 10.0                    # Cycle-consistency loss
        self.lambda_id = 0.1 * self.lambda_cycle    # Identity loss
 
        optimizer = Adam(0.0002, 0.5)
 
        # Build and compile the discriminators
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        self.d_A.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
        self.d_B.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
 
        #-------------------------
        # Construct Computational
        #   Graph of Generators
        #-------------------------
 
        # Build the generators
        self.g_AB = self.build_generator()
        self.g_BA = self.build_generator()
 
        # Input images from both domains
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)
 
        # Translate images to the other domain
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)
        # Translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)
        # Identity mapping of images
        img_A_id = self.g_BA(img_A)
        img_B_id = self.g_AB(img_B)
 
        # For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False
 
        # Discriminators determines validity of translated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)
 
        # Combined model trains generators to fool discriminators
        self.combined = Model(inputs=[img_A, img_B],
                              outputs=[ valid_A, valid_B,
                                        reconstr_A, reconstr_B,
                                        img_A_id, img_B_id ])
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                            loss_weights=[  1, 1,
                                            self.lambda_cycle, self.lambda_cycle,
                                            self.lambda_id, self.lambda_id ],
                            optimizer=optimizer)
 
    def build_generator(self):
        """U-Net Generator"""
 
        def conv2d(layer_input, filters, f_size=4):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            d = InstanceNormalization()(d)
            return d
 
        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = InstanceNormalization()(u)
            u = Concatenate()([u, skip_input])
            return u
 
        # Image input
        d0 = Input(shape=self.img_shape)
 
        # Downsampling
        d1 = conv2d(d0, self.gf)
        d2 = conv2d(d1, self.gf*2)
        d3 = conv2d(d2, self.gf*4)
        d4 = conv2d(d3, self.gf*8)
 
        # Upsampling
        u1 = deconv2d(d4, d3, self.gf*4)
        u2 = deconv2d(u1, d2, self.gf*2)
        u3 = deconv2d(u2, d1, self.gf)
 
        u4 = UpSampling2D(size=2)(u3)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)
 
        return Model(d0, output_img)
 
    def build_discriminator(self):
 
        def d_layer(layer_input, filters, f_size=4, normalization=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if normalization:
                d = InstanceNormalization()(d)
            return d
 
        img = Input(shape=self.img_shape)
 
        d1 = d_layer(img, self.df, normalization=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)
 
        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)
 
        return Model(img, validity)
 
    def train(self, epochs, batch_size=1, sample_interval=50):
 
        start_time = datetime.datetime.now()
 
        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)
 
        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):
 
                # ----------------------
                #  Train Discriminators
                # ----------------------
 
                # Translate images to opposite domain
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)
 
                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)
 
                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)
 
                # Total disciminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)
 
 
                # ------------------
                #  Train Generators
                # ------------------
 
                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                        [valid, valid,
                                                        imgs_A, imgs_B,
                                                        imgs_A, imgs_B])
 
                elapsed_time = datetime.datetime.now() - start_time
 
                # Plot the progress
                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                                                                        % ( epoch, epochs,
                                                                            batch_i, self.data_loader.n_batches,
                                                                            d_loss[0], 100*d_loss[1],
                                                                            g_loss[0],
                                                                            np.mean(g_loss[1:3]),
                                                                            np.mean(g_loss[3:5]),
                                                                            np.mean(g_loss[5:6]),
                                                                            elapsed_time))
 
                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)
 
    def sample_images(self, epoch, batch_i):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 3
 
        imgs_A = self.data_loader.load_data(domain="A", batch_size=1, is_testing=True)
        imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=True)
 
        # Demo (for GIF)
        #imgs_A = self.data_loader.load_img('datasets/apple2orange/testA/n07740461_1541.jpg')
        #imgs_B = self.data_loader.load_img('datasets/apple2orange/testB/n07749192_4241.jpg')
 
        # Translate images to the other domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = self.g_BA.predict(fake_B)
        reconstr_B = self.g_AB.predict(fake_A)
 
        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])
 
        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5
 
        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()
 
 
if __name__ == '__main__':
    gan = CycleGAN()
    gan.train(epochs=200, batch_size=1, sample_interval=200)

data_loader.py:

from glob import glob
import numpy as np
from PIL import Image
 
class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res
 
    def load_data(self, domain, batch_size=1, is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        path = glob('E:\workspace49\CycleGAN-tensorflow/datasets/%s/%s/*' % (self.dataset_name, data_type))
 
        batch_images = np.random.choice(path, size=batch_size)
 
        imgs = []
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = img.resize(self.img_res)
 
                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = img.resize(self.img_res)
            imgs.append(np.array(img))
 
        imgs = np.array(imgs)/127.5 - 1.
 
        return imgs
 
    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path_A = glob('E:\workspace49\CycleGAN-tensorflow/datasets/%s/%sA/*' % (self.dataset_name, data_type))
        path_B = glob('E:\workspace49\CycleGAN-tensorflow/datasets/%s/%sB/*' % (self.dataset_name, data_type))
 
        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size
 
        # Sample n_batches * batch_size from each path list so that model sees all
        # samples from both domains
        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)
 
        for i in range(self.n_batches-1):
            batch_A = path_A[i*batch_size:(i+1)*batch_size]
            batch_B = path_B[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)
 
                img_A = img_A.resize(self.img_res)
                img_B = img_B.resize(self.img_res)
 
                if not is_testing and np.random.random() > 0.5:
                        img_A = np.fliplr(img_A)
                        img_B = np.fliplr(img_B)
 
                imgs_A.append(np.array(img_A))
                imgs_B.append(np.array(img_B))
 
            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.
 
            yield imgs_A, imgs_B
 
    def load_img(self, path):
        img = self.imread(path)
        img = img.resize(self.img_res)
        img = img/127.5 - 1.
        return img[np.newaxis, :, :, :]
 
    def imread(self, path):
        return Image.open(path)

本作品采用知识共享署名 4.0 国际许可协议进行许可。

发表回复