Transfer Learning Convolution Neural Network (U-Net)¶

In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

Import Packages¶

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils

from PIL import Image
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import json

import os
import shutil
import glob
C:\ProgramData\Miniconda3\lib\site-packages\tqdm\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Define helper functions¶

In [3]:
def train_test_split(data, val_size=0.2, test_size=0.2, shuffle=True, random_seed=None):
    if shuffle:
        if random_seed is not None:
            np.random.seed(random_seed)
        np.random.shuffle(data)

    n_total = len(data)
    n_test = int(n_total * test_size)
    n_val = int(n_total * val_size)
    n_train = n_total - n_val - n_test

    train_data = data[:n_train]
    val_data = data[n_train:n_train + n_val]
    test_data = data[n_train + n_val:]

    return train_data, val_data, test_data

Defining Paths¶

In [4]:
# base path of the dataset
DATASET_PATH = r"dataset"
# define the path to the images and masks dataset
IMAGE_DATASET_PATH = os.path.join(DATASET_PATH, "images")
MASK_DATASET_PATH = os.path.join(DATASET_PATH, "masks")

# define the path to the output serialized model, model training
BASE_OUTPUT = "output"
MODEL_STR = 'unet'
MODEL_PATH = os.path.join(BASE_OUTPUT, f"{MODEL_STR}_coral.pth")

# plot, and testing image paths
TEST_PATHS = os.path.join(DATASET_PATH, "test")

# define the input image dimensions
INPUT_IMAGE_WIDTH = 256
INPUT_IMAGE_HEIGHT = 256

# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# determine if we will be pinning memory during data loading
PIN_MEMORY = True if DEVICE == "cpu" else False
BATCH_SIZE = 16

Define Data Loader¶

In [5]:
class CreateDataset(Dataset):
    def __init__(self, img_dir, mask_dir, device, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.device = device
        self.transform = transform
        self.image_files = os.listdir(img_dir)
        self.mask_files = os.listdir(mask_dir)
        
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])    
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        return image.to(self.device), mask.to(self.device)

transform = transforms.Compose([
    transforms.Resize((INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH)),
    transforms.ToTensor(), # Shape (C, H, W)
])

train_dataset = CreateDataset("dataset\\train\\images", "dataset\\train\\masks", DEVICE, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory = PIN_MEMORY)

valid_dataset = CreateDataset("dataset\\valid\\images", "dataset\\valid\\masks", DEVICE, transform=transform)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory = PIN_MEMORY)

test_dataset = CreateDataset("dataset\\test\\images", "dataset\\test\\masks", DEVICE, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory = PIN_MEMORY)

Transfer Learning Resnet101¶

In [6]:
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils

ENCODER = 'resnet101'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['notcoral','coral']
ACTIVATION = 'sigmoid'

INIT_LR = 0.0001
NUM_EPOCHS = 500
# define threshold to filter weak predictions
THRESHOLD = 0.5

unet = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=min(1, len(CLASSES)-1), 
    activation=ACTIVATION,
)
print(f"New model {MODEL_STR} is created.")
New model unet is created.

Model Criteria¶

In [7]:
loss = smp.utils.losses.JaccardLoss()
loss_str = 'jaccard_loss'

# define metrics
metrics = [
    smp.utils.metrics.IoU(),
    smp.utils.metrics.Fscore(),
    smp.utils.metrics.Recall(),
    smp.utils.metrics.Precision(),
]

# define optimizer
optimizer = torch.optim.Adam([ 
    dict(params=unet.parameters(), lr=INIT_LR),
])
    
train_epoch = smp.utils.train.TrainEpoch(
    unet, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    unet, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

Model Training¶

In [8]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
In [9]:
best_iou_score = 0.0
train_logs_list, valid_logs_list = [], []

early_stopper = EarlyStopper(patience=3, min_delta=0.01)

for i in range(NUM_EPOCHS):

    # Perform training & validation
    print('\nEpoch: {}'.format(i+1))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    train_logs_list.append(train_logs)
    valid_logs_list.append(valid_logs)

    # Save model if a better val IoU score is obtained
    if best_iou_score < valid_logs['iou_score']:
        best_iou_score = valid_logs['iou_score']
        torch.save(unet, MODEL_PATH)
        best_unet = unet
        print('Model saved!')

    if early_stopper.early_stop(valid_logs[loss_str]):
        print("Early Stopped!")
        break

# Save Training Logs
with open(os.path.join(BASE_OUTPUT, f"{MODEL_STR}_coral_train_logs.json"), 'w') as f:
    json.dump(train_logs_list, f, indent=4)
with open(os.path.join(BASE_OUTPUT, f"{MODEL_STR}_coral_valid_logs.json"), 'w') as f:
    json.dump(valid_logs_list, f, indent=4)
Epoch: 1
train: 100%|█| 14/14 [00:08<00:00,  1.68it/s, jaccard_loss - 0.9636, iou_score - 0.04122, fscore - 0.07904, recall - 0.
valid: 100%|█| 3/3 [00:00<00:00,  7.98it/s, jaccard_loss - 0.9665, iou_score - 0.04366, fscore - 0.08363, recall - 0.58
Model saved!

Epoch: 2
train: 100%|█| 14/14 [00:05<00:00,  2.52it/s, jaccard_loss - 0.951, iou_score - 0.07077, fscore - 0.1318, recall - 0.94
valid: 100%|█| 3/3 [00:00<00:00,  8.29it/s, jaccard_loss - 0.9634, iou_score - 0.0514, fscore - 0.09777, recall - 0.611
Model saved!

Epoch: 3
train: 100%|█| 14/14 [00:05<00:00,  2.53it/s, jaccard_loss - 0.9424, iou_score - 0.09745, fscore - 0.1774, recall - 0.9
valid: 100%|█| 3/3 [00:00<00:00,  7.92it/s, jaccard_loss - 0.9353, iou_score - 0.117, fscore - 0.2095, recall - 0.9309,
Model saved!

Epoch: 4
train: 100%|█| 14/14 [00:05<00:00,  2.45it/s, jaccard_loss - 0.9374, iou_score - 0.1166, fscore - 0.2085, recall - 0.98
valid: 100%|█| 3/3 [00:00<00:00,  8.13it/s, jaccard_loss - 0.9299, iou_score - 0.1282, fscore - 0.2271, recall - 0.9677
Model saved!

Epoch: 5
train: 100%|█| 14/14 [00:05<00:00,  2.50it/s, jaccard_loss - 0.9311, iou_score - 0.1423, fscore - 0.2488, recall - 0.98
valid: 100%|█| 3/3 [00:00<00:00,  8.13it/s, jaccard_loss - 0.9228, iou_score - 0.1631, fscore - 0.2804, recall - 0.9717
Model saved!

Epoch: 6
train: 100%|█| 14/14 [00:05<00:00,  2.45it/s, jaccard_loss - 0.9259, iou_score - 0.1796, fscore - 0.304, recall - 0.989
valid: 100%|█| 3/3 [00:00<00:00,  8.06it/s, jaccard_loss - 0.9186, iou_score - 0.2147, fscore - 0.3534, recall - 0.9533
Model saved!

Epoch: 7
train: 100%|█| 14/14 [00:05<00:00,  2.42it/s, jaccard_loss - 0.9201, iou_score - 0.2371, fscore - 0.3824, recall - 0.98
valid: 100%|█| 3/3 [00:00<00:00,  8.09it/s, jaccard_loss - 0.9128, iou_score - 0.3087, fscore - 0.4713, recall - 0.9044
Model saved!

Epoch: 8
train: 100%|█| 14/14 [00:05<00:00,  2.39it/s, jaccard_loss - 0.9139, iou_score - 0.3002, fscore - 0.4611, recall - 0.98
valid: 100%|█| 3/3 [00:00<00:00,  7.75it/s, jaccard_loss - 0.9055, iou_score - 0.3392, fscore - 0.5059, recall - 0.9098
Model saved!

Epoch: 9
train: 100%|█| 14/14 [00:05<00:00,  2.40it/s, jaccard_loss - 0.9063, iou_score - 0.3661, fscore - 0.5354, recall - 0.98
valid: 100%|█| 3/3 [00:00<00:00,  8.20it/s, jaccard_loss - 0.9006, iou_score - 0.3859, fscore - 0.556, recall - 0.9056,
Model saved!

Epoch: 10
train: 100%|█| 14/14 [00:05<00:00,  2.39it/s, jaccard_loss - 0.9, iou_score - 0.4101, fscore - 0.5807, recall - 0.9859,
valid: 100%|█| 3/3 [00:00<00:00,  7.54it/s, jaccard_loss - 0.892, iou_score - 0.4043, fscore - 0.5753, recall - 0.898, 
Model saved!

Epoch: 11
train: 100%|█| 14/14 [00:05<00:00,  2.35it/s, jaccard_loss - 0.8938, iou_score - 0.4434, fscore - 0.6137, recall - 0.98
valid: 100%|█| 3/3 [00:00<00:00,  7.83it/s, jaccard_loss - 0.8899, iou_score - 0.4593, fscore - 0.6291, recall - 0.8707
Model saved!

Epoch: 12
train: 100%|█| 14/14 [00:05<00:00,  2.42it/s, jaccard_loss - 0.8856, iou_score - 0.4952, fscore - 0.662, recall - 0.984
valid: 100%|█| 3/3 [00:00<00:00,  8.06it/s, jaccard_loss - 0.8778, iou_score - 0.4773, fscore - 0.6445, recall - 0.8746
Model saved!

Epoch: 13
train: 100%|█| 14/14 [00:05<00:00,  2.42it/s, jaccard_loss - 0.8795, iou_score - 0.5324, fscore - 0.6939, recall - 0.98
valid: 100%|█| 3/3 [00:00<00:00,  7.98it/s, jaccard_loss - 0.8695, iou_score - 0.4891, fscore - 0.6567, recall - 0.8841
Model saved!

Epoch: 14
train: 100%|█| 14/14 [00:05<00:00,  2.42it/s, jaccard_loss - 0.8696, iou_score - 0.5624, fscore - 0.7193, recall - 0.98
valid: 100%|█| 3/3 [00:00<00:00,  7.94it/s, jaccard_loss - 0.8618, iou_score - 0.491, fscore - 0.6567, recall - 0.891, 
Model saved!

Epoch: 15
train: 100%|█| 14/14 [00:05<00:00,  2.41it/s, jaccard_loss - 0.8629, iou_score - 0.5738, fscore - 0.7287, recall - 0.98
valid: 100%|█| 3/3 [00:00<00:00,  7.77it/s, jaccard_loss - 0.8464, iou_score - 0.5435, fscore - 0.7038, recall - 0.8782
Model saved!

Epoch: 16
train: 100%|█| 14/14 [00:05<00:00,  2.42it/s, jaccard_loss - 0.8498, iou_score - 0.6031, fscore - 0.7517, recall - 0.98
valid: 100%|█| 3/3 [00:00<00:00,  7.90it/s, jaccard_loss - 0.8304, iou_score - 0.5659, fscore - 0.7227, recall - 0.8549
Model saved!

Epoch: 17
train: 100%|█| 14/14 [00:05<00:00,  2.41it/s, jaccard_loss - 0.83, iou_score - 0.6322, fscore - 0.774, recall - 0.9807,
valid: 100%|█| 3/3 [00:00<00:00,  7.86it/s, jaccard_loss - 0.8274, iou_score - 0.5619, fscore - 0.7184, recall - 0.8778

Epoch: 18
train: 100%|█| 14/14 [00:05<00:00,  2.41it/s, jaccard_loss - 0.8114, iou_score - 0.6696, fscore - 0.8013, recall - 0.97
valid: 100%|█| 3/3 [00:00<00:00,  6.41it/s, jaccard_loss - 0.8166, iou_score - 0.6233, fscore - 0.7676, recall - 0.8152
Model saved!

Epoch: 19
train: 100%|█| 14/14 [00:05<00:00,  2.40it/s, jaccard_loss - 0.7975, iou_score - 0.672, fscore - 0.8028, recall - 0.974
valid: 100%|█| 3/3 [00:00<00:00,  8.00it/s, jaccard_loss - 0.7988, iou_score - 0.5439, fscore - 0.7033, recall - 0.8901

Epoch: 20
train: 100%|█| 14/14 [00:05<00:00,  2.40it/s, jaccard_loss - 0.774, iou_score - 0.7128, fscore - 0.8319, recall - 0.972
valid: 100%|█| 3/3 [00:00<00:00,  7.98it/s, jaccard_loss - 0.7793, iou_score - 0.6352, fscore - 0.7767, recall - 0.8263
Model saved!

Epoch: 21
train: 100%|█| 14/14 [00:06<00:00,  2.30it/s, jaccard_loss - 0.7578, iou_score - 0.7107, fscore - 0.8301, recall - 0.97
valid: 100%|█| 3/3 [00:00<00:00,  7.52it/s, jaccard_loss - 0.7611, iou_score - 0.624, fscore - 0.7676, recall - 0.8308,

Epoch: 22
train: 100%|█| 14/14 [00:06<00:00,  2.26it/s, jaccard_loss - 0.7372, iou_score - 0.7408, fscore - 0.8506, recall - 0.96
valid: 100%|█| 3/3 [00:00<00:00,  7.85it/s, jaccard_loss - 0.7427, iou_score - 0.6321, fscore - 0.7742, recall - 0.8407

Epoch: 23
train: 100%|█| 14/14 [00:05<00:00,  2.39it/s, jaccard_loss - 0.7152, iou_score - 0.7615, fscore - 0.8643, recall - 0.96
valid: 100%|█| 3/3 [00:00<00:00,  7.79it/s, jaccard_loss - 0.7281, iou_score - 0.6539, fscore - 0.7905, recall - 0.822,
Model saved!

Epoch: 24
train: 100%|█| 14/14 [00:06<00:00,  2.14it/s, jaccard_loss - 0.6904, iou_score - 0.7806, fscore - 0.8764, recall - 0.96
valid: 100%|█| 3/3 [00:00<00:00,  8.22it/s, jaccard_loss - 0.7089, iou_score - 0.6671, fscore - 0.8, recall - 0.8147, p
Model saved!

Epoch: 25
train: 100%|█| 14/14 [00:05<00:00,  2.43it/s, jaccard_loss - 0.6729, iou_score - 0.7835, fscore - 0.8784, recall - 0.97
valid: 100%|█| 3/3 [00:00<00:00,  7.83it/s, jaccard_loss - 0.7003, iou_score - 0.6654, fscore - 0.7988, recall - 0.7527

Epoch: 26
train: 100%|█| 14/14 [00:06<00:00,  2.32it/s, jaccard_loss - 0.6494, iou_score - 0.8063, fscore - 0.8923, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  8.38it/s, jaccard_loss - 0.6647, iou_score - 0.6622, fscore - 0.7963, recall - 0.843,

Epoch: 27
train: 100%|█| 14/14 [00:06<00:00,  2.19it/s, jaccard_loss - 0.6242, iou_score - 0.8062, fscore - 0.8925, recall - 0.96
valid: 100%|█| 3/3 [00:00<00:00,  8.02it/s, jaccard_loss - 0.6501, iou_score - 0.6673, fscore - 0.7997, recall - 0.832,
Model saved!

Epoch: 28
train: 100%|█| 14/14 [00:05<00:00,  2.39it/s, jaccard_loss - 0.5984, iou_score - 0.8291, fscore - 0.9063, recall - 0.96
valid: 100%|█| 3/3 [00:00<00:00,  7.83it/s, jaccard_loss - 0.6334, iou_score - 0.6796, fscore - 0.809, recall - 0.8008,
Model saved!

Epoch: 29
train: 100%|█| 14/14 [00:06<00:00,  2.26it/s, jaccard_loss - 0.5725, iou_score - 0.8338, fscore - 0.9091, recall - 0.96
valid: 100%|█| 3/3 [00:00<00:00,  5.94it/s, jaccard_loss - 0.6229, iou_score - 0.6745, fscore - 0.8054, recall - 0.7679

Epoch: 30
train: 100%|█| 14/14 [00:06<00:00,  2.13it/s, jaccard_loss - 0.5525, iou_score - 0.839, fscore - 0.9121, recall - 0.961
valid: 100%|█| 3/3 [00:00<00:00,  8.22it/s, jaccard_loss - 0.5942, iou_score - 0.6682, fscore - 0.8007, recall - 0.8306

Epoch: 31
train: 100%|█| 14/14 [00:05<00:00,  2.34it/s, jaccard_loss - 0.529, iou_score - 0.8404, fscore - 0.9129, recall - 0.961
valid: 100%|█| 3/3 [00:00<00:00,  7.71it/s, jaccard_loss - 0.5857, iou_score - 0.68, fscore - 0.809, recall - 0.7727, p
Model saved!

Epoch: 32
train: 100%|█| 14/14 [00:05<00:00,  2.44it/s, jaccard_loss - 0.5011, iou_score - 0.8579, fscore - 0.9235, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  8.09it/s, jaccard_loss - 0.557, iou_score - 0.6923, fscore - 0.8179, recall - 0.8021,
Model saved!

Epoch: 33
train: 100%|█| 14/14 [00:06<00:00,  2.20it/s, jaccard_loss - 0.4794, iou_score - 0.8515, fscore - 0.9196, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.83it/s, jaccard_loss - 0.5495, iou_score - 0.6869, fscore - 0.8141, recall - 0.772,

Epoch: 34
train: 100%|█| 14/14 [00:06<00:00,  2.24it/s, jaccard_loss - 0.4585, iou_score - 0.8591, fscore - 0.9241, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  8.02it/s, jaccard_loss - 0.5376, iou_score - 0.6825, fscore - 0.8108, recall - 0.7544

Epoch: 35
train: 100%|█| 14/14 [00:06<00:00,  2.14it/s, jaccard_loss - 0.437, iou_score - 0.8687, fscore - 0.9297, recall - 0.958
valid: 100%|█| 3/3 [00:00<00:00,  7.90it/s, jaccard_loss - 0.534, iou_score - 0.6704, fscore - 0.8023, recall - 0.7404,

Epoch: 36
train: 100%|█| 14/14 [00:06<00:00,  2.16it/s, jaccard_loss - 0.4161, iou_score - 0.8703, fscore - 0.9306, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  6.10it/s, jaccard_loss - 0.5043, iou_score - 0.687, fscore - 0.8141, recall - 0.7812,

Epoch: 37
train: 100%|█| 14/14 [00:06<00:00,  2.19it/s, jaccard_loss - 0.3962, iou_score - 0.8717, fscore - 0.9314, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.94it/s, jaccard_loss - 0.4908, iou_score - 0.6871, fscore - 0.8141, recall - 0.7909

Epoch: 38
train: 100%|█| 14/14 [00:06<00:00,  2.10it/s, jaccard_loss - 0.3801, iou_score - 0.871, fscore - 0.931, recall - 0.9532
valid: 100%|█| 3/3 [00:00<00:00,  5.78it/s, jaccard_loss - 0.4929, iou_score - 0.6743, fscore - 0.8048, recall - 0.749,

Epoch: 39
train: 100%|█| 14/14 [00:08<00:00,  1.67it/s, jaccard_loss - 0.3643, iou_score - 0.8752, fscore - 0.9334, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  5.63it/s, jaccard_loss - 0.472, iou_score - 0.6863, fscore - 0.8137, recall - 0.7712,

Epoch: 40
train: 100%|█| 14/14 [00:06<00:00,  2.24it/s, jaccard_loss - 0.3475, iou_score - 0.8788, fscore - 0.9355, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.94it/s, jaccard_loss - 0.4692, iou_score - 0.6759, fscore - 0.8061, recall - 0.7503

Epoch: 41
train: 100%|█| 14/14 [00:07<00:00,  1.98it/s, jaccard_loss - 0.3318, iou_score - 0.8793, fscore - 0.9357, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.59it/s, jaccard_loss - 0.4526, iou_score - 0.6857, fscore - 0.8132, recall - 0.7625

Epoch: 42
train: 100%|█| 14/14 [00:05<00:00,  2.35it/s, jaccard_loss - 0.3189, iou_score - 0.8829, fscore - 0.9378, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  8.01it/s, jaccard_loss - 0.438, iou_score - 0.6905, fscore - 0.8165, recall - 0.7801,

Epoch: 43
train: 100%|█| 14/14 [00:05<00:00,  2.37it/s, jaccard_loss - 0.3061, iou_score - 0.8864, fscore - 0.9398, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  6.93it/s, jaccard_loss - 0.4276, iou_score - 0.6941, fscore - 0.819, recall - 0.7964,
Model saved!

Epoch: 44
train: 100%|█| 14/14 [00:05<00:00,  2.34it/s, jaccard_loss - 0.292, iou_score - 0.8907, fscore - 0.9422, recall - 0.954
valid: 100%|█| 3/3 [00:00<00:00,  6.91it/s, jaccard_loss - 0.4225, iou_score - 0.6929, fscore - 0.8182, recall - 0.7826

Epoch: 45
train: 100%|█| 14/14 [00:06<00:00,  2.24it/s, jaccard_loss - 0.2813, iou_score - 0.8906, fscore - 0.9421, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  8.06it/s, jaccard_loss - 0.4162, iou_score - 0.6928, fscore - 0.8181, recall - 0.7751

Epoch: 46
train: 100%|█| 14/14 [00:05<00:00,  2.38it/s, jaccard_loss - 0.2706, iou_score - 0.8911, fscore - 0.9424, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  6.47it/s, jaccard_loss - 0.417, iou_score - 0.6857, fscore - 0.8134, recall - 0.7555,

Epoch: 47
train: 100%|█| 14/14 [00:06<00:00,  2.21it/s, jaccard_loss - 0.2669, iou_score - 0.8881, fscore - 0.9407, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.71it/s, jaccard_loss - 0.3988, iou_score - 0.6965, fscore - 0.8207, recall - 0.797,
Model saved!

Epoch: 48
train: 100%|█| 14/14 [00:08<00:00,  1.60it/s, jaccard_loss - 0.2558, iou_score - 0.8891, fscore - 0.9413, recall - 0.94
valid: 100%|█| 3/3 [00:00<00:00,  4.77it/s, jaccard_loss - 0.3993, iou_score - 0.6904, fscore - 0.8166, recall - 0.7771

Epoch: 49
train: 100%|█| 14/14 [00:09<00:00,  1.47it/s, jaccard_loss - 0.2456, iou_score - 0.8932, fscore - 0.9436, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  5.32it/s, jaccard_loss - 0.3933, iou_score - 0.6931, fscore - 0.8185, recall - 0.7753

Epoch: 50
train: 100%|█| 14/14 [00:09<00:00,  1.44it/s, jaccard_loss - 0.2374, iou_score - 0.8927, fscore - 0.9433, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  5.81it/s, jaccard_loss - 0.4032, iou_score - 0.6752, fscore - 0.8056, recall - 0.7519

Epoch: 51
train: 100%|█| 14/14 [00:09<00:00,  1.46it/s, jaccard_loss - 0.2307, iou_score - 0.8954, fscore - 0.9448, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  4.46it/s, jaccard_loss - 0.4172, iou_score - 0.6581, fscore - 0.7933, recall - 0.7102

Epoch: 52
train: 100%|█| 14/14 [00:09<00:00,  1.47it/s, jaccard_loss - 0.2242, iou_score - 0.8959, fscore - 0.9451, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  5.00it/s, jaccard_loss - 0.3837, iou_score - 0.6894, fscore - 0.8158, recall - 0.7782

Epoch: 53
train: 100%|█| 14/14 [00:09<00:00,  1.49it/s, jaccard_loss - 0.2165, iou_score - 0.8958, fscore - 0.945, recall - 0.949
valid: 100%|█| 3/3 [00:00<00:00,  5.82it/s, jaccard_loss - 0.388, iou_score - 0.682, fscore - 0.8105, recall - 0.7479, 

Epoch: 54
train: 100%|█| 14/14 [00:09<00:00,  1.52it/s, jaccard_loss - 0.2122, iou_score - 0.8943, fscore - 0.9442, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  4.78it/s, jaccard_loss - 0.3777, iou_score - 0.6882, fscore - 0.8149, recall - 0.7617

Epoch: 55
train: 100%|█| 14/14 [00:09<00:00,  1.53it/s, jaccard_loss - 0.2018, iou_score - 0.899, fscore - 0.9468, recall - 0.949
valid: 100%|█| 3/3 [00:00<00:00,  4.44it/s, jaccard_loss - 0.3721, iou_score - 0.6891, fscore - 0.8154, recall - 0.7698

Epoch: 56
train: 100%|█| 14/14 [00:09<00:00,  1.48it/s, jaccard_loss - 0.1938, iou_score - 0.9008, fscore - 0.9478, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  5.16it/s, jaccard_loss - 0.364, iou_score - 0.6934, fscore - 0.8186, recall - 0.7725,

Epoch: 57
train: 100%|█| 14/14 [00:09<00:00,  1.44it/s, jaccard_loss - 0.1899, iou_score - 0.9001, fscore - 0.9474, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  4.61it/s, jaccard_loss - 0.3681, iou_score - 0.6855, fscore - 0.8131, recall - 0.7515

Epoch: 58
train: 100%|█| 14/14 [00:09<00:00,  1.44it/s, jaccard_loss - 0.1853, iou_score - 0.8985, fscore - 0.9465, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  5.01it/s, jaccard_loss - 0.3631, iou_score - 0.6867, fscore - 0.8139, recall - 0.7531

Epoch: 59
train: 100%|█| 14/14 [00:09<00:00,  1.43it/s, jaccard_loss - 0.1784, iou_score - 0.9013, fscore - 0.9481, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  6.48it/s, jaccard_loss - 0.3534, iou_score - 0.6944, fscore - 0.8194, recall - 0.768,

Epoch: 60
train: 100%|█| 14/14 [00:07<00:00,  1.83it/s, jaccard_loss - 0.175, iou_score - 0.9003, fscore - 0.9475, recall - 0.949
valid: 100%|█| 3/3 [00:00<00:00,  6.76it/s, jaccard_loss - 0.3558, iou_score - 0.6884, fscore - 0.8151, recall - 0.7597

Epoch: 61
train: 100%|█| 14/14 [00:06<00:00,  2.18it/s, jaccard_loss - 0.1722, iou_score - 0.8989, fscore - 0.9467, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.06it/s, jaccard_loss - 0.3729, iou_score - 0.668, fscore - 0.8007, recall - 0.7278,

Epoch: 62
train: 100%|█| 14/14 [00:06<00:00,  2.23it/s, jaccard_loss - 0.1684, iou_score - 0.8987, fscore - 0.9466, recall - 0.94
valid: 100%|█| 3/3 [00:00<00:00,  6.77it/s, jaccard_loss - 0.3586, iou_score - 0.6809, fscore - 0.8099, recall - 0.7516

Epoch: 63
train: 100%|█| 14/14 [00:06<00:00,  2.20it/s, jaccard_loss - 0.1623, iou_score - 0.902, fscore - 0.9485, recall - 0.951
valid: 100%|█| 3/3 [00:00<00:00,  5.82it/s, jaccard_loss - 0.3589, iou_score - 0.6788, fscore - 0.8083, recall - 0.741,

Epoch: 64
train: 100%|█| 14/14 [00:06<00:00,  2.19it/s, jaccard_loss - 0.16, iou_score - 0.9015, fscore - 0.9482, recall - 0.9528
valid: 100%|█| 3/3 [00:00<00:00,  7.08it/s, jaccard_loss - 0.3518, iou_score - 0.6843, fscore - 0.8122, recall - 0.7547

Epoch: 65
train: 100%|█| 14/14 [00:06<00:00,  2.10it/s, jaccard_loss - 0.1613, iou_score - 0.8976, fscore - 0.946, recall - 0.946
valid: 100%|█| 3/3 [00:00<00:00,  7.00it/s, jaccard_loss - 0.3365, iou_score - 0.6978, fscore - 0.8216, recall - 0.7941
Model saved!

Epoch: 66
train: 100%|█| 14/14 [00:06<00:00,  2.15it/s, jaccard_loss - 0.1577, iou_score - 0.8985, fscore - 0.9465, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.69it/s, jaccard_loss - 0.3558, iou_score - 0.6768, fscore - 0.8069, recall - 0.7386

Epoch: 67
train: 100%|█| 14/14 [00:07<00:00,  1.86it/s, jaccard_loss - 0.1552, iou_score - 0.8985, fscore - 0.9465, recall - 0.94
valid: 100%|█| 3/3 [00:00<00:00,  7.71it/s, jaccard_loss - 0.3429, iou_score - 0.6884, fscore - 0.8151, recall - 0.7649

Epoch: 68
train: 100%|█| 14/14 [00:06<00:00,  2.13it/s, jaccard_loss - 0.1487, iou_score - 0.9028, fscore - 0.9489, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.75it/s, jaccard_loss - 0.3478, iou_score - 0.6824, fscore - 0.8109, recall - 0.7491

Epoch: 69
train: 100%|█| 14/14 [00:06<00:00,  2.22it/s, jaccard_loss - 0.1464, iou_score - 0.9037, fscore - 0.9494, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.79it/s, jaccard_loss - 0.3426, iou_score - 0.6867, fscore - 0.8139, recall - 0.7576

Epoch: 70
train: 100%|█| 14/14 [00:06<00:00,  2.14it/s, jaccard_loss - 0.1438, iou_score - 0.904, fscore - 0.9496, recall - 0.951
valid: 100%|█| 3/3 [00:00<00:00,  8.06it/s, jaccard_loss - 0.3377, iou_score - 0.6908, fscore - 0.8168, recall - 0.7631

Epoch: 71
train: 100%|█| 14/14 [00:06<00:00,  2.23it/s, jaccard_loss - 0.1417, iou_score - 0.9043, fscore - 0.9497, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.83it/s, jaccard_loss - 0.3408, iou_score - 0.6859, fscore - 0.8133, recall - 0.7536

Epoch: 72
train: 100%|█| 14/14 [00:06<00:00,  2.27it/s, jaccard_loss - 0.1399, iou_score - 0.9045, fscore - 0.9499, recall - 0.94
valid: 100%|█| 3/3 [00:00<00:00,  7.94it/s, jaccard_loss - 0.3301, iou_score - 0.6959, fscore - 0.8203, recall - 0.7809

Epoch: 73
train: 100%|█| 14/14 [00:06<00:00,  2.20it/s, jaccard_loss - 0.1377, iou_score - 0.9049, fscore - 0.9501, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  8.11it/s, jaccard_loss - 0.3354, iou_score - 0.6893, fscore - 0.8158, recall - 0.7621

Epoch: 74
train: 100%|█| 14/14 [00:06<00:00,  2.18it/s, jaccard_loss - 0.1363, iou_score - 0.9051, fscore - 0.9502, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  8.09it/s, jaccard_loss - 0.3314, iou_score - 0.693, fscore - 0.8183, recall - 0.7663,

Epoch: 75
train: 100%|█| 14/14 [00:06<00:00,  2.26it/s, jaccard_loss - 0.1341, iou_score - 0.9061, fscore - 0.9508, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.92it/s, jaccard_loss - 0.3305, iou_score - 0.6931, fscore - 0.8185, recall - 0.7727

Epoch: 76
train: 100%|█| 14/14 [00:06<00:00,  2.26it/s, jaccard_loss - 0.1323, iou_score - 0.906, fscore - 0.9507, recall - 0.949
valid: 100%|█| 3/3 [00:00<00:00,  7.90it/s, jaccard_loss - 0.3263, iou_score - 0.6961, fscore - 0.8205, recall - 0.7807

Epoch: 77
train: 100%|█| 14/14 [00:06<00:00,  2.19it/s, jaccard_loss - 0.1321, iou_score - 0.9047, fscore - 0.9499, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.77it/s, jaccard_loss - 0.3294, iou_score - 0.6928, fscore - 0.8182, recall - 0.7692

Epoch: 78
train: 100%|█| 14/14 [00:06<00:00,  2.28it/s, jaccard_loss - 0.1313, iou_score - 0.9045, fscore - 0.9499, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  8.09it/s, jaccard_loss - 0.3219, iou_score - 0.6986, fscore - 0.8222, recall - 0.789,
Model saved!

Epoch: 79
train: 100%|█| 14/14 [00:06<00:00,  2.29it/s, jaccard_loss - 0.1295, iou_score - 0.9048, fscore - 0.95, recall - 0.9519
valid: 100%|█| 3/3 [00:00<00:00,  5.32it/s, jaccard_loss - 0.3371, iou_score - 0.6825, fscore - 0.811, recall - 0.7486,

Epoch: 80
train: 100%|█| 14/14 [00:06<00:00,  2.29it/s, jaccard_loss - 0.1284, iou_score - 0.9051, fscore - 0.9502, recall - 0.94
valid: 100%|█| 3/3 [00:00<00:00,  7.83it/s, jaccard_loss - 0.3278, iou_score - 0.6916, fscore - 0.8174, recall - 0.7717

Epoch: 81
train: 100%|█| 14/14 [00:06<00:00,  2.06it/s, jaccard_loss - 0.1271, iou_score - 0.9054, fscore - 0.9503, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  7.87it/s, jaccard_loss - 0.3245, iou_score - 0.6944, fscore - 0.8193, recall - 0.7819

Epoch: 82
train: 100%|█| 14/14 [00:06<00:00,  2.29it/s, jaccard_loss - 0.1276, iou_score - 0.904, fscore - 0.9496, recall - 0.948
valid: 100%|█| 3/3 [00:00<00:00,  7.83it/s, jaccard_loss - 0.3337, iou_score - 0.685, fscore - 0.8128, recall - 0.7545,

Epoch: 83
train: 100%|█| 14/14 [00:06<00:00,  2.28it/s, jaccard_loss - 0.1275, iou_score - 0.9035, fscore - 0.9493, recall - 0.95
valid: 100%|█| 3/3 [00:00<00:00,  8.00it/s, jaccard_loss - 0.3347, iou_score - 0.6832, fscore - 0.8114, recall - 0.7498
Early Stopped!

Assess Model Metrics¶

In [10]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
In [11]:
titles = [x.replace("_", "").capitalize() for x in train_logs_list[0].keys()]
titles[titles.index("Iouscore")] = "IoU Score"
titles[titles.index("Fscore")] = "F Score"
titles[titles.index("Jaccardloss")] = "Jaccard loss"
titles = ["<b>" + x + "</b>" for x in titles]

fig = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=titles,
    horizontal_spacing=0.1,
    vertical_spacing=0.1,
)

k = 0
for i in range(2):
    for j in range(2):
        metric = list(train_logs_list[0].keys())[k]

        train = [x[metric] for x in train_logs_list]
        valid = [x[metric] for x in valid_logs_list]

        fig.add_trace(
            go.Scatter(
                x=[x for x in range(len(train))],
                y=train,
                name="Train",
                mode="lines",
                line=dict(color="blue"),
                showlegend=(k == 3),
            ),
            row=i + 1,
            col=j + 1,
        )

        fig.add_trace(
            go.Scatter(
                x=[x for x in range(len(valid))],
                y=valid,
                name="Validation",
                mode="lines",
                line=dict(color="red"),
                showlegend=(k == 3),
            ),
            row=i + 1,
            col=j + 1,
        )

        fig.update_xaxes(title_text="Epochs", row=i + 1, col=j + 1, title_standoff=1)
        fig.update_yaxes(
            #range=[0, 1],
            dtick=0.2,
            title_text="Value",
            row=i + 1,
            col=j + 1,
            title_standoff=1,
        )

        k += 1
fig.update_layout(
    width = 800, height = 800,
    showlegend=True, legend=dict(orientation="h", x=0.35, y=1.1),
    title_text="<b>Performance Metrics</b>",
    title_font=dict(size=24),
    title_x=0.5,
    title_y=0.97
)
fig.show()
In [12]:
# Save metric graph
fig.write_json(r"output\metric.json")

Image Results (Validation)¶

In [13]:
images, masks = next(iter(valid_loader))[0] , next(iter(valid_loader))[1]
with torch.no_grad():
    unet.eval()
    pred = best_unet(images)
In [14]:
i=0
plt.imshow(masks[i].to('cpu').numpy()[0], cmap='gray')
plt.axis('off')
plt.show();
plt.imshow(np.array([1 if i >= 0.5 else 0 for i in np.array(pred[i][0].to('cpu').numpy()).flatten()]).reshape(256,256), cmap='gray')
plt.axis('off')
plt.show();

criterion_IOU = smp.utils.metrics.IoU(threshold=THRESHOLD) 
criterion_recall = smp.utils.metrics.Recall(threshold=THRESHOLD)

print(criterion_IOU(pred, masks))
print(criterion_recall(pred, masks))
tensor(0.7213, device='cuda:0')
tensor(0.7987, device='cuda:0')

Image Results (Test)¶

In [15]:
images, masks = next(iter(test_loader))[0] , next(iter(test_loader))[1]
with torch.no_grad():
    unet.eval()
    pred = best_unet(images)
pred.shape
Out[15]:
torch.Size([13, 1, 256, 256])
In [16]:
i=1

criterion_IOU = smp.utils.metrics.IoU(threshold=THRESHOLD) 
criterion_recall = smp.utils.metrics.Recall(threshold=THRESHOLD)

print(f"IoU: {criterion_IOU(pred[i][0], masks[i])}")
print(f"Recall: {criterion_recall(pred[i][0], masks[i])}")


ori = np.transpose(images[0].to('cpu').numpy(), (2, 1, 0))
true_mask = masks[i].to('cpu').numpy()[0]
pred_mask = np.array([1 if i >= 0.5 else 0 for i in np.array(pred[i][0].to('cpu').numpy()).flatten()]).reshape(256,256)

fig, ax = plt.subplots(1,3)
ax[0].imshow(ori)
ax[0].axis('off')
ax[1].imshow(true_mask, cmap='gray')
ax[1].axis('off')
ax[2].imshow(pred_mask, cmap='gray')
ax[2].axis('off')
plt.subplots_adjust(wspace=0.05, hspace=0.05)
IoU: 0.7414787411689758
Recall: 0.8505991101264954

Unrelated Image Results (Test)¶

In [17]:
fun_dataset = CreateDataset(r"dataset\fun"
                            , r"dataset\fun"
                            , DEVICE, transform=transform)
fun_loader = DataLoader(fun_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory = PIN_MEMORY)
In [18]:
images, masks = next(iter(fun_loader))[0] , next(iter(fun_loader))[1]
with torch.no_grad():
    unet.eval()
    pred = best_unet(images)
pred.shape
Out[18]:
torch.Size([2, 1, 256, 256])
In [19]:
i = 0
raw = cv2.imread(r"dataset\fun\IMG_1401.JPG")
raw = cv2.resize(raw, (256, 256))
raw = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)

raw2 = np.transpose(raw, (2, 1, 0))
raw2 = torch.from_numpy(raw2).float().to(DEVICE)
raw2 = raw2.unsqueeze(0)

pred_mask = np.array([1 if i >= 0.5 else 0 for i in np.array(pred[i][0].to('cpu').numpy()).flatten()]).reshape(256,256)

fig, ax = plt.subplots(1,2)
ax[0].imshow(raw)
ax[0].axis('off')
ax[1].imshow(pred_mask, cmap='gray')
ax[1].axis('off')
plt.subplots_adjust(wspace=0.05, hspace=0.05)

plt.savefig(r"output/example2.png"
           , bbox_inches='tight'
            , facecolor='none'
            , transparent=True)
In [20]:
i = 1
raw = cv2.imread(r"dataset\fun\lenna.jpg")
raw = cv2.resize(raw, (256, 256))
raw = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)

raw2 = np.transpose(raw, (2, 1, 0))
raw2 = torch.from_numpy(raw2).float().to(DEVICE)
raw2 = raw2.unsqueeze(0)

pred_mask = np.array([1 if i >= 0.5 else 0 for i in np.array(pred[i][0].to('cpu').numpy()).flatten()]).reshape(256,256)

fig, ax = plt.subplots(1,2)
ax[0].imshow(raw)
ax[0].axis('off')
ax[1].imshow(pred_mask, cmap='gray')
ax[1].axis('off')
plt.subplots_adjust(wspace=0.05, hspace=0.05)

plt.savefig(r"output/example3.png"
           , bbox_inches='tight'
            , facecolor='none'
            , transparent=True)