Generative Adversarial Networks (GANs) are a type of deep learning framework introduced by Ian Goodfellow and colleagues in 2014. GANs are used for generating new data samples similar to a given dataset. They consist of two neural networks: a generator and a discriminator, which are trained simultaneously in a competitive manner.
1. Generator: Takes random noise as input and generates fake data samples.
2. Discriminator: Takes both real and generated data samples as input and predicts whether the samples are real or fake.
3. Adversarial Training: The generator and discriminator are trained alternately: the generator aims to fool the discriminator by generating realistic samples, while the discriminator learns to distinguish between real and fake samples.
1. Generator Training: Update the generator to minimize the discriminator's ability to distinguish between real and generated samples.
2. Discriminator Training: Update the discriminator to better distinguish between real and generated samples.
Let's implement a simple GAN using TensorFlow/Keras to generate handwritten digits similar to those in the MNIST dataset. 👇👇
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import LeakyReLU, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
# Load the MNIST dataset
(X_train, _), (_, _) = mnist.load_data()
# Normalize the data
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = X_train.reshape(X_train.shape[0], 784)
# Define the generator model
generator = Sequential([
Dense(256, input_dim=100),
LeakyReLU(alpha=0.2),
BatchNormalization(),
Dense(512),
LeakyReLU(alpha=0.2),
BatchNormalization(),
Dense(1024),
LeakyReLU(alpha=0.2),
BatchNormalization(),
Dense(784, activation='tanh'),
Reshape((28, 28))
])
# # Define the discriminator model
discriminator = Sequential([
Flatten(input_shape=(28, 28)),
Dense(1024),
LeakyReLU(alpha=0.2),
Dense(512),
LeakyReLU(alpha=0.2),
Dense(256),
LeakyReLU(alpha=0.2),
Dense(1, activation='sigmoid')
])
# # Compile the discriminator
discriminator.compile(optimizer=Adam(learning_rate=0.0002, beta_1=0.5),
loss='binary_crossentropy', metrics=['accuracy'])
# Compile the GAN model
discriminator.trainable = False
gan_input = Input(shape=(100,))
x = generator(gan_input)
gan_output = discriminator(x)
gan = Model(gan_input, gan_output)
gan.compile(optimizer=Adam(learning_rate=0.0002, beta_1=0.5),
loss='binary_crossentropy')
# # Function to train the GAN
def train_gan(epochs=1, batch_size=128):
# Calculate the number of batches per epoch
batch_count = X_train.shape[0] // batch_size
for e in range(epochs):
for _ in range(batch_count):
# Generate random noise as input for the generator
noise = np.random.normal(0, 1, size=[batch_size, 100])
# Generate fake images using the generator
generated_images = generator.predict(noise)
# Get a random batch of real images from the dataset
batch_idx = np.random.randint(0, X_train.shape[0], batch_size)
real_images = X_train[batch_idx]
# Concatenate real and fake images
X = np.concatenate([real_images, generated_images])
# Labels for generated and real data
y_dis = np.zeros(2 * batch_size)
y_dis[:batch_size] = 0.9 # One-sided label smoothing
# Train the discriminator
discriminator.trainable = True
d_loss = discriminator.train_on_batch(X, y_dis)
# Train the generator (via the GAN model)
noise = np.random.normal(0, 1, size=[batch_size, 100])
y_gen = np.ones(batch_size)
discriminator.trainable = False
g_loss = gan.train_on_batch(noise, y_gen)
# Print the progress and save the generated images
print(f"Epoch {e+1}, Discriminator Loss: {d_loss[0]},
Generator Loss: {g_loss}")
if e % 10 == 0:
plot_generated_images(e, generator)
# # Function to plot generated images
def plot_generated_images(epoch, generator, examples=10, dim=(1, 10),
figsize=(10, 1)):
noise = np.random.normal(0, 1, size=[examples, 100])
generated_images = generator.predict(noise)
generated_images = generated_images.reshape(examples, 28, 28)
plt.figure(figsize=figsize)
for i in range(examples):
plt.subplot(dim[0], dim[1], i+1)
plt.imshow(generated_images[i], interpolation='nearest',
cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig(f'gan_generated_image_epoch_{epoch}.png')
plt.show()
# # Train the GAN
train_gan(epochs=100, batch_size=128)
No comments:
Post a Comment