Skip to content

Custom Models

This tutorial shows how to implement a custom autoencoder architecture using Tinify modules.

Basic Architecture

Let's build a simple autoencoder with:

  • An EntropyBottleneck module
  • 3 convolutional layers for encoding
  • 3 transposed convolutions for decoding
  • GDN activation functions
import torch.nn as nn

from tinify.entropy_models import EntropyBottleneck
from tinify.layers import GDN


class Network(nn.Module):
    def __init__(self, N=128):
        super().__init__()
        self.entropy_bottleneck = EntropyBottleneck(N)
        self.encode = nn.Sequential(
            nn.Conv2d(3, N, stride=2, kernel_size=5, padding=2),
            GDN(N),
            nn.Conv2d(N, N, stride=2, kernel_size=5, padding=2),
            GDN(N),
            nn.Conv2d(N, N, stride=2, kernel_size=5, padding=2),
        )

        self.decode = nn.Sequential(
            nn.ConvTranspose2d(N, N, kernel_size=5, padding=2, output_padding=1, stride=2),
            GDN(N, inverse=True),
            nn.ConvTranspose2d(N, N, kernel_size=5, padding=2, output_padding=1, stride=2),
            GDN(N, inverse=True),
            nn.ConvTranspose2d(N, 3, kernel_size=5, padding=2, output_padding=1, stride=2),
        )

    def forward(self, x):
        y = self.encode(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y)
        x_hat = self.decode(y_hat)
        return x_hat, y_likelihoods

The strided convolutions reduce spatial dimensions while increasing channels, helping learn better latent representations. The bottleneck module provides differentiable entropy estimation during training.

Note

See the original paper Variational image compression with a scale hyperprior and the tensorflow/compression documentation for detailed explanations.

Loss Functions

Rate-Distortion Loss

The rate-distortion loss maximizes reconstruction quality (PSNR) while minimizing the bitrate:

\[\mathcal{L} = \mathcal{D} + \lambda \cdot \mathcal{R}\]
import math
import torch.nn.functional as F

x = torch.rand(1, 3, 64, 64)
net = Network()
x_hat, y_likelihoods = net(x)

# Bitrate of the quantized latent
N, _, H, W = x.size()
num_pixels = N * H * W
bpp_loss = torch.log(y_likelihoods).sum() / (-math.log(2) * num_pixels)

# Mean square error
mse_loss = F.mse_loss(x, x_hat)

# Final loss term
lmbda = 0.01  # Trade-off parameter
loss = mse_loss + lmbda * bpp_loss

Tip

Variable bit-rate architectures are possible but beyond this tutorial's scope. See Variable Rate Deep Image Compression With a Conditional Autoencoder.

Auxiliary Loss

The entropy bottleneck parameters need separate optimization to minimize density model evaluation:

aux_loss = net.entropy_bottleneck.loss()

This auxiliary loss must be minimized during or after training.

Using CompressionModel Base Class

Tinify provides a CompressionModel base class with helpful utilities:

from tinify.models import CompressionModel
from tinify.models.utils import conv, deconv


class Network(CompressionModel):
    def __init__(self, N=128):
        super().__init__()
        self.encode = nn.Sequential(
            conv(3, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, N),
        )

        self.decode = nn.Sequential(
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, 3),
        )

    def forward(self, x):
        y = self.encode(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y)
        x_hat = self.decode(y_hat)
        return x_hat, y_likelihoods

Setting Up Optimizers

Train both the compression network and entropy bottleneck with separate optimizers:

import torch.optim as optim

# Main network parameters (exclude quantiles)
parameters = set(
    p for n, p in net.named_parameters()
    if not n.endswith(".quantiles")
)

# Auxiliary parameters (entropy bottleneck quantiles)
aux_parameters = set(
    p for n, p in net.named_parameters()
    if n.endswith(".quantiles")
)

optimizer = optim.Adam(parameters, lr=1e-4)
aux_optimizer = optim.Adam(aux_parameters, lr=1e-3)

Note

You can also use PyTorch's parameter groups to define a single optimizer.

Training Loop

x = torch.rand(1, 3, 64, 64)

for i in range(num_epochs):
    optimizer.zero_grad()
    aux_optimizer.zero_grad()

    x_hat, y_likelihoods = net(x)

    # Compute rate-distortion loss
    N, _, H, W = x.size()
    num_pixels = N * H * W
    bpp_loss = torch.log(y_likelihoods).sum() / (-math.log(2) * num_pixels)
    mse_loss = F.mse_loss(x, x_hat)
    loss = mse_loss + lmbda * bpp_loss

    loss.backward()
    optimizer.step()

    # Update auxiliary parameters
    aux_loss = net.aux_loss()
    aux_loss.backward()
    aux_optimizer.step()

    if i % 100 == 0:
        print(f"Step {i}: loss={loss.item():.4f}, bpp={bpp_loss.item():.4f}")

Adding a Hyperprior

For better compression, add a hyperprior network:

from tinify.models import ScaleHyperprior

# Use the built-in scale hyperprior model
model = ScaleHyperprior(N=128, M=192)

Or implement your own by following the patterns in tinify.models.google.

Next Steps

  • Explore the Model Zoo for pre-trained architectures
  • Check the API Reference for all available base classes
  • Look at examples/train.py for a complete training script