Chromosome Segmentation in Pytorch using U-Net Architecture

Details about the usecase can be studied from the references given below in the references section. Following is the source code. The code was run on Google Colab.

Imports & Utilities

Importing the Libraries

import numpy as np
import h5py
from matplotlib import *
import sys
import pylab as pl
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from torch.autograd import Variable
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

Following is the code to set device based on the type of machine (CPU or GPU).

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Following is the code for mounting the google drive on Google Colab.

from google.colab import drive
drive.mount('/content/drive')

Following are the utilities functions take from the link.

def plotSamplesOneHots(labels_of_samples, output_file=False):
    '''
    labels_of_samples of shape (num_samples, x, y, num_onehots)
    '''
    if len(labels_of_samples.shape) != 4:
        print("Incorrect input size - should be (num_samples, x, y, num_onehots)")
    num_samples = labels_of_samples.shape[0]
    num_onehots = labels_of_samples.shape[1]
    figure_size = (4*num_onehots, 4*num_samples)
    fig, ax = plt.subplots(num_samples, num_onehots, sharex=True, sharey=True, figsize=figure_size)
    for i in range(num_samples):
        for j in range(num_onehots):
            ax[i, j].imshow(labels_of_samples[i,j,...], aspect="auto")
    fig.tight_layout()
    plt.show()
    if output_file == True:
        fig.savefig(output_file)
        
def makeXbyY(data, X, Y):
    '''
    Crop data to size X by Y
    '''
    if len(data.shape) < 3:
        print('Input should be of size (num_samples, x, y,...)')
    data_x_start = int((data.shape[1]-X)/2)
    data_y_start = int((data.shape[1]-Y)/2)
    arrayXbyY = data[:, (data_x_start):(data_x_start + X), (data_y_start):(data_y_start + Y),...]
    return arrayXbyY

def findNearestNeighbourLabel(array):
    center = int(array.shape[0]/2)
    labels_count = np.zeros(5)
    for x in range(array.shape[0]):
        for y in range(array.shape[1]):
            if (x != center) or (y != center):
                temp_label = array[x, y]
                labels_count[temp_label] += 1
    return labels_count.argmax()

def cleanLabelNearestNeighbour(label,num_of_classes):
    '''
    Corrects incorrect labels in a single image based on a threshold on the number of 
    nearest neighbours with the same label
    '''
    x_length = label.shape[0]
    y_length = label.shape[1]
    # num_of_classes = 4
    cleaned_labels = np.zeros((x_length, y_length, 4))
    for x in range(1,x_length-1):
        for y in range(1, y_length-1):
            temp_label = label[x,y]
            if temp_label >3: # if labeled as 4 or above
                temp_label = findNearestNeighbourLabel(label[(x-1):(x+2), (y-1):(y+2)])
                cleaned_labels[x, y, temp_label] = 1
            elif temp_label > 0:
                num_labels_in_3x3 = len(np.where(label[(x-1):(x+2), (y-1):(y+2)]==temp_label)[0])
                if num_labels_in_3x3 > 3:
                    cleaned_labels[x, y, temp_label] = 1
                else:
                    temp_label = findNearestNeighbourLabel(label[(x-1):(x+2), (y-1):(y+2)])
                    cleaned_labels[x, y, temp_label] = 1
        non_zero_array = cleaned_labels[..., 1:].sum(axis=2).astype('bool')
        cleaned_labels[..., 0] = np.ones((x_length, y_length), dtype='bool')^non_zero_array
    return cleaned_labels

def cleanLabelNearestNeighbour_alllabels(labels):    
    '''
    Cleans incorrect labels
    '''
    num_labels = labels.shape[0] # count of data set
    num_of_classes = 4
    cleaned_dim = list(labels.shape) #[13434, 94, 93]
    cleaned_dim.append(num_of_classes) # [13434, 94, 93,4]
    cleaned_labels = np.zeros(cleaned_dim)
    for image_i in range(num_labels):
        # print('Preprocessing image %d of %d' % (image_i, num_labels))
        cleaned_labels[image_i,...] = cleanLabelNearestNeighbour(labels[image_i, ...],num_of_classes)
    return cleaned_labels

def meanIOU_per_image(y_pred, y_true):
    '''
    Calculate the IOU, averaged across images
    '''
    if len(y_pred.shape) < 3 or (y_pred.shape[2]<4):
        print('Wrong dimensions: one hot encoding expected')
        return
    y_pred = y_pred.astype('bool')
    y_true = y_true.astype('bool')
    IUs = []
    for layer in range(y_true.shape[1]):
        intersection = y_pred[:,layer,...] & y_true[:,layer,...]
        union = y_pred[:,layer,...] | y_true[:,layer,...]
        if union.sum() == 0:
            IUs.append(1)
        else:
            IUs.append(intersection.sum()/union.sum())
    return sum(IUs)/len(IUs)

def meanIOU(y_pred, y_true):
    '''
    Calculate the mean IOU, with the mean taken over classes
    '''
    if len(y_pred.shape) < 4 or (y_pred.shape[1]<4):
        print('Wrong dimensions: one hot encoding expected')
        return
    y_pred = y_pred.astype('bool')
    y_true = y_true.astype('bool')
    IUs = []
    for layer in range(y_true.shape[1]):
        intersection = y_pred[:,layer,...] & y_true[:,layer,...]
        union = y_pred[:,layer,...] | y_true[:,layer,...]
        if union.sum() == 0:
            IUs.append(1)
        else:
            IUs.append(intersection.sum()/union.sum())
    return sum(IUs)/len(IUs)
	
def IOU(y_pred, y_true):
    '''
    Calculate the IOU for each class seperately
    '''
    if len(y_pred.shape) < 4 or (y_pred.shape[1]<4):
        print('Wrong dimensions: one hot encoding expected')
        return
    y_pred = y_pred.astype('bool')
    y_true = y_true.astype('bool')
    #print(y_pred)
    #print(y_true)
    IUs = []
    for layer in range(y_true.shape[1]):
        intersection = y_pred[:,layer,...] & y_true[:,layer,...]
        union = y_pred[:,layer,...] | y_true[:,layer,...]
        #print(intersection.sum(), union.sum())
        if union.sum() == 0:
            IUs.append(1)
        else:
            IUs.append(intersection.sum()/union.sum())
    return IUs

# One-hot encoding
def oneHotEncode(initial_array):
    '''
    One hot encode the labels
    '''
    allowed_max_class_num = 3
    output_shape = list(initial_array.shape)
    output_shape[-1] = initial_array.max()
    output_array_dims = list(initial_array.shape)
    output_array_dims.append(4)
    output_array = np.zeros(output_array_dims)
    for image_i in range(0, initial_array.shape[0]):
        for class_num in range(0, allowed_max_class_num):
            for x in range(0, initial_array.shape[1]):
                for y in range(0, initial_array.shape[2]):
                    if initial_array[image_i, x, y] == class_num:
                        output_array[image_i, x, y, class_num] = 1

        class_num = allowed_max_class_num
        for x in range(0, initial_array.shape[1]):
            for y in range(0, initial_array.shape[2]):
                if initial_array[image_i, x, y] >= allowed_max_class_num:
                    output_array[image_i, x, y, class_num] = 1
    return output_array

# Global Accuracy
def globalAccuracy(y_pred, y_true):
    # Calculate the global accuracy (ie. percent of pixels correctly labelled)
    
    y_pred = y_pred.astype('bool')
    y_true = y_true.astype('bool')

    correct = y_pred & y_true
    num_correct = correct.sum()
    num_total = 1
    shape_dim=list(y_true.shape)
    shape_dim.remove(4)
    shape_dim
    for dim in shape_dim:
        # print(dim)
        num_total = num_total*dim
    return num_correct/num_total

Data Preprocessing

Data Download & Extraction

!wget 'https://github.com/jeanpat/DeepFISH/raw/master/dataset/LowRes_13434_overlapping_pairs.h5'
!mv 'LowRes_13434_overlapping_pairs.h5' dataset/

file_path = 'dataset/LowRes_13434_overlapping_pairs.h5'
h5f = h5py.File(file_path,'r')
xdata = h5f['dataset_1'][...,0]
labels = h5f['dataset_1'][...,1]
h5f.close()
print(f'Shape of Images: {xdata.shape} \nShape of Labels: {labels.shape}')

Dataset Analysis

# First overlapped chromosome image
# This is grayscale image with shape (94, 93)
xdata[0].shape

# It is observed that some intensities of pixels are negative. I am not able to figure out why?
np.unique(xdata[0])

# Though negative intensities quantity is not much.

unique, counts = np.unique(xdata[0], return_counts=True)
print(np.asarray((unique, counts)).T)

# Image

fig = pl.figure(figsize = (5,5)) 
ax = fig.add_subplot(111)
ax.imshow(xdata[0], aspect="auto")

# Label of the first overlapped chromosome image
# Lets see the unique intensity values of pixels
#np.unique(labels[0])
unique, counts = np.unique(labels[0], return_counts=True)
print(np.asarray((unique, counts)).T)

# This clearly indicates the masks for the four classes
# Image:
fig = pl.figure(figsize = (5,5)) 
ax = fig.add_subplot(111)
ax.imshow(labels[0], aspect="auto")

# But we can observe that there are different coulors at the boundaries. So cleaning is requried.

# Clean the Labels
labels_cleaned = cleanLabelNearestNeighbour_alllabels(labels)
print(f'Shape of Images: {xdata.shape} \nShape of Labels: {labels_cleaned.shape}')

# Reshape image to height = width = 88
xdata_equal = makeXbyY(xdata, 88, 88)
labels_equal = makeXbyY(labels_cleaned, 88, 88)

print(f'Shape of Images: {xdata_equal.shape} \nShape of Labels: {labels_equal.shape}')

# Reshape Data and Labels
# Add one 1 empty channel for gray scale Images (13434, 88, 88) to (13434, 1, 88, 88)
# Reshape labels from (13434, 88, 88, 4) to (13434, 4, 88, 88)
xdata1 = np.expand_dims(xdata_equal, axis=1)
labels1 = np.transpose(labels_equal, (0, 3, 1, 2))
print(f'Shape of Images: {xdata1.shape} \nShape of Labels: {labels1.shape}')

# Lets see first two images and its final labels
fig = pl.figure(figsize = (5,5)) 
ax = fig.add_subplot(111)
ax.imshow(xdata1[0][0], aspect="auto")

fig = pl.figure(figsize = (5,5)) 
ax = fig.add_subplot(111)
ax.imshow(xdata1[1][0], aspect="auto")

plotSamplesOneHots(labels1[0:2])

# Save the procossed data
np.save('dataset/xdata_88x88', xdata1)
np.save('dataset/ydata_88x88_0123_onehot', labels1)

Data Loading

# Load the processed data
xdata_loaded = np.load('dataset/xdata_88x88.npy')
labels_loaded = np.load('dataset/ydata_88x88_0123_onehot.npy')

print(f'Shape of Images: {xdata_loaded.shape} \nShape of Labels: {labels_loaded.shape}')

# Dataset Conversion in Pytorch format

class ChromosomeDataset(Dataset):
  def __init__(self, xdata_loaded, labels_loaded, transform=None):
    super().__init__()
    self.transform = transform
        
  def __len__(self):
    return xdata_loaded.shape[0]
    
  def __getitem__(self, idx):
    return xdata_loaded[idx], labels_loaded[idx]

batch_size = 1
dataset = ChromosomeDataset(xdata_loaded, labels_loaded)
train_ds, test_ds = torch.utils.data.random_split(dataset, (12434, 1000)) # train on 12434 and tested on 1000
train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
test_dataloader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)

xb, yb = next(iter(train_dataloader))
print(xb.shape)
print(yb.shape)

fig = pl.figure(figsize = (5,5)) 
ax = fig.add_subplot(111)
ax.imshow(xb.cpu().detach().numpy()[0][0], aspect="auto")

U-Net Architecture

Fig: U-Net architecture for chromosome segmentation. Each blue box corresponds to a multi-channel feature map, while the white boxes are the copied feature maps [1]
!rm unet.py
!wget https://raw.githubusercontent.com/jvanvugt/pytorch-unet/master/unet.py

from unet import UNet
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_classes=4, depth=4, padding=True, up_mode='upsample').to(device)
optim = torch.optim.Adam(model.parameters())

!pip install torchinfo
from torchinfo import summary
summary(model, input_size=(1,1, 88, 88))

Training

epochs = 10
learning_rate = 0.00001

for epoch in range(epochs):
  print(f'---------- epoch: {epoch} ----------')
  running_loss = []
  IOU_per_epoch = []
  accuracy_per_epoch = []

  for i, (X, y) in enumerate(train_dataloader):
    optim.zero_grad()

    X = X.float().to(device)
    y = y.to(device)

    y1 = torch.argmax(y, dim=1)
    outputs = model(X)
    
    #print(outputs.shape, y.shape)
    loss = torch.nn.CrossEntropyLoss()
    loss = loss(outputs, y1)
    #loss = Variable(loss, requires_grad = True)

    optim.zero_grad()

    a = list(model.parameters())[0].clone()
    loss.backward()
    optim.step()
    b = list(model.parameters())[0].clone()
    #print(torch.equal(a.data, b.data))
    
    #running_loss += loss.item()
    running_loss.append(loss.item())

    outputs1 = torch.argmax(outputs, dim=1)
    outputs2 = oneHotEncode(outputs1.cpu().detach().numpy())
    
    IOU_per_epoch.append(IOU(np.transpose(outputs2, (0,3,1,2)), y.cpu().detach().numpy()))
    accuracy_per_epoch.append(globalAccuracy(np.transpose(outputs2, (0,3,1,2)), y.cpu().detach().numpy()))

  #print(np.unique(running_loss))
  #print(running_loss)
  training_loss = sum(running_loss)/len(running_loss)
  print(f'train loss: {training_loss}')

  accuracy_per_epoch = np.average(np.stack(accuracy_per_epoch, axis=0))
  print(f'Accuracy per epoch: {accuracy_per_epoch}')
  mean_IOU_per_epoch = np.average(np.stack(IOU_per_epoch, axis=0), axis=0)
  print(f'Mean IOU per epoch: {mean_IOU_per_epoch}')

print('Finished Training')

Checkpoint the model

# check point code
PATH = '/content/drive/MyDrive/projects/chromosomes/segmentation_unet/model/unet_chromosome_01.pth'
torch.save(model.state_dict(), PATH)

Testing

y_pred_test=[]
true_y_pred_test=[]
i = 0
model.eval()
for i, (X, y) in enumerate(test_dataloader):
  
  X = X.float().to(device)
  y = y.to(device)

  #y1 = torch.argmax(y, dim=1)
  outputs = model(X)

  true_y_pred_test.append(y.cpu().detach().numpy())
  y_pred_test.append(outputs.cpu().detach().numpy())

y_pred_test1 = np.stack(y_pred_test, axis=1)[0,:,:,:,:]
true_y_pred_test1 = np.stack(true_y_pred_test, axis=1)[0,:,:,:,:]

testIOU = IOU(y_pred_test1, true_y_pred_test1)
print(f'testIOU: {testIOU}')

# Global Accuracy 
global_test_accuracy = globalAccuracy(np.transpose(oneHotEncode(np.argmax(y_pred_test1, axis=1)), (0,3,1,2)), true_y_pred_test1)
print(f'Global Test Acuracy: {global_train_accuracy}')
del y_pred_test1
del true_y_pred_test1

References

  1. Overlapping Chromosome Segmentation using U-Net: Convolutional Networks with Test Time Augmentation – Hariyanti Mohd Saleha et. al., Paper Link
  2. U-Net: Convolutional Networks for Biomedical Image Segmentation – Olaf Ronneberger et. al., Paper Link
  3. Image Segmentation to Distinguish Between Overlapping Human Chromosomes, 2017 – R. Lily Hu et. al. – Paper Link, Github Link

Note: A portion of above code was contributed by Mr. Somex Gupta.