Convolutional neural networks

Thu 22 April 2021

In [37]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torchvision.datasets as datasets
import pandas as pd
import os
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor
from PIL import Image
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms, utils
import torchvision.transforms.functional as TF
import glob
from sklearn.model_selection import KFold
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import gc
plt.rcParams['figure.figsize'] = [12, 8]
In [2]:
!rm -rf data/
!curl -L -o data.zip https://www.dropbox.com/sh/jhvugjvowcnuovb/AADHwtqWk7p2y7KkO8OUg_lha?dl=0
!unzip -oq data.zip -d data
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   151    0   151    0     0    269      0 --:--:-- --:--:-- --:--:--   268
  0     0    0     0    0     0      0      0 --:--:--  0:00:01 --:--:--     0
100 14.2M  100 14.2M    0     0  2861k      0  0:00:05  0:00:05 --:--:-- 4859k
warning:  stripped absolute path spec from /
mapname:  conversion of  failed
In [18]:
# Hyperperamiters
LEARNING_RATE = 0.001
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda:0" if USE_CUDA else "cpu")
NUM_EPOCHS = 20
NUM_WORKERS = 2
BATCH_SIZE = 4
#IMAGE_HEIGHT = 160
#IMAGE_WIDTH = 160  
RBG_CHANNELS = 1 # all channels
WEIGHT_DECAY = 0.01 # weight decay values
In [4]:
class RetinalDataSet(Dataset):
    def __init__(self, image_paths, target_paths, train=True, image_height=None, image_width=None):
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.image_height = image_height
        self.image_width = image_width

    def transform(self, image, mask):
        # Resize
        if self.image_height and self.image_width is not None:
            resize_image = transforms.Resize((self.image_height, self.image_width))
            image = resize_image(image)
            mask = resize_image(mask)
        
        # Transform to tensor
        to_tensor = transforms.ToTensor()
        image = to_tensor(image)
        # Use only the first channel in the image
        # image = image[:1,:,:]
        mask = to_tensor(mask)
        
        return image, mask

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        mask = Image.open(self.target_paths[index])
        image, mask = self.transform(image, mask)
        return image, mask

    def __len__(self):
        return len(self.image_paths)
In [5]:
# Create image and segmentation paths
image_paths = glob.glob('data/Images_train/*.tif')
image_paths.sort()
mask_paths = glob.glob('data/Labels_train/*.gif')
mask_paths.sort()

# Create retinal dataset
retinal_data = RetinalDataSet(image_paths=image_paths,target_paths=mask_paths)
In [15]:
train_set, val_set = torch.utils.data.random_split(retinal_data, [12, 3], generator=torch.Generator().manual_seed(42))
In [20]:
# Print information about dataset
train_set, val_set = torch.utils.data.random_split(retinal_data, [12, 3], generator=torch.Generator().manual_seed(42))


print('Training dataset:')
for i in range(len(train_set)):
  sample = train_set[i]

  print(i, sample[0].shape, sample[1].shape)

print('\nValidation dataset:')
for i in range(len(val_set)):
  sample = val_set[i]

  print(i, sample[0].shape, sample[1].shape)

# Create trainloader and validationloader
trainloader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
validationloader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

print('\nTrainloader:')
for i_batch, sample_batched in enumerate(trainloader):
    print(i_batch, sample_batched[0].size(),
          sample_batched[1].size())

print('\nValidationloader:')
for i_batch, sample_batched in enumerate(validationloader):
    print(i_batch, sample_batched[0].size(),
          sample_batched[1].size())
    
Training dataset:
0 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
1 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
2 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
3 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
4 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
5 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
6 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
7 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
8 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
9 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
10 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
11 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])

Validation dataset:
0 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
1 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])
2 torch.Size([3, 584, 565]) torch.Size([1, 584, 565])

Trainloader:
0 torch.Size([4, 3, 584, 565]) torch.Size([4, 1, 584, 565])
1 torch.Size([4, 3, 584, 565]) torch.Size([4, 1, 584, 565])
2 torch.Size([4, 3, 584, 565]) torch.Size([4, 1, 584, 565])

Validationloader:
0 torch.Size([3, 3, 584, 565]) torch.Size([3, 1, 584, 565])
In [38]:
###### 2. visualize images ######
# functions to show an image
def imshow(img, channel=None):
    img = img   
    npimg = img.cpu().numpy()
    
    if channel is None:
      plt.imshow(np.transpose(npimg, (1, 2, 0)))  
    else:
      plt.imshow(np.transpose(npimg, (1, 2, 0))[:,:,channel])
    
    plt.show()

# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images in the 3 rgb channels
print('All three channels')
imshow(torchvision.utils.make_grid(images))
print('Only the first channel')
imshow(torchvision.utils.make_grid(images),0)
print('Only the second channel')
imshow(torchvision.utils.make_grid(images),1)
print('Only the third channel')
imshow(torchvision.utils.make_grid(images),2)
print('The corresponding segmentations')
imshow(torchvision.utils.make_grid(labels),0)
All three channels
Only the first channel
Only the second channel
Only the third channel
The corresponding segmentations
In [26]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), # Making it a "Same" convolution
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # The downward part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # The upward part of the UNET model
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        # The bottom part
        self.bottom = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    # The downward part
    def forward(self, x):
        skipped_connections = []

        for down in self.downs:
            x = down(x)
            skipped_connections.append(x)
            x = self.pool(x)

        x = self.bottom(x)
        skipped_connections = skipped_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skipped_connection = skipped_connections[idx//2]


            # Make sure shapes dimensions of input is divisible by 2
            if x.shape != skipped_connection.shape:
                x = TF.resize(x, size=skipped_connection.shape[2:]) # because of the step of 2 above

            concat_skip = torch.cat((skipped_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

model = UNET(in_channels=3, out_channels=1).to(device=DEVICE)
In [27]:
# Loss function and optimization choices
criterion = nn.BCEWithLogitsLoss() # Using logits because then the sigmoid is not needed at the end
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
In [28]:
# free up cuda memory
gc.collect()
torch.cuda.empty_cache()
In [29]:
##### 7. loss evolving over epochs #####
for epoch in range(NUM_EPOCHS):
  
  running_loss = 0.0
  for i, (inputs, labels) in enumerate(trainloader, 0):
  
  # indicate that we are training
      model.train()
      
      # send inputs to DEVICE
      inputs, labels = inputs.to(device=DEVICE), labels.to(device=DEVICE)

      #if RBG_CHANNELS is not None:
       #   inputs = inputs[:, :RBG_CHANNELS,:,:]
      
      # Zero the parameter gradients
      optimizer.zero_grad()
      
      # Feed the model
      outputs = model(inputs)

      #dice = dice_coefficient(outputs, labels)
      
      # compute loss
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      # print statistics
      running_loss += loss.item() 
      if epoch % 20 == 0:
          correct = 0
          total = 0
          for val_images, val_labels in validationloader:
                  preds = model(val_images.to(DEVICE))
                  preds[preds >= 0.5] = 1
                  preds[preds < 0.5] = 0

                  val_labels = val_labels.to(DEVICE)
                  total += val_labels.shape[0] * val_labels.shape[1] * val_labels.shape[2] * val_labels.shape[3]

                  correct += (preds == val_labels).sum()

                  #imshow(torchvision.utils.make_grid(preds),0)
          accuracy = round((correct/total).item(),2)

          print(f'[Epoch:{epoch + 1}, i:{i+1}], running loss: {running_loss / 2}, accuracy: {accuracy}')
          running_loss = 0.0
      
print(f'Finished Training')
[Epoch:1, i:1], running loss: 0.3038671612739563, accuracy: 0.89
[Epoch:1, i:2], running loss: 0.26890453696250916, accuracy: 0.92
[Epoch:1, i:3], running loss: 0.2546672821044922, accuracy: 0.92
Finished Training
In [46]:
# Visualize the validation results 

preds = model(val_images.to(device=DEVICE))
preds[preds >= 0.5] = 1
preds[preds < 0.5] = 0
imshow(torchvision.utils.make_grid(preds),0)