Simple Art Restoration with U-Net

Simple Art Restoration with U-Net#


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import random
# Step 1: Define a Tiny U-Net

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3): # [64, 64, 3]
        super().__init__()
        self.down1 = DoubleConv(in_channels, 64) # [64, 64, 64]
        self.pool1 = nn.MaxPool2d(2) # [32, 32, 64]
        self.down2 = DoubleConv(64, 128) # [32, 32, 128]
        self.pool2 = nn.MaxPool2d(2) # [16, 16, 128]

        self.bottleneck = DoubleConv(128, 256) # [16, 16, 256]

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2) # [32, 32, 128 + 128]
        self.dec2 = DoubleConv(256, 128) # [32, 32, 128]
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2) # [64, 64, 64 + 64]
        self.dec1 = DoubleConv(128, 64) # [64, 64, 64]

        self.final = nn.Conv2d(64, out_channels, 1) # [64, 64, 3]

    def forward(self, x):
        # Encoder: downsampling path (extract features, lose some detail)
        d1 = self.down1(x)
        p1 = self.pool1(d1)
        d2 = self.down2(p1)
        p2 = self.pool2(d2)

        # Bottleneck: compressed global representation
        bn = self.bottleneck(p2)

        # Decoder: upsampling path (reconstruct image)
        u2 = self.up2(bn)
        cat2 = torch.cat([u2, d2], dim=1)  # Skip Connection: bring back mid-level details
        d2 = self.dec2(cat2)

        u1 = self.up1(d2)
        cat1 = torch.cat([u1, d1], dim=1)  # Skip Connection: bring back fine details (edges, brushstrokes)
        d1 = self.dec1(cat1)

        return self.final(d1)  # Output: restored image
# Step 2: Prepare the Dataset

class DamagedDataset(Dataset):
    def __init__(self, train=True):
        transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor()
        ])
        self.data = datasets.CIFAR10(root="./data", train=train,
                                     transform=transform, download=True)

    def add_damage(self, img):
        img = img.clone()
        mask = torch.zeros_like(img)

        C, H, W = img.shape

        # Block damage
        if random.random() > 0.5:
            block_size = random.randint(12, 24)
            x = random.randint(0, W - block_size)
            y = random.randint(0, H - block_size)
            img[:, y:y+block_size, x:x+block_size] = 0
            img[0, y:y+block_size, x:x+block_size] = 1
            img[1, y:y+block_size, x:x+block_size] = 0
            img[2, y:y+block_size, x:x+block_size] = 0
            mask[:, y:y+block_size, x:x+block_size] = 1

        # Streaks / scratches
        if random.random() > 0.5:
            num_streaks = random.randint(3, 7)
            for _ in range(num_streaks):
                x1, y1 = random.randint(0, W-1), random.randint(0, H-1)
                length = random.randint(10, 30)
                angle = random.uniform(-0.5, 0.5)
                for i in range(length):
                    xi = int(x1 + i)
                    yi = int(y1 + i*angle)
                    if 0 <= xi < W and 0 <= yi < H:
                        img[:, yi, xi] = 0
                        img[0, yi, xi] = 1
                        img[1, yi, xi] = 0
                        img[2, yi, xi] = 0
                        mask[:, yi, xi] = 1
        return img, mask

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        clean, _ = self.data[idx]
        damaged, mask = self.add_damage(clean)
        return damaged, clean, mask

train_dataset = DamagedDataset(train=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
100%|██████████| 170M/170M [00:06<00:00, 27.8MB/s]
# Step 3: Training

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

print("Training started...")
for epoch in range(3):
    for damaged, clean, mask in train_loader:
        damaged, clean = damaged.to(device), clean.to(device)
        output = model(damaged)
        loss = criterion(output, clean)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")
Training started...
Epoch 1: Loss = 0.0004
Epoch 2: Loss = 0.0005
Epoch 3: Loss = 0.0006
# Helper to add black border
def add_black_border(img_tensor, border=1):
    bordered = img_tensor.clone()
    C, H, W = bordered.shape

    # Top and bottom
    bordered[:, :border, :] = 0
    bordered[:, H-border:, :] = 0

    # Left and right
    bordered[:, :, :border] = 0
    bordered[:, :, W-border:] = 0

    return bordered

# Step 4: Visualization

def show_images(damaged, restored, clean, masks, n_samples=3):
    stages = ['Damaged', 'AI Patch', 'Restored', 'Original']
    plt.figure(figsize=(10, 2 * n_samples), facecolor='white')

    for i in range(n_samples):
        # Damaged input
        img = add_black_border(damaged[i].cpu(), border=1)
        img = torch.clamp(img, 0, 1).float()
        plt.subplot(n_samples, 4, i*4 + 1)
        ax = plt.gca()
        ax.set_facecolor('white')
        plt.imshow(np.transpose(img.numpy(), (1,2,0)), origin='upper')
        if i == 0:
            plt.title("Damaged")
        plt.axis('off')

        # AI Patch (true restored colors)
        canvas = torch.ones_like(restored[i].cpu())
        canvas = canvas * (1 - masks[i].cpu()) + restored[i].detach().cpu() * masks[i].cpu()
        canvas = torch.clamp(canvas, 0, 1)
        canvas = add_black_border(canvas, border=1)

        plt.subplot(n_samples, 4, i*4 + 2)
        ax = plt.gca()
        ax.set_facecolor('white')
        plt.imshow(np.transpose(canvas.numpy(), (1,2,0)), origin='upper')
        if i == 0:
            plt.title("AI Patch")
        plt.axis('off')

        # Restored output
        img = add_black_border(restored[i].detach().cpu(), border=1)
        img = torch.clamp(img, 0, 1).float()
        plt.subplot(n_samples, 4, i*4 + 3)
        ax = plt.gca()
        ax.set_facecolor('white')
        plt.imshow(np.transpose(img.numpy(), (1,2,0)), origin='upper')
        if i == 0:
            plt.title("Restored")
        plt.axis('off')

        # Original ground truth
        img = add_black_border(clean[i].cpu(), border=1)
        img = torch.clamp(img, 0, 1).float()
        plt.subplot(n_samples, 4, i*4 + 4)
        ax = plt.gca()
        ax.set_facecolor('white')
        plt.imshow(np.transpose(img.numpy(), (1,2,0)), origin='upper')
        if i == 0:
            plt.title("Original")
        plt.axis('off')

    plt.tight_layout()
    plt.show()
# Test on a few samples
test_dataset = DamagedDataset(train=False)
damaged, clean, masks = zip(*[test_dataset[i] for i in range(5)])
damaged = torch.stack(damaged).to(device)
clean = torch.stack(clean).to(device)
masks = torch.stack(masks)

restored = model(damaged)

show_images(damaged, restored, clean, masks)
../_images/60dff0cfcd5a2657e6823e27003aa9646004ca1f8e301bd959591d152461bb3f.png

Report: Download (PDF)