Building GANs from scratch in python12 min read·21 hours agoPhoto by Michael & Diane Weidner on UnsplashThe idea of Generative Adversarial Networks, or GANs, was introduced by Goodfellow and his colleagues [1] in 2014, and shortly after that became extremely popular in the field of computer vision and image generation. Despite the last 10 years of rapid development within the domain of AI and growth of the number of new algorithms, the simplicity and brilliance of this concept are still extremely impressive. So today I want to illustrate how powerful these networks can be by attempting to remove clouds from satellite RGB (Red, Green, Blue) images.Preparation of a properly balanced, big enough and correctly pre-processed CV dataset takes a solid amount of time, so I decided to explore what Kaggle has to offer. The dataset I found the most appropriate for this task is EuroSat [2], which has an open license. It comprises 27000 labeled RGB images 64×64 pixels from Sentinel-2 and is built for solving the multiclass classification problem.EuroSat dataset imagery example. License.We are not interested in classification itself, but one of the main features of the EuroSat dataset is that all its images have a clear sky. That‘s exactly what we need. Adopting this approach from [3], we will use these Sentinel-2 shots as targets and create inputs by adding noise (clouds) to them.So let’s prepare our data before actually talking about GANs. Firstly, we need to download the data and merge all the classes into one directory.🐍The full python code: GitHub.import numpy as npimport pandas as pdimport randomfrom os import listdir, mkdir, renamefrom os.path import join, existsimport shutilimport datetimeimport matplotlib.pyplot as pltfrom highlight_text import ax_text, fig_textfrom PIL import Imageimport warningswarnings.filterwarnings(‘ignore’)classes = listdir(‘./EuroSat’)path_target = ‘./EuroSat/all_targets’path_input = ‘./EuroSat/all_inputs'”””RUN IT ONLY ONCE TO RENAME THE FILES IN THE UNPACKED ARCHIVE”””mkdir(path_input)mkdir(path_target)k = 1for kind in classes:path = join(‘./EuroSat’, str(kind))for i, f in enumerate(listdir(path)):shutil.copyfile(join(path, f),join(path_target, f))rename(join(path_target, f), join(path_target, f'{k}.jpg’))k += 1The second important step is generating noise. Whereas you can use different approaches, e.g. randomly masking out some pixels, adding some Gaussian noise, in this article I want to try a new thing for me — Perlin noise. It was invented in the 80s by Ken Perlin [4] when developing cinematic smoke effects. This kind of noise has a more organic appearance compared to regular random noise. Just let me prove it.def generate_perlin_noise(width, height, scale, octaves, persistence, lacunarity):noise = np.zeros((height, width))for i in range(height):for j in range(width):noise[i][j] = pnoise2(i / scale,j / scale,octaves=octaves,persistence=persistence,lacunarity=lacunarity,repeatx=width,repeaty=height,base=0)return noisedef normalize_noise(noise):min_val = noise.min()max_val = noise.max()return (noise – min_val) / (max_val – min_val)def generate_clouds(width, height, base_scale, octaves, persistence, lacunarity):clouds = np.zeros((height, width))for octave in range(1, octaves + 1):scale = base_scale / octavelayer = generate_perlin_noise(width, height, scale, 1, persistence, lacunarity)clouds += layer * (persistence ** octave)clouds = normalize_noise(clouds)return cloudsdef overlay_clouds(image, clouds, alpha=0.5):clouds_rgb = np.stack([clouds] * 3, axis=-1)image = image.astype(float) / 255.0clouds_rgb = clouds_rgb.astype(float)blended = image * (1 – alpha) + clouds_rgb * alphablended = (blended * 255).astype(np.uint8)return blendedwidth, height = 64, 64octaves = 12 #number of noise layers combinedpersistence = 0.5 #lower persistence reduces the amplitude of higher-frequency octaveslacunarity = 2 #higher lacunarity increases the frequency of higher-frequency octavesfor i in range(len(listdir(path_target))):base_scale = random.uniform(5,120) #noise frequencyalpha = random.uniform(0,1) #transparencyclouds = generate_clouds(width, height, base_scale, octaves, persistence, lacunarity)img = np.asarray(Image.open(join(path_target, f'{i+1}.jpg’)))image = Image.fromarray(overlay_clouds(img,clouds, alpha))image.save(join(path_input,f'{i+1}.jpg’))print(f’Processed {i+1}/{len(listdir(path_target))}’)idx = np.random.randint(27000)fig,ax = plt.subplots(1,2)ax[0].imshow(np.asarray(Image.open(join(path_target, f'{idx}.jpg’))))ax[1].imshow(np.asarray(Image.open(join(path_input, f'{idx}.jpg’))))ax[0].set_title(“Target”)ax[0].axis(‘off’)ax[1].set_title(“Input”)ax[1].axis(‘off’)plt.show()Image by author.As you can see above, the clouds on the images are very realistic, they have different “density” and texture resembling the real ones.If you are intrigued by Perlin noise as I was, here is a really cool video on how this noise can be applied in the GameDev industry:Since now we have a ready-to-use dataset, let’s talk about GANs.To better illustrate this idea, let’s imagine that you’re traveling around South-East Asia and find yourself in an urgent need of a hoodie, since it’s too cold outside. Coming to the closest street market, you find a small shop with some branded clothes. The seller brings you a nice hoodie to try on saying that it’s the famous brand ExpensiveButNotWorthIt. You take a closer look and conclude that it’s obviously a fake. The seller says: ‘Wait a sec, I have the REAL one. He returns with another hoodie, which looks more like the branded one, but still a fake. After several iterations like this, the seller brings an indistinguishable copy of the legendary ExpensiveButNotWorthIt and you readily buy it. That’s basically how the GANs work!In the case of GANs, you are called a discriminator (D). The goal of a discriminator is to distinguish between a true object and a fake one, or to solve the binary classification task. The seller is called a generator (G), since he’s trying to generate a high-quality fake. The discriminator and generator are trained independently to outperform each other. Hence, in the end we get a high-quality fake.GANs architecture. License.The training process originally looks like this:Sample input noise (in our case images with clouds).Feed the noise to G and collect the prediction.Calculate the D loss by getting 2 predictions one for G’s output and another for the real data.Update D’s weights.Sample input noise again.Feed the noise to G and collect the prediction.Calculate the G loss by feeding its prediction to D.Update G’s weights.GANs training loop. Source: [1].In other words we can define a value function V(G,D):Source: [1].where we want to minimize the term log(1-D(G(z))) to train G and maximize log D(x) to train D (in this notation x — real data sample and z — noise).Now let’s try to implement it in pytorch!In the original paper authors talk about using Multilayer Perceptron (MLP); it’s also often referred simply as ANN, but I want to try a little bit more complicated approach — I want to use the UNet [5] architecture as a Generator and ResNet [6] as a Discriminator. These are both well-known CNN architectures, so I won’t be explaining them here (let me know if I should write a separate article in the comments).Let’s build them. Discriminator:import torchimport torch.nn as nnimport torch.optim as optimimport torch.nn.functional as Ffrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transformsfrom torch.utils.data import Subsetclass ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride = 1, downsample = None):super(ResidualBlock, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),nn.BatchNorm2d(out_channels),nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(out_channels))self.downsample = downsampleself.relu = nn.ReLU()self.out_channels = out_channelsdef forward(self, x):residual = xout = self.conv1(x)out = self.conv2(out)if self.downsample:residual = self.downsample(x)out += residualout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block=ResidualBlock, all_connections=[3,4,6,3]):super(ResNet, self).__init__()self.inputs = 16self.conv1 = nn.Sequential(nn.Conv2d(3, 16, kernel_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(16),nn.ReLU()) #16x64x64self.maxpool = nn.MaxPool2d(kernel_size = 2, stride = 2) #16x32x32self.layer0 = self.makeLayer(block, 16, all_connections[0], stride = 1) #connections = 3, shape: 16x32x32self.layer1 = self.makeLayer(block, 32, all_connections[1], stride = 2)#connections = 4, shape: 32x16x16self.layer2 = self.makeLayer(block, 128, all_connections[2], stride = 2)#connections = 6, shape: 1281x8x8self.layer3 = self.makeLayer(block, 256, all_connections[3], stride = 2)#connections = 3, shape: 256x4x4self.avgpool = nn.AvgPool2d(4, stride=1)self.fc = nn.Linear(256, 1)def makeLayer(self, block, outputs, connections, stride=1):downsample = Noneif stride != 1 or self.inputs != outputs:downsample = nn.Sequential(nn.Conv2d(self.inputs, outputs, kernel_size=1, stride=stride),nn.BatchNorm2d(outputs),)layers = []layers.append(block(self.inputs, outputs, stride, downsample))self.inputs = outputsfor i in range(1, connections):layers.append(block(self.inputs, outputs))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.maxpool(x)x = self.layer0(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.avgpool(x)x = x.view(-1, 256)x = self.fc(x).flatten()return F.sigmoid(x)Generator:class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels):super(DoubleConv, self).__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)class UNet(nn.Module):def __init__(self):super().__init__()self.conv_1 = DoubleConv(3, 32) # 32x64x64self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2) # 32x32x32self.conv_2 = DoubleConv(32, 64) #64x32x32self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2) #64x16x16self.conv_3 = DoubleConv(64, 128) #128x16x16self.pool_3 = nn.MaxPool2d(kernel_size=2, stride=2) #128x8x8self.conv_4 = DoubleConv(128, 256) #256x8x8self.pool_4 = nn.MaxPool2d(kernel_size=2, stride=2) #256x4x4self.conv_5 = DoubleConv(256, 512) #512x2x2#DECODERself.upconv_1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) #256x4x4self.conv_6 = DoubleConv(512, 256) #256x4x4self.upconv_2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) #128x8x8self.conv_7 = DoubleConv(256, 128) #128x8x8self.upconv_3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) #64x16x16self.conv_8 = DoubleConv(128, 64) #64x16x16self.upconv_4 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) #32x32x32self.conv_9 = DoubleConv(64, 32) #32x32x32self.output = nn.Conv2d(32, 3, kernel_size = 3, stride = 1, padding = 1) #3x64x64def forward(self, batch):conv_1_out = self.conv_1(batch)conv_2_out = self.conv_2(self.pool_1(conv_1_out))conv_3_out = self.conv_3(self.pool_2(conv_2_out))conv_4_out = self.conv_4(self.pool_3(conv_3_out))conv_5_out = self.conv_5(self.pool_4(conv_4_out))conv_6_out = self.conv_6(torch.cat([self.upconv_1(conv_5_out), conv_4_out], dim=1))conv_7_out = self.conv_7(torch.cat([self.upconv_2(conv_6_out), conv_3_out], dim=1))conv_8_out = self.conv_8(torch.cat([self.upconv_3(conv_7_out), conv_2_out], dim=1))conv_9_out = self.conv_9(torch.cat([self.upconv_4(conv_8_out), conv_1_out], dim=1))output = self.output(conv_9_out)return F.sigmoid(output)Now we need to split our data into train/test and wrap them into a torch dataset:class dataset(Dataset):def __init__(self, batch_size, images_paths, targets, img_size = 64):self.batch_size = batch_sizeself.img_size = img_sizeself.images_paths = images_pathsself.targets = targetsself.len = len(self.images_paths) // batch_sizeself.transform = transforms.Compose([transforms.ToTensor(),])self.batch_im = [self.images_paths[idx * self.batch_size:(idx + 1) * self.batch_size] for idx in range(self.len)]self.batch_t = [self.targets[idx * self.batch_size:(idx + 1) * self.batch_size] for idx in range(self.len)]def __getitem__(self, idx):pred = torch.stack([self.transform(Image.open(join(path_input,file_name)))for file_name in self.batch_im[idx]])target = torch.stack([self.transform(Image.open(join(path_target,file_name)))for file_name in self.batch_im[idx]])return pred, targetdef __len__(self):return self.lenPerfect. It’s time to write the training loop. Before doing so, let’s define our loss functions and optimizer:device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)batch_size = 64num_epochs = 15learning_rate_D = 1e-5learning_rate_G = 1e-4discriminator = ResNet()generator = UNet()bce = nn.BCEWithLogitsLoss()l1loss = nn.L1Loss()optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate_D)optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate_G)scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=10, gamma=0.1)scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=10, gamma=0.1)As you can see, these losses are different from the picture with the GAN algorithm. In particular, I added L1Loss. The idea is that we are not simply generating a random image from noise, we want to keep most of the information from the input and just remove noise. So G loss will be:G_loss = log(1 − D(G(z))) + 𝝀 |G(z)-y|instead of justG_loss = log(1 − D(G(z)))𝝀 is an arbitrary coefficient, which balances two components of the losses.Finally, let’s split the data to start the training process:test_ratio, train_ratio = 0.3, 0.7num_test = int(len(listdir(path_target))*test_ratio)num_train = int((int(len(listdir(path_target)))-num_test))img_size = (64, 64)print(“Number of train samples:”, num_train)print(“Number of test samples:”, num_test)random.seed(231)train_idxs = np.array(random.sample(range(num_test+num_train), num_train))mask = np.ones(num_train+num_test, dtype=bool)mask[train_idxs] = Falseimages = {}features = random.sample(listdir(path_input),num_test+num_train)targets = random.sample(listdir(path_target),num_test+num_train)random.Random(231).shuffle(features)random.Random(231).shuffle(targets)train_input_img_paths = np.array(features)[train_idxs]train_target_img_path = np.array(targets)[train_idxs]test_input_img_paths = np.array(features)[mask]test_target_img_path = np.array(targets)[mask]train_loader = dataset(batch_size=batch_size, img_size=img_size, images_paths=train_input_img_paths, targets=train_target_img_path)test_loader = dataset(batch_size=batch_size, img_size=img_size, images_paths=test_input_img_paths, targets=test_target_img_path)Now we can run our training loop:train_loss_G, train_loss_D, val_loss_G, val_loss_D = [], [], [], []all_loss_G, all_loss_D = [], []best_generator_epoch_val_loss, best_discriminator_epoch_val_loss = -np.inf, -np.inffor epoch in range(num_epochs):discriminator.train()generator.train()discriminator_epoch_loss, generator_epoch_loss = 0, 0for inputs, targets in train_loader:inputs, true = inputs, targets”’1. Training the Discriminator (ResNet)”’optimizer_D.zero_grad()fake = generator(inputs).detach()pred_fake = discriminator(fake).to(device)loss_fake = bce(pred_fake, torch.zeros(batch_size, device=device))pred_real = discriminator(true).to(device)loss_real = bce(pred_real, torch.ones(batch_size, device=device))loss_D = (loss_fake+loss_real)/2loss_D.backward()optimizer_D.step()discriminator_epoch_loss += loss_D.item()all_loss_D.append(loss_D.item())”’2. Training the Generator (UNet)”’optimizer_G.zero_grad()fake = generator(inputs)pred_fake = discriminator(fake).to(device)loss_G_bce = bce(pred_fake, torch.ones_like(pred_fake, device=device))loss_G_l1 = l1loss(fake, targets)*100loss_G = loss_G_bce + loss_G_l1loss_G.backward()optimizer_G.step()generator_epoch_loss += loss_G.item()all_loss_G.append(loss_G.item())discriminator_epoch_loss /= len(train_loader)generator_epoch_loss /= len(train_loader)train_loss_D.append(discriminator_epoch_loss)train_loss_G.append(generator_epoch_loss)discriminator.eval()generator.eval()discriminator_epoch_val_loss, generator_epoch_val_loss = 0, 0with torch.no_grad():for inputs, targets in test_loader:inputs, targets = inputs, targetsfake = generator(inputs)pred = discriminator(fake).to(device)loss_G_bce = bce(fake, torch.ones_like(fake, device=device))loss_G_l1 = l1loss(fake, targets)*100loss_G = loss_G_bce + loss_G_l1loss_D = bce(pred.to(device), torch.zeros(batch_size, device=device))discriminator_epoch_val_loss += loss_D.item()generator_epoch_val_loss += loss_G.item()discriminator_epoch_val_loss /= len(test_loader)generator_epoch_val_loss /= len(test_loader)val_loss_D.append(discriminator_epoch_val_loss)val_loss_G.append(generator_epoch_val_loss)print(f”——Epoch [{epoch+1}/{num_epochs}]——\nTrain Loss D: {discriminator_epoch_loss:.4f}, Val Loss D: {discriminator_epoch_val_loss:.4f}”)print(f’Train Loss G: {generator_epoch_loss:.4f}, Val Loss G: {generator_epoch_val_loss:.4f}’)if discriminator_epoch_val_loss > best_discriminator_epoch_val_loss:discriminator_epoch_val_loss = best_discriminator_epoch_val_losstorch.save(discriminator.state_dict(), “discriminator.pth”)if generator_epoch_val_loss > best_generator_epoch_val_loss:generator_epoch_val_loss = best_generator_epoch_val_losstorch.save(generator.state_dict(), “generator.pth”)#scheduler_D.step()#scheduler_G.step()fig, ax = plt.subplots(1,3)ax[0].imshow(np.transpose(inputs.numpy()[7], (1,2,0)))ax[1].imshow(np.transpose(targets.numpy()[7], (1,2,0)))ax[2].imshow(np.transpose(fake.detach().numpy()[7], (1,2,0)))plt.show()After the code is finished we can plot the losses. This code was partly adopted from this cool website:from matplotlib.font_manager import FontPropertiesbackground_color = ‘#001219’font = FontProperties(fname=’LexendDeca-VariableFont_wght.ttf’)fig, ax = plt.subplots(1, 2, figsize=(16, 9))fig.set_facecolor(background_color)ax[0].set_facecolor(background_color)ax[1].set_facecolor(background_color)ax[0].plot(range(len(all_loss_G)), all_loss_G, color=’#bc6c25′, lw=0.5) ax[1].plot(range(len(all_loss_D)), all_loss_D, color=’#00b4d8′, lw=0.5)ax[0].scatter([np.array(all_loss_G).argmax(), np.array(all_loss_G).argmin()],[np.array(all_loss_G).max(), np.array(all_loss_G).min()],s=30, color=’#bc6c25′,)ax[1].scatter([np.array(all_loss_D).argmax(), np.array(all_loss_D).argmin()],[np.array(all_loss_D).max(), np.array(all_loss_D).min()],s=30, color=’#00b4d8′,)ax_text(np.array(all_loss_G).argmax()+60, np.array(all_loss_G).max()+0.1,f'{round(np.array(all_loss_G).max(),1)}’,fontsize=13, color=’#bc6c25′,font=font,ax=ax[0])ax_text(np.array(all_loss_G).argmin()+60, np.array(all_loss_G).min()-0.1,f'{round(np.array(all_loss_G).min(),1)}’,fontsize=13, color=’#bc6c25′,font=font,ax=ax[0])ax_text(np.array(all_loss_D).argmax()+60, np.array(all_loss_D).max()+0.01,f'{round(np.array(all_loss_D).max(),1)}’,fontsize=13, color=’#00b4d8′,font=font,ax=ax[1])ax_text(np.array(all_loss_D).argmin()+60, np.array(all_loss_D).min()-0.005,f'{round(np.array(all_loss_D).min(),1)}’,fontsize=13, color=’#00b4d8′,font=font,ax=ax[1])for i in range(2):ax[i].tick_params(axis=’x’, colors=’white’)ax[i].tick_params(axis=’y’, colors=’white’)ax[i].spines[‘left’].set_color(‘white’) ax[i].spines[‘bottom’].set_color(‘white’) ax[i].set_xlabel(‘Epoch’, color=’white’, fontproperties=font, fontsize=13)ax[i].set_ylabel(‘Loss’, color=’white’, fontproperties=font, fontsize=13)ax[0].set_title(‘Generator’, color=’white’, fontproperties=font, fontsize=18)ax[1].set_title(‘Discriminator’, color=’white’, fontproperties=font, fontsize=18)plt.savefig(‘Loss.jpg’)plt.show()# ax[0].set_axis_off()# ax[1].set_axis_off()