Skip to content

Models

Compression model architectures.

Base Classes

CompressionModel

CompressionModel(entropy_bottleneck_channels=None, init_weights=None)

Bases: Module

Base class for constructing an auto-encoder with any number of EntropyBottleneck or GaussianConditional modules.

Source code in tinify/models/base.py
def __init__(
    self,
    entropy_bottleneck_channels: int | None = None,
    init_weights: bool | None = None,
) -> None:
    super().__init__()

    if entropy_bottleneck_channels is not None:
        warnings.warn(
            "The entropy_bottleneck_channels parameter is deprecated. "
            "Create an entropy_bottleneck in your model directly instead:\n\n"
            "class YourModel(CompressionModel):\n"
            "    def __init__(self):\n"
            "        super().__init__()\n"
            "        self.entropy_bottleneck = "
            "EntropyBottleneck(entropy_bottleneck_channels)\n",
            DeprecationWarning,
            stacklevel=2,
        )
        self.entropy_bottleneck = EntropyBottleneck(entropy_bottleneck_channels)

    if init_weights is not None:
        warnings.warn(
            "The init_weights parameter was removed as it was never functional.",
            DeprecationWarning,
            stacklevel=2,
        )

aux_loss

aux_loss()

Returns the total auxiliary loss over all EntropyBottleneck\s.

In contrast to the primary "net" loss used by the "net" optimizer, the "aux" loss is only used by the "aux" optimizer to update only the EntropyBottleneck.quantiles parameters. In fact, the "aux" loss does not depend on image data at all.

The purpose of the "aux" loss is to determine the range within which most of the mass of a given distribution is contained, as well as its median (i.e. 50% probability). That is, for a given distribution, the "aux" loss converges towards satisfying the following conditions for some chosen tail_mass probability:

  • cdf(quantiles[0]) = tail_mass / 2
  • cdf(quantiles[1]) = 0.5
  • cdf(quantiles[2]) = 1 - tail_mass / 2

This ensures that the concrete _quantized_cdf\s operate primarily within a finitely supported region. Any symbols outside this range must be coded using some alternative method that does not involve the _quantized_cdf\s. Luckily, one may choose a tail_mass probability that is sufficiently small so that this rarely occurs. It is important that we work with _quantized_cdf\s that have a small finite support; otherwise, entropy coding runtime performance would suffer. Thus, tail_mass should not be too small, either!

Source code in tinify/models/base.py
def aux_loss(self) -> Tensor:
    r"""Returns the total auxiliary loss over all ``EntropyBottleneck``\s.

    In contrast to the primary "net" loss used by the "net"
    optimizer, the "aux" loss is only used by the "aux" optimizer to
    update *only* the ``EntropyBottleneck.quantiles`` parameters. In
    fact, the "aux" loss does not depend on image data at all.

    The purpose of the "aux" loss is to determine the range within
    which most of the mass of a given distribution is contained, as
    well as its median (i.e. 50% probability). That is, for a given
    distribution, the "aux" loss converges towards satisfying the
    following conditions for some chosen ``tail_mass`` probability:

    * ``cdf(quantiles[0]) = tail_mass / 2``
    * ``cdf(quantiles[1]) = 0.5``
    * ``cdf(quantiles[2]) = 1 - tail_mass / 2``

    This ensures that the concrete ``_quantized_cdf``\s operate
    primarily within a finitely supported region. Any symbols
    outside this range must be coded using some alternative method
    that does *not* involve the ``_quantized_cdf``\s. Luckily, one
    may choose a ``tail_mass`` probability that is sufficiently
    small so that this rarely occurs. It is important that we work
    with ``_quantized_cdf``\s that have a small finite support;
    otherwise, entropy coding runtime performance would suffer.
    Thus, ``tail_mass`` should not be too small, either!
    """
    loss = sum(m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck))
    return cast(Tensor, loss)

update

update(scale_table=None, force=False, update_quantiles=False)

Updates EntropyBottleneck and GaussianConditional CDFs.

Needs to be called once after training to be able to later perform the evaluation with an actual entropy coder.

Parameters:

Name Type Description Default
scale_table Tensor

table of scales (i.e. stdev) for initializing the Gaussian distributions (default: 64 logarithmically spaced scales from 0.11 to 256)

None
force bool

overwrite previous values (default: False)

False
update_quantiles bool

fast update quantiles (default: False)

False

Returns:

Name Type Description
updated bool

True if at least one of the modules was updated.

Source code in tinify/models/base.py
def update(
    self,
    scale_table: Tensor | None = None,
    force: bool = False,
    update_quantiles: bool = False,
) -> bool:
    """Updates EntropyBottleneck and GaussianConditional CDFs.

    Needs to be called once after training to be able to later perform the
    evaluation with an actual entropy coder.

    Args:
        scale_table (torch.Tensor): table of scales (i.e. stdev)
            for initializing the Gaussian distributions
            (default: 64 logarithmically spaced scales from 0.11 to 256)
        force (bool): overwrite previous values (default: False)
        update_quantiles (bool): fast update quantiles (default: False)

    Returns:
        updated (bool): True if at least one of the modules was updated.
    """
    if scale_table is None:
        scale_table = get_scale_table()
    updated = False
    for _, module in self.named_modules():
        if isinstance(module, EntropyBottleneck):
            updated |= module.update(force=force, update_quantiles=update_quantiles)
        if isinstance(module, GaussianConditional):
            updated |= module.update_scale_table(scale_table, force=force)
    return updated

load_state_dict

load_state_dict(state_dict, strict=True)
Source code in tinify/models/base.py
def load_state_dict(self, state_dict: dict[str, Any], strict: bool = True) -> Any:
    for name, module in self.named_modules():
        if not any(x.startswith(name) for x in state_dict.keys()):
            continue

        if isinstance(module, EntropyBottleneck):
            update_registered_buffers(
                module,
                name,
                ["_quantized_cdf", "_offset", "_cdf_length"],
                state_dict,
            )
            state_dict = remap_old_keys(name, state_dict)

        if isinstance(module, GaussianConditional):
            update_registered_buffers(
                module,
                name,
                ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
                state_dict,
            )

    return nn.Module.load_state_dict(self, state_dict, strict=strict)

Image Compression Models

Factorized Prior

FactorizedPrior

FactorizedPrior(N, M, **kwargs)

Bases: CompressionModel

Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston: "Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>_, Int Conf. on Learning Representations (ICLR), 2018.

.. code-block:: none

          ┌───┐    y
    x ──►─┤g_a├──►─┐
          └───┘    │
                   ▼
                 ┌─┴─┐
                 │ Q │
                 └─┬─┘
                   │
             y_hat ▼
                   │
                   ·
                EB :
                   ·
                   │
             y_hat ▼
                   │
          ┌───┐    │
x_hat ──◄─┤g_s├────┘
          └───┘

EB = Entropy bottleneck

Parameters:

Name Type Description Default
N int

Number of channels

required
M int

Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder)

required

from_state_dict classmethod

from_state_dict(state_dict)

Return a new model instance from state_dict.

Scale Hyperprior

ScaleHyperprior

ScaleHyperprior(N, M, **kwargs)

Bases: CompressionModel

Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston: "Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>_ Int. Conf. on Learning Representations (ICLR), 2018.

.. code-block:: none

          ┌───┐    y     ┌───┐  z  ┌───┐ z_hat      z_hat ┌───┐
    x ──►─┤g_a├──►─┬──►──┤h_a├──►──┤ Q ├───►───·⋯⋯·───►───┤h_s├─┐
          └───┘    │     └───┘     └───┘        EB        └───┘ │
                   ▼                                            │
                 ┌─┴─┐                                          │
                 │ Q │                                          ▼
                 └─┬─┘                                          │
                   │                                            │
             y_hat ▼                                            │
                   │                                            │
                   ·                                            │
                GC : ◄─────────────────────◄────────────────────┘
                   ·                 scales_hat
                   │
             y_hat ▼
                   │
          ┌───┐    │
x_hat ──◄─┤g_s├────┘
          └───┘

EB = Entropy bottleneck
GC = Gaussian conditional

Parameters:

Name Type Description Default
N int

Number of channels

required
M int

Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder)

required

from_state_dict classmethod

from_state_dict(state_dict)

Return a new model instance from state_dict.

Mean-Scale Hyperprior

MeanScaleHyperprior

MeanScaleHyperprior(N, M, **kwargs)

Bases: ScaleHyperprior

Scale Hyperprior with non zero-mean Gaussian conditionals from D. Minnen, J. Balle, G.D. Toderici: "Joint Autoregressive and Hierarchical Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>_, Adv. in Neural Information Processing Systems 31 (NeurIPS 2018).

.. code-block:: none

          ┌───┐    y     ┌───┐  z  ┌───┐ z_hat      z_hat ┌───┐
    x ──►─┤g_a├──►─┬──►──┤h_a├──►──┤ Q ├───►───·⋯⋯·───►───┤h_s├─┐
          └───┘    │     └───┘     └───┘        EB        └───┘ │
                   ▼                                            │
                 ┌─┴─┐                                          │
                 │ Q │                                          ▼
                 └─┬─┘                                          │
                   │                                            │
             y_hat ▼                                            │
                   │                                            │
                   ·                                            │
                GC : ◄─────────────────────◄────────────────────┘
                   ·                 scales_hat
                   │                 means_hat
             y_hat ▼
                   │
          ┌───┐    │
x_hat ──◄─┤g_s├────┘
          └───┘

EB = Entropy bottleneck
GC = Gaussian conditional

Parameters:

Name Type Description Default
N int

Number of channels

required
M int

Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder)

required

Joint Autoregressive Hierarchical Priors

JointAutoregressiveHierarchicalPriors

JointAutoregressiveHierarchicalPriors(N=192, M=192, **kwargs)

Bases: MeanScaleHyperprior

Joint Autoregressive Hierarchical Priors model from D. Minnen, J. Balle, G.D. Toderici: "Joint Autoregressive and Hierarchical Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>_, Adv. in Neural Information Processing Systems 31 (NeurIPS 2018).

.. code-block:: none

          ┌───┐    y     ┌───┐  z  ┌───┐ z_hat      z_hat ┌───┐
    x ──►─┤g_a├──►─┬──►──┤h_a├──►──┤ Q ├───►───·⋯⋯·───►───┤h_s├─┐
          └───┘    │     └───┘     └───┘        EB        └───┘ │
                   ▼                                            │
                 ┌─┴─┐                                          │
                 │ Q │                                   params ▼
                 └─┬─┘                                          │
             y_hat ▼                  ┌─────┐                   │
                   ├──────────►───────┤  CP ├────────►──────────┤
                   │                  └─────┘                   │
                   ▼                                            ▼
                   │                                            │
                   ·                  ┌─────┐                   │
                GC : ◄────────◄───────┤  EP ├────────◄──────────┘
                   ·     scales_hat   └─────┘
                   │      means_hat
             y_hat ▼
                   │
          ┌───┐    │
x_hat ──◄─┤g_s├────┘
          └───┘

EB = Entropy bottleneck
GC = Gaussian conditional
EP = Entropy parameters network
CP = Context prediction (masked convolution)

Parameters:

Name Type Description Default
N int

Number of channels

192
M int

Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder)

192

from_state_dict classmethod

from_state_dict(state_dict)

Return a new model instance from state_dict.

Attention-based Models

Cheng2020Anchor

Cheng2020Anchor(N=192, **kwargs)

Bases: JointAutoregressiveHierarchicalPriors

Anchor model variant from "Learned Image Compression with Discretized Gaussian Mixture Likelihoods and Attention Modules" <https://arxiv.org/abs/2001.01568>_, by Zhengxue Cheng, Heming Sun, Masaru Takeuchi, Jiro Katto.

Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel convolutions for up-sampling.

Parameters:

Name Type Description Default
N int

Number of channels

192

from_state_dict classmethod

from_state_dict(state_dict)

Return a new model instance from state_dict.

Cheng2020Attention

Cheng2020Attention(N=192, **kwargs)

Bases: Cheng2020Anchor

Self-attention model variant from "Learned Image Compression with Discretized Gaussian Mixture Likelihoods and Attention Modules" <https://arxiv.org/abs/2001.01568>_, by Zhengxue Cheng, Heming Sun, Masaru Takeuchi, Jiro Katto.

Uses self-attention, residual blocks with small convolutions (3x3 and 1x1), and sub-pixel convolutions for up-sampling.

Parameters:

Name Type Description Default
N int

Number of channels

192

Utility Functions

conv

conv(in_channels, out_channels, kernel_size=5, stride=2)
Source code in tinify/models/utils.py
def conv(
    in_channels: int, out_channels: int, kernel_size: int = 5, stride: int = 2
) -> nn.Conv2d:
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=kernel_size // 2,
    )

deconv

deconv(in_channels, out_channels, kernel_size=5, stride=2)
Source code in tinify/models/utils.py
def deconv(
    in_channels: int, out_channels: int, kernel_size: int = 5, stride: int = 2
) -> nn.ConvTranspose2d:
    return nn.ConvTranspose2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        output_padding=stride - 1,
        padding=kernel_size // 2,
    )