1 Introduction

Image classification is a classical topic in computer vision. There are many state-of-the-art networks proposed in the ImageNet challenge [1]. These deep neural networks commonly require a large and balanced dataset for training. However, in medical image classification, the performance of most networks will deteriorate due to the imbalanced dataset. The underlying idea of neural networks is minimizing the loss function via gradient descent. When training on an imbalanced dataset, the gradients will easily fall into the trap of predicting majority. Apart from reducing majority-class samples, to the best of our knowledge, the only effective solution is increasing the samples of minority. In the field of medical images, collecting pathological cases is time-consuming. The best solution is generating new minority-class images with high quality and with diversity.

Generative adversarial networks (GANs) [2] are currently the most powerful generative models. As one of deep neural networks, GANs also require a large dataset for training. However, the minority-class subset is always insufficient to train a good GAN. In particular, balancing GAN (BAGAN) [3] provided a new method to train GANs on imbalanced datasets while specifically aiming to generate minority-class images in high quality. The main contributions of BAGAN are 1. using an autoencoder to initialize the GAN training, which gives the GAN a common knowledge of all classes, 2. combining real/fake loss and classification loss fairly into one output at the discriminator, which ensures a balanced training for each class.

  • Problem statement

Although BAGAN proposed an autoencoder initialization to stabilize the GAN training, sometimes the performance of BAGAN is still unstable especially on medical image datasets. Medical image datasets are always: 1. highly imbalanced due to the rare pathological cases, 2. hard to distinguish the difference among classes. As shown in [3], the imbalanced Flowers dataset has many similar classes so that BAGAN performs not well. In our experiments, BAGAN fails to generate good samples on a small-scale medical image dataset. We consider that the encoder fails to separate images by class when translating them into latent vectors. Furthermore, similar to traditional GANs, BAGAN is hard to train and sensitive to its architecture and hyperparameters. Our objective of this work is to generate minority-class images in high quality even with a small-scale imbalanced dataset. Our contributions are

  • We improve the loss function of BAGAN with gradient penalty and build the corresponding architecture of generator and discriminator (BAGAN-GP).

  • We propose a novel architecture of autoencoder with an intermediate embedding model, which helps the autoencoder learn the label information directly.

  • We discuss the drawbacks of the original BAGAN and exemplify performance improvements over the original BAGAN and demonstrate the potential reasons.

2 Background

Literature review of GANs. The underlying method of Generative adversarial networks (GANs) is solving a minimax problem [2, 4]. A typical GAN model contains a generator and a discriminator. The generator wants to maximize its performance, which works to generate images as real as possible to confuse the discriminator. The discriminator works to distinguish a mixture of original and generated images whether real or fake. In this game, the generator attempts to mimic the distribution of the real data.

GAN techniques are fast developed in recent years. There are various types of GANs: with different metrics of comparing two distributions (e.g., KL divergence for the original GAN [2], Wasserstein distance for WGAN [5, 6], EBGAN [7], BEGAN [8], Loss-Sensitive GAN [9]), with regularization on the loss function (e.g., WGAN-GP [5], DRAGAN [10]), with different well-designed architecture of GANs (e.g., CycleGAN [11, 12], PGGAN [13], SAGAN [14]), with using a single image for generation (e.g., SinGAN [15]), with conditions (e.g., ACGAN [16]), for augmentation (e.g., AugGAN [17], BAGAN [3]), for reducing mode collapse problem (e.g., VEEGAN [18]).


GAN-based augmentation Data augmentation can extract more information from the original datasets to improve the performance of models. Traditional image augmentation is simply applying linear transformations to the original images, e.g., reflections, rotations, and shears. If the linear transformations do not affect the recognition of images, it is effective for the models to learn more information on the original dataset. To extract more information, it is also reasonable to apply some nonlinear transformations to the original dataset. GANs are exactly good at creating similar images by nonlinear transformations inside the network. The literature review [19] compared many data augmentation methods in deep learning, especially the methods based on GANs.

GANs can simulate the distribution of the real dataset and generate new data samples with high quality. Therefore, there are some recent work applying GANs as an augmentation technique. However, the small training set of minority-class images is still a challenge to train a GAN to generate high-quality samples. GAMO [20] introduced oversampling method into an end-to-end adversarial learning system to deal with the imbalance issue in classification. AugGAN [17] and AugCGAN [12] proposed an image-to-image translation framework to generate images in target domain. BAGAN [3] proposed an overall approach to generate minority-class images with high quality to balance the original dataset. [21] used conditional WGAN-GP (cWGAN-GP) to generate face emotion samples for data augmentation. [22] discussed the importance of data augmentation in medical image analysis and considered GANs as the most promising technique. For brain tumor images synthesis, [23] used GANs and [24] used conditional PGGAN for better tumor detection. Particularly, some recent research applied GANs for augmentation to detect the COVID-19 lesions from the pulmonary CT images [25].

3 Methods

3.1 BAGAN architecture

Fig. 1
figure 1

The architecture of BAGAN. BAGAN proposed three effective steps to improve the quality of generated images when training GANs on imbalanced datasets


Autoencoder initialization. Autoencoder initialization helps generator and discriminator to build a common knowledge of the dataset among all classes. Besides, autoencoder will lead the initialized GAN to a good and stable solution. BAGAN uses a typical autoencoder, the encoder translates a given image into a latent vector and the decoder translates a given latent vector back to a reconstructed image. It applies L2 loss minimization between real images and reconstructed images to train the autoencoder networks. In this step, there is no information about classes and the autoencoder learns all images unsupervisedly.


Labeled latent vectors generation. In this step, the class information is attached to each latent vector. The real images can be divided into different classes. Using the encoder to translate these images into latent vectors. With an assumption that these latent vectors are normally distributed within their own classes, a probabilistic generator can be derived by calculating means and covariances w.r.t classes.


Balanced training in GAN. The generator and the discriminator have prior knowledge from the initialized autoencoder. The generator inherits the same architecture and weights from the trained decoder. The discriminator inherits the same weights of the trained encoder as the first part and adds an auxiliary softmax layer to identify different classes. Differently from ACGAN [16], the discriminator has only one output but it can classify real/fake and other real classes. Furthermore, in each training batch, the proportion of fake images is the same as any other class. It means the gradients propagated equally for each class and real/fake validity. Although the majority-class images are easier for GAN to learn and to generate real-like images, the balanced training guarantees that the minority-class images will not be ignored.

3.2 Enhancements on BAGAN

3.2.1 Enhanced loss function

In this work, we will use two advanced loss functions with gradient penalty (from WGAN-GP [5] and DRAGAN [10]) to compare against the original loss function of BAGAN.


Original GAN. In original GAN model, the loss function is based on KL-JS divergence. Using cross-entropy loss to minimize the difference between two distribution is equivalent to minimizing the KL-JS divergence. However, KL-JS divergence can only give meaningful gradients when two distributions have overlaps. KL-JS divergence cannot measure how far two distributions away when they have no intersections. The loss function \(L\left( X_r,X_g\right) \) of original GAN is defined as:

$$\begin{aligned} \underset{\theta _G}{\min }\underset{\theta _D}{\max }L\left( X_r,X_g\right)&={\mathbb {E}}_{x_r\sim X_r}\left[ \log \left( D\left( x_r\right) \right) \right] \nonumber \\&\quad +{\mathbb {E}}_{x_g\sim X_g}\left[ \log \left( 1-D\left( x_g\right) \right) \right] \end{aligned}$$
(1)

where D denotes the discriminator function, G denotes the generator function, \(\theta _G \) is the parameters of the generator, \(\theta _D \) is the parameters of the discriminator; \(x_r \) is sampled from the real distribution \(X_r \), \(x_g \) is sampled from the generated distribution \(X_g \), where \(x_g=G\left( z\right) \) and z is a random noise vector sample from normal distribution\(z\sim N\left( 0,I_{dim(z)}\right) \). The discriminator is minimizing:

$$\begin{aligned} L^{\left( D\right) }\left( X_r,X_g\right) =&-{\mathbb {E}}_{x_r\sim X_r}\left[ \log \left( D\left( x_r\right) \right) \right] \nonumber \\&-{\mathbb {E}}_{x_g\sim X_g}\left[ \log \left( 1-D\left( x_g\right) \right) \right] \end{aligned}$$
(2)

The generator is minimizing:

$$L^{\left( G\right) }\left( X_g\right) =-{\mathbb {E}}_{x_g\sim X_g}\left[ \log \left( D\left( x_g\right) \right) \right]$$
(3)

WGAN. For the loss function, we can replace the KL divergence by the Wasserstein distance to improve the performance and training stability. In practice of constructing an original GAN, the architecture of discriminator is not suggested to be very powerful. A powerful discriminator cannot give meaningful gradients when training its generator. WGAN [6] proposed the Wasserstein distance to solve this problem. Wasserstein distance is the minimum transport cost of moving mass from one distribution to another distribution, which is also called as Earth-Mover Distance (EMD). EMD is continuous and differentiable so that the gradients are always meaningful, which ensures the stability of the GAN training. Based on the theory of WGAN, the generator will eventually converge to the performance of the discriminator. Hence, WGAN requires a deep architecture of the discriminator so that the

defined as:

$$\begin{aligned} W\left( X_r,X_g\right) =\underset{\gamma \sim \varPi \left( X_r,X_g\right) }{\inf }{\mathbb {E}}_{\left( x_r,x_g\right) \sim \gamma }\Vert {x_r-x_g}\Vert \end{aligned}$$
(4)

where \(\varPi \left( X_r,X_g\right) \) denotes all possible joint distributions between the real distribution \(X_r \) and the generated distribution \(X_g \). Each \(\gamma \) represents a transport plan.

However, it is impossible to find the lower bound by traversing all the possible \(\gamma \) in this equation. Using the Kantorovich–Rubinstein duality, it is equivalent to find the upper bound in:

$$\begin{aligned} W\left( X_r,X_g\right) =\underset{\Vert D\Vert _L\le 1}{\sup } \left( {\mathbb {E}}_{x_r\sim {X_r}}\left[ D(x_r)\right] -{\mathbb {E}}_{x_g\sim X_g}\left[ D(x_g)\right] \right) \end{aligned}$$
(5)

where \(\Vert D\Vert _L\le 1 \) denotes D belongs to the set of 1-Lipschitz functions. Without the constraint, the objective function for the discriminator is maximizing:

$$\begin{aligned} W^{\left( D\right) }\left( X_r,X_g\right) ={\mathbb {E}}_{x_r\sim {X_r}}\left[ D\left( x_r\right) \right] -{\mathbb {E}}_{x_g\sim {X_g}}\left[ D\left( x_g\right) \right] \end{aligned}$$
(6)

The discriminator in WGAN uses an unconstrained real number rather than a classification probability to measure the validity of real/fake images. The loss function of the WGAN does not have a log-sigmoid functions comparing to the original GAN.

Gradient penalty.1-Lipschitz constraint is equivalent to the norm of gradients \(\Vert \nabla _{{x}}{D}({x})\Vert _2\le 1 \) everywhere. The gradient penalty term is defined as:

$$\begin{aligned} GP={\mathbb {E}}_{{x}\sim {{X}}}\left[ (\Vert \nabla _{{x}}{D}({x})\Vert _2-1)^{2}\right] \end{aligned}$$
(7)

In WGAN-GP [5], they add an extra gradient penalty term to the discriminator loss function. The loss function for the discriminator is minimizing:

$$\begin{aligned} W^{\left( D\right) }\left( X_r,X_g\right)&= {\mathbb {E}}_{x_r\sim {X_r}}\left[ D\left( x_r\right) \right] -{\mathbb {E}}_{x_g\sim {X_g}}\left[ D\left( x_g\right) \right] \nonumber \\&\quad +\lambda {\mathbb {E}}_{{\hat{x}}\sim {{\hat{X}}}}\left[ (\Vert \nabla _{{\hat{x}}}{D}({\hat{x}})\Vert _2-1)^{2}\right] \end{aligned}$$
(8)

where \({\widehat{x}}=\alpha x_r+\left( 1-\alpha \right) x_g,\alpha \sim U(0,1) \), which we refer to as “model interpolation,” \(\lambda \) is a hyperparameter of the penalty extent.

Gradient penalty is only applied in the discriminator loss. The loss function for generator is minimizing:

$$\begin{aligned} W^{\left( G\right) }\left( X_g\right) =-{\mathbb {E}}_{x_g\sim {X_g}}\left[ D\left( x_g\right) \right] \end{aligned}$$
(9)

DRAGAN [10] borrowed the idea of gradient penalty from WGAN-GP [5]. [5] indicated the gradient penalty term can be adapted to standard GAN loss function Eq. 1. [10] applied the gradient penalty based on the Wasserstein distance to the original log-sigmoid loss function and [26] demonstrated it is also effective. The loss function for the discriminator is minimizing:

$$\begin{aligned} L^{\left( D\right) }\left( X_r,X_g\right) =&-{\mathbb {E}}_{x_r\sim X_r}\left[ \log \left( D\left( x_r\right) \right) \right] \nonumber \\&-{\mathbb {E}}_{x_g\sim X_g}\left[ \log \left( 1-D\left( x_g\right) \right) \right] \nonumber \\&+\lambda {\mathbb {E}}_{{\hat{x}}\sim {{\hat{X}}}}\left[ (\Vert \nabla _{{\hat{x}}}{D}({\hat{x}})\Vert _2-1)^{2}\right] \end{aligned}$$
(10)

where \({\widehat{x}}=\alpha x_r+\left( 1-\alpha \right) x_\mathrm{noise},\alpha \sim U(0,1),x_\mathrm{noise}\sim p_{noise} \), which we refer to as “noise interpolation.” Although DRAGAN modified the gradient penalty comparing with WGAN-GP, we will not discuss deeply on the difference.

There is no gradient penalty in the generator loss, so the loss function is the same as the original GAN:

$$\begin{aligned} L^{\left( G\right) }\left( X_g\right) =-{\mathbb {E}}_{x_g\sim X_g}\left[ \log \left( D\left( x_g\right) \right) \right] \end{aligned}$$
(11)

With comparison of these loss functions in practice, our enhanced BAGAN uses a DRAGAN-like loss function with the “model interpolation” gradient penalty.

With conditionality.For data augmentation, we need to apply conditional GAN to generate minority-class samples. The architecture of DRAGAN and WGAN-GP is almost the same. Referring to ACGAN [16] and cWGAN-GP [21], we built a feasible architecture for conditional DRAGAN (cDRAGAN). Due to the existence of gradient penalty, we cannot add softmax layer to the end of the discriminator to identify different classes. The output of the discriminator still needs to be an unconstrained real number. In our work, we keep the output of the generator and the discriminator the same as WGAN-GP, whereas we attach the label information into the input of the generator and the discriminator. The label information is expanded by an embedding layer and combined with other inputs by a multiply layer. The loss function for the discriminator:

$$\begin{aligned}&L^{\left( D\right) }\left( X_r,X_g,Y_r\right) =-{\mathbb {E}}_{(x_r, y_r)\sim (X_r,Y_r)}\left[ \log \left( D\left( x_r,y_r\right) \right) \right] \nonumber \\&\quad -{\mathbb {E}}_{(x_g,y_r)\sim (X_g,Y_r)}\left[ \log \left( 1-D\left( x_g,y_r\right) \right) \right] \nonumber \\&\quad +\lambda {\mathbb {E}}_{({\hat{x}},y_r)\sim ({{\hat{X}},Y_r})}\left[ (\Vert \nabla _{({\hat{x}},y_r)}{D}({\hat{x}},y_r)\Vert _2-1)^{2}\right] \end{aligned}$$
(12)

Similar to ACGAN and cWGAN-GP, the generated images use the real labels for training in both G and D. The loss function for the generator:

$$\begin{aligned} L^{\left( G\right) }\left( X_g,Y_r\right) =-{\mathbb {E}}_{(x_g,y_r)\sim (X_g,Y_r)}\left[ \log \left( D\left( x_g,y_r\right) \right) \right] \end{aligned}$$
(13)

Combine with BAGAN.BAGAN has state-of-the-art performance of generating minority-class images on imbalanced datasets. The GAN architecture in BAGAN is just a typical conditional GAN. We noticed that the GAN model inside the BAGAN learning system is completely independent. Referred to the evolution of GANs, there are some improvements on GAN architectures and loss functions to achieve more stable training and better performance. We enhanced the GAN part in BAGAN by adopting the architecture and loss function from the cDRAGAN proposed in the previous section. The loss function is modified by the idea of balanced training from BAGAN. The loss function of the discriminator:

$$\begin{aligned}L^{\left( D\right) }\left( X_r,Z,Y_r,Y_f,Y_\mathrm{wrong}\right)&= -{\mathbb {E}}_{(x_r, y_r)\sim (X_r,Y_r)}\left[ \log \left( D\left( x_r,y_r\right) \right) \right] \nonumber \\&\quad -{\mathbb {E}}_{(z,y_f)\sim (Z,Y_f)}\left[ \log \left( 1-D\left( G(z,y_f),y_f\right) \right) \right] \nonumber \\&\quad -{\mathbb {E}}_{(x_r, y_\mathrm{wrong})\sim (X_r,Y_\mathrm{wrong})}\left[ \log \left( 1-D\left( x_r,y_\mathrm{wrong}\right) \right) \right] \nonumber \\&\quad +\lambda {\mathbb {E}}_{({\hat{x}},y_r)\sim ({{\hat{X}},Y_r})}\left[ (\Vert \nabla _{({\hat{x}},y_r)}{D}({\hat{x}},y_r)\Vert _2-1)^{2}\right] \end{aligned}$$
(14)

where z is a random noise vector \(z\sim N\left( 0,I_{\dim \left( z\right) }\right) \equiv Z~\), \(y_{f}\sim U\{0,1,2,...\}\equiv Y_{f} \) and\(\;y_\mathrm{wrong}\sim U\{0,1,2,...\;\}\equiv Y_\mathrm{wrong} \). Previously, the real labels are shared with the real images and the fake images when training the discriminator. In an imbalanced dataset, the real labels randomly sampled from the dataset are still imbalanced. Hence, the GAN will automatically train more on the majority classes. In practice, if we sample from the stratified real labels for training, the GAN will learn slowly. Referring to BAGAN, we randomly sample a fake label from a balanced-label set \(Y_{f} \) for each fake image. In order to enhance the learning of class information from the real dataset, we add an extra cross-entropy loss of wrongly classified cases. For the gradient penalty term, we borrow the “model interpolation” method from WGAN-GP.

In the setting of balanced training, the loss function of the generator becomes:

$$\begin{aligned} L^{\left( G\right) }\left( Z,Y_{f}\right) =-{\mathbb {E}}\left[ \log \left( D\left( G\left( z,y_{f}\right) \right) \right) \right] \end{aligned}$$
(15)

3.2.2 Enhanced autoencoder

BAGAN has two key steps comparing with ordinary conditional GAN: autoencoder initialization and labeled latent generation. In our work, we design a new autoencoder architecture with an embedding section. In BAGAN, the labeled latent generation is based on the assumption that the latent vectors are normally distributed. This assumption restricts the performance of BAGAN in practice.

1. There might be some overlaps between the latent-vector distributions of different classes Fig. 2. The result is the generated samples based on the intersected latent vectors look like the mixed-class images. In application, we cannot feed a random latent vector into generator to get images by class. Instead, we must calculate a labeled latent vector by means and covariances of encoded training data.

Fig. 2
figure 2

Distributions of latent vectors in different classes are overlapped

2. The autoencoder does not learn the label information directly in BAGAN. The latent vectors encoded by the autoencoder cannot disperse their own classes. The labeled latent vectors are defined and restricted by their overlapped distributions, i.e., the label information is unclear. Then, the rough label information attached to the latent vectors will mislead the later GAN training. Furthermore, even if we have a perfectly dispersed latent vectors, the labeled latent vectors are only suitable to the trained decoder. Along with the GAN training, the generator (pretrained decoder) will be updated. However, after the autoencoder initialization, the distributions of labeled latent vectors cannot be updated anymore when we train the later GAN model. In our work, we use an embedding model to generate labeled latent vectors. (Figs. 3, 4, and 5)

Fig. 3
figure 3

Autoencoder with an intermediate embedding model. Our proposed autoencoder is supervised. The label information is embedded to a dense vector with the same size of the latent vector. Then, we apply a multiply layer to combine these two vectors as a labeled latent vector

Fig. 4
figure 4

GAN architecture and our proposed generator. Our proposed generator is an aggregate model of the pretrained embedding model and decoder model. We feed a random latent vector and a random label into the generator and get a generated image in specific class. The embedding model inside the generator can be updated with GAN training

Fig. 5
figure 5

The discriminator architecture is similar to cWGAN-GP. Our proposed discriminator is an extended model of the pretrained encoder. To note, the discriminator does not use the whole encoder model. Excluding the output layer in decoder, we adopt the second-last output (feature map) and combine the feature map with the embedded labels as a new dense vector. The output of the discriminator is an unconstrained real number, which indicates the total validity of real/fake and class-matching

4 Experiments and results

The optimizer for our models in this work is Adam algorithm with learning rate 0.0002 and momentum (0.5, 0.9). The size of mini-batches is 128. All the image inputs will be resized as \(64\times 64\times channels \). The dimension of default latent vector is 128. We only use batch normalization in the generator/decoder. Except the generator’s output activation function is tanh while the discriminator’s is linear, other activation functions are LeakyReLU with threshold 0.2. Quality of generated images is measured by Fréchet Inception Distance. The framework of all experiments is Keras with TensorFlow backend. We use an NVIDIA Tesla P4 GPU with 8GB memory. Most of our results are trained within 3600s. For Cells dataset, we train 100 epochs and each epoch takes 18s on our device. For MNIST Fashion dataset, we train 15 epochs and each epoch takes 154s on our device. For CIFAR-10 dataset, we train 30 epochs and each epoch takes 129s on our device.

Note. In each figure of representative images at this section, the first row (\(row=0 \)) shows real images by class. For each column, we feed the generator with class label \(c_\mathrm{column} \). Start from the second row, we feed the generator with a fixed noise vector \(z_\mathrm{row-1} \). The generated images in this figure are derived by

$$\begin{aligned} Im\left( \mathrm row>0,column\ge 0\right) =G\left( z_\mathrm{row-1},c_\mathrm{column}\right) \end{aligned}$$
(16)

4.1 MNIST fashion & CIFAR-10

Fig. 6
figure 6

Representative samples generated in the MNIST Fashion. The order of these images follows Eq. (16)

Table 1 Class weight of MNIST Fashion (balanced & imbalanced)

We start with our experiments on two well-known balanced datasets, MNIST Fashion and CIFAR-10. We first sample 70% of images as the training set for generative models (A for MNIST Fashion Table 1, C for CIFAR-10 Table 2). To exemplify the quality of minority-class generation, an imbalanced version (B for MNIST Fashion Table 1, D for CIFAR-10 Table 2) is created manually for comparison. We observe our model works perfectly not only on the balanced datasets (A, C), but also on the highly imbalanced datasets (B, D). From the representative images Figs. 6 and 7 generated with imbalanced datasets, we cannot easily figure out which column is minority class. Therefore, our model has a fair training for each class no matter the imbalanced class weight. The learning outcome only depends on the complexity of the image itself. For example, there are 73 trousers and 370 sandals in dataset B. Although the training set of sandals is 5 times as large as trousers, the generated trousers images even have a better quality.

Fig. 7
figure 7

Representative samples generated in the CIFAR-10. The order of these images follows Eq. (16)

Table 2 Class weight of CIFAR-10 (balanced & imbalanced)

The discriminator in our BAGAN-GP has a similar architecture with WGAN-GP. Hence, we can set the train ratio of the discriminator vs the generator to 5 and boost the training with high stability. In the original BAGAN, we cannot set a train ratio larger than 1. Otherwise, the training of BAGAN will be oscillated. In other words, the stability of BAGAN requires a competitive relation between the generator and the discriminator, while our BAGAN-GP only pursues a powerful discriminator to lead the generator. Furthermore, our BAGAN-GP still performs excellently when we only initialize the generator because a good generator will accelerate the learning process of the discriminator.

4.2 Medical image dataset: cells

Cells dataset is a highly imbalanced medical-image dataset, which contains one majority class and three minority classes Table 3, i.e., “red blood cell,” “ring,” “schizont,” and “trophozoite,” respectively. Except the first type, the rest of the cells indicate different stages of malaria infection.

Table 3 Class weight of Cells dataset
Fig. 8
figure 8

Real images per class of Cells dataset

Unlike the images of MNIST Fashion and CIFAR-10, these four classes are different types of red blood cells Fig. 8. It means they look similar but some different in specific features. Visually, it is hard to distinguish some type 2 cells with type 3 cells.

Fig. 9
figure 9

Generated images by BAGAN (left) and BAGAN-GP (right). The order of these images follows Eq. (16)

In Fig. 9, we observe that BAGAN is trying to improve the minority-class generation by sacrificing the quality of majority class. It is exactly the objective of BAGAN, but we are not satisfied on this result. With our BAGAN-GP, all types of cells are generated in high quality. In Sect. 5, we will quantitatively analyze the performance of our model.

Fig. 10
figure 10

Two-dimensional t-SNE plot of the encoded latent vectors. Left: Encoder of BAGAN. Middle: Encoder of the enhanced BAGAN-GP (ours). Right: Encoder + Embedding (ours)

In practice, BAGAN is unstable to train on some imbalanced datasets, especially the medical images datasets, e.g., Cells dataset in our experiment. The encoder of the original BAGAN cannot translate the input images into dispersed groups of latent vectors Fig. 10. Then, the labeled latent vectors are generated by the distribution of these undivided latent vectors. Thus, the later GAN model will fail to generate images in different classes due to the misleading labeled latent vectors. With our enhanced autoencoder, we observe that BAGAN becomes stable in training and it is not sensitive to the GAN architecture and hyperparameters.

Fig. 11
figure 11

Comparing the real samples (o) and generated samples (x) by the feature layer output via ResNet-50

At the feature-level cognition of ResNet-50 Fig. 11, the generated samples can be regarded as effective augmented images. Furthermore, we observe the generated images manifold are equally distributed around the real images manifold. It means, for each class, our generator is not creating one or few modes of images. In other words, the generator comprehensively learns the real data distribution and does not suffer the problem of mode collapse.

5 Evaluation

Table 4 FID: Compare with real samples (in validation set).
  • Metric: Fréchet Inception Distance.

There are two common metrics to evaluate the quality of the generated images: Inception Score (IS) [27] and Fréchet Inception Distance (FID) [28]. Both of these two measurements are based on the Inception V3 network, which is pretrained on ImageNet dataset. IS is derived from the classification logits, while FID is derived from the feature layer. IS only measures the distance between the generated sample distribution and the ImageNet distribution, whereas FID calculates the feature-level distance between the generated sample distribution and the real sample distribution. In this work, our objective datasets, medical image datasets, are quite different from ImageNet dataset. Therefore, we adopt FID as the evaluation metric. Fréchet Distance is defined as:

$$\begin{aligned} FID=\Vert \mu _r-\mu _g\Vert ^{2}+Tr\left( \varSigma _r+\varSigma _g-2\left( \varSigma _r\varSigma _g\right) ^{1/2}\right) \end{aligned}$$

where \(\mu _r \) is the mean of the real features, \(\mu _g \) is the mean of the generated features, \(\varSigma _r \) is the covariance matrix of the real features, \(\varSigma _g \) is the covariance matrix of the generated features.

  • FID on Cells. Table 4

All FID scores are calculated by the real samples from validation set and the target samples. For comparison, we introduce two baseline FID scores: the reconstructed samples by autoencoder and the real samples from training set. The FID of reconstructed samples is regarded as a lower baseline and the FID of real samples is regarded as an upper baseline. The quality of target samples is higher when its FID is lower.

In the Cells dataset, BAGAN can only generate poor samples. Its performance is only better than autoencoder. As we construct our BAGAN-GP model, we first build a cDRAGAN model Eqs. (12) and (13) and combine cDRAGAN with BAGAN framework to get our final model. We need to demonstrate that the combined model is better than the previous independent models. cDRAGAN can generate majority-class images with high quality and ignore the minority, which is the drawback of non-BAGAN. When we apply autoencoder initialization to cDRAGAN and keep the same loss function, the BAGAN-GP (v1) can further improve the quality of the majority but there is no improvement on the minority.

Note on BAGAN-GP. (v1): using real labels for generated images Eqs. (12) and (13). (v2): feeding balancing labels in generator at training Eqs. (14) and (15). (v3): replacing BAGAN original encoder by our encoder. (100/200): the training epochs. 100 epochs for 1800s, and 200 epochs for 3600s.

Comparing BAGAN-GP (v1) with BAGAN-GP (v2), there is a negative effect on the majority-class generation when we apply balanced training to generator, which is analogous to BAGAN. However, the improvement on minority-class generation is significant, while the negative effect on majority-class generation is small. If our purpose is generating minority-class images, it is recommended to use balanced training (v2). Otherwise, we can omit the balanced training step to generate highest quality images of the majority class. Many traditional GANs will fail to converge with a long training time. Thanks to the gradient penalty term, our BAGAN-GP is stable during a long training period. We observe the longer training on BAGAN-GP, the better overall performance it will achieve.

Although BAGAN-GP is stable with less hyperparameter tuning, here we give some suggestions to build a better BAGAN-GP for future work. In our experiments, we observe it is not recommended to set a high latent dimension and a complex embedding model. Besides, we suggest the discriminator does not need to inherit the weights from the pretrained encoder. The potential reason is the pretrained encoder is not powerful without the embedding part.

6 Conclusion

In this work, we proposed a new architecture of BAGAN with gradient penalty in loss function. With gradient penalty term, we have a more stable BAGAN in training. For the autoencoder initialization, we proposed a supervised autoencoder with an intermediate embedding model to learn the label information directly, which helps to encode the similar but different-class images dispersedly.

We compared the enhanced BAGAN-GP against the original BAGAN. From the dispersion of labeled latent vectors to the quality of generated images, our model has stronger performance than the original BAGAN. Besides, our model can handle minority-class generation in a wide range of datasets, including medical image datasets.

  • Future work

We observe our model can generate images in different classes unambiguously. If we can transfer the class knowledge from generative models to classification models, we believe it will significantly improve the performance of classifiers on imbalanced datasets.

We only use the plain dataset to train the GAN model in this work. In practice, we can apply data augmentation in the step of GAN training, there will be a further improvement on the final results.

There are many research topics dealing with the scarcity of data, such as data augmentation, few-shot, and zero-shot learning. We hope our work can broaden the ideas in these topics.