Convolutional neural networks
Thu 22 April 2021
In [37]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torchvision.datasets as datasets
import pandas as pd
import os
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor
from PIL import Image
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms, utils
import torchvision.transforms.functional as TF
import glob
from sklearn.model_selection import KFold
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import gc
plt.rcParams['figure.figsize'] = [12, 8]
In [2]:
!rm -rf data/
!curl -L -o data.zip https://www.dropbox.com/sh/jhvugjvowcnuovb/AADHwtqWk7p2y7KkO8OUg_lha?dl=0
!unzip -oq data.zip -d data
In [18]:
# Hyperperamiters
LEARNING_RATE = 0.001
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda:0" if USE_CUDA else "cpu")
NUM_EPOCHS = 20
NUM_WORKERS = 2
BATCH_SIZE = 4
#IMAGE_HEIGHT = 160
#IMAGE_WIDTH = 160
RBG_CHANNELS = 1 # all channels
WEIGHT_DECAY = 0.01 # weight decay values
In [4]:
class RetinalDataSet(Dataset):
def __init__(self, image_paths, target_paths, train=True, image_height=None, image_width=None):
self.image_paths = image_paths
self.target_paths = target_paths
self.image_height = image_height
self.image_width = image_width
def transform(self, image, mask):
# Resize
if self.image_height and self.image_width is not None:
resize_image = transforms.Resize((self.image_height, self.image_width))
image = resize_image(image)
mask = resize_image(mask)
# Transform to tensor
to_tensor = transforms.ToTensor()
image = to_tensor(image)
# Use only the first channel in the image
# image = image[:1,:,:]
mask = to_tensor(mask)
return image, mask
def __getitem__(self, index):
image = Image.open(self.image_paths[index])
mask = Image.open(self.target_paths[index])
image, mask = self.transform(image, mask)
return image, mask
def __len__(self):
return len(self.image_paths)
In [5]:
# Create image and segmentation paths
image_paths = glob.glob('data/Images_train/*.tif')
image_paths.sort()
mask_paths = glob.glob('data/Labels_train/*.gif')
mask_paths.sort()
# Create retinal dataset
retinal_data = RetinalDataSet(image_paths=image_paths,target_paths=mask_paths)
In [15]:
train_set, val_set = torch.utils.data.random_split(retinal_data, [12, 3], generator=torch.Generator().manual_seed(42))
In [20]:
# Print information about dataset
train_set, val_set = torch.utils.data.random_split(retinal_data, [12, 3], generator=torch.Generator().manual_seed(42))
print('Training dataset:')
for i in range(len(train_set)):
sample = train_set[i]
print(i, sample[0].shape, sample[1].shape)
print('\nValidation dataset:')
for i in range(len(val_set)):
sample = val_set[i]
print(i, sample[0].shape, sample[1].shape)
# Create trainloader and validationloader
trainloader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
validationloader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
print('\nTrainloader:')
for i_batch, sample_batched in enumerate(trainloader):
print(i_batch, sample_batched[0].size(),
sample_batched[1].size())
print('\nValidationloader:')
for i_batch, sample_batched in enumerate(validationloader):
print(i_batch, sample_batched[0].size(),
sample_batched[1].size())
In [38]:
###### 2. visualize images ######
# functions to show an image
def imshow(img, channel=None):
img = img
npimg = img.cpu().numpy()
if channel is None:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
else:
plt.imshow(np.transpose(npimg, (1, 2, 0))[:,:,channel])
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# show images in the 3 rgb channels
print('All three channels')
imshow(torchvision.utils.make_grid(images))
print('Only the first channel')
imshow(torchvision.utils.make_grid(images),0)
print('Only the second channel')
imshow(torchvision.utils.make_grid(images),1)
print('Only the third channel')
imshow(torchvision.utils.make_grid(images),2)
print('The corresponding segmentations')
imshow(torchvision.utils.make_grid(labels),0)
In [26]:
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), # Making it a "Same" convolution
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class UNET(nn.Module):
def __init__(
self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
):
super(UNET, self).__init__()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# The downward part of UNET
for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature
# The upward part of the UNET model
for feature in reversed(features):
self.ups.append(
nn.ConvTranspose2d(
feature*2, feature, kernel_size=2, stride=2,
)
)
self.ups.append(DoubleConv(feature*2, feature))
# The bottom part
self.bottom = DoubleConv(features[-1], features[-1]*2)
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
# The downward part
def forward(self, x):
skipped_connections = []
for down in self.downs:
x = down(x)
skipped_connections.append(x)
x = self.pool(x)
x = self.bottom(x)
skipped_connections = skipped_connections[::-1]
for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
skipped_connection = skipped_connections[idx//2]
# Make sure shapes dimensions of input is divisible by 2
if x.shape != skipped_connection.shape:
x = TF.resize(x, size=skipped_connection.shape[2:]) # because of the step of 2 above
concat_skip = torch.cat((skipped_connection, x), dim=1)
x = self.ups[idx+1](concat_skip)
return self.final_conv(x)
model = UNET(in_channels=3, out_channels=1).to(device=DEVICE)
In [27]:
# Loss function and optimization choices
criterion = nn.BCEWithLogitsLoss() # Using logits because then the sigmoid is not needed at the end
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
In [28]:
# free up cuda memory
gc.collect()
torch.cuda.empty_cache()
In [29]:
##### 7. loss evolving over epochs #####
for epoch in range(NUM_EPOCHS):
running_loss = 0.0
for i, (inputs, labels) in enumerate(trainloader, 0):
# indicate that we are training
model.train()
# send inputs to DEVICE
inputs, labels = inputs.to(device=DEVICE), labels.to(device=DEVICE)
#if RBG_CHANNELS is not None:
# inputs = inputs[:, :RBG_CHANNELS,:,:]
# Zero the parameter gradients
optimizer.zero_grad()
# Feed the model
outputs = model(inputs)
#dice = dice_coefficient(outputs, labels)
# compute loss
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if epoch % 20 == 0:
correct = 0
total = 0
for val_images, val_labels in validationloader:
preds = model(val_images.to(DEVICE))
preds[preds >= 0.5] = 1
preds[preds < 0.5] = 0
val_labels = val_labels.to(DEVICE)
total += val_labels.shape[0] * val_labels.shape[1] * val_labels.shape[2] * val_labels.shape[3]
correct += (preds == val_labels).sum()
#imshow(torchvision.utils.make_grid(preds),0)
accuracy = round((correct/total).item(),2)
print(f'[Epoch:{epoch + 1}, i:{i+1}], running loss: {running_loss / 2}, accuracy: {accuracy}')
running_loss = 0.0
print(f'Finished Training')
In [46]:
# Visualize the validation results
preds = model(val_images.to(device=DEVICE))
preds[preds >= 0.5] = 1
preds[preds < 0.5] = 0
imshow(torchvision.utils.make_grid(preds),0)
In [31]:
# create test dataset
class RetinalTestSet(Dataset):
def __init__(self, image_paths, train=False):
self.image_paths = image_paths
def transform(self, image):
# Resize
#resize_image = transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH))
#image = resize_image(image)
# Transform to tensor
to_tensor = transforms.ToTensor()
image = to_tensor(image)
# Use only the first channel in the image
#image = image[:1,:,:]
return image
def __getitem__(self, index):
image = Image.open(self.image_paths[index])
image = self.transform(image)
return image
def __len__(self):
return len(self.image_paths)
In [47]:
test_paths = glob.glob('data/Images_test/*.tif')
test_set = RetinalTestSet(test_paths)
# free up cuda memory
gc.collect()
torch.cuda.empty_cache()
In [48]:
# visualize the test predictions
BATCH_SIZE = 5
testloader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
for images in testloader:
preds = model(images.to(device=DEVICE))
preds[preds >= 0.5] = 1
preds[preds < 0.5] = 0
print('Prediction results:')
imshow(torchvision.utils.make_grid(preds),0)