train
accuracies = []
losses = []
def train(iterations, batch_size, sample_interval):
(X_train, y_train), (_, _) = mnist.load_data()
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for iteration in range(iterations):
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs, labels = X_train[idx], y_train[idx]
z = np.random.normal(0, 1, (batch_size, z_dim))
gen_imgs = generator.predict([z, labels])
d_loss_real = discriminator.train_on_batch([imgs, labels], real)
d_loss_fake = discriminator.train_on_batch([gen_imgs, labels], fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
z = np.random.normal(0, 1, (batch_size, z_dim))
labels = np.random.randint(0, num_classes, batch_size).reshape(-1, 1)
g_loss = cgan.train_on_batch([z, labels], real)
if (iteration + 1) % sample_interval == 0:
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %
(iteration + 1, d_loss[0], 100 * d_loss[1], g_loss))
losses.append((d_loss[0], g_loss))
accuracies.append(100 * d_loss[1])
sample_images()