Skip to content

Entropy Models

Entropy bottleneck and hyperprior models for learned compression.

Overview

Entropy models are a critical component of learned compression. They model the probability distribution of latent representations, enabling efficient entropy coding.

Entropy Bottleneck

The entropy bottleneck is used to compress the latent representation of an image. It learns a flexible density model of the latent distribution.

EntropyBottleneck

EntropyBottleneck(channels, *args, tail_mass=1e-09, init_scale=10, filters=(3, 3, 3, 3), **kwargs)

Bases: EntropyModel

Entropy bottleneck layer, introduced by J. Balle, D. Minnen, S. Singh, S. J. Hwang, N. Johnston, in "Variational image compression with a scale hyperprior" <https://arxiv.org/abs/1802.01436>_.

This is a re-implementation of the entropy bottleneck layer in tensorflow/compression. See the original paper and the tensorflow documentation <https://github.com/tensorflow/compression/blob/v1.3/docs/entropy_bottleneck.md>__ for an introduction.

Source code in tinify/entropy_models/entropy_models.py
def __init__(
    self,
    channels: int,
    *args: Any,
    tail_mass: float = 1e-9,
    init_scale: float = 10,
    filters: tuple[int, ...] = (3, 3, 3, 3),
    **kwargs: Any,
) -> None:
    super().__init__(*args, **kwargs)

    self.channels = int(channels)
    self.filters = tuple(int(f) for f in filters)
    self.init_scale = float(init_scale)
    self.tail_mass = float(tail_mass)

    # Create parameters
    filters = (1,) + self.filters + (1,)
    scale = self.init_scale ** (1 / (len(self.filters) + 1))
    channels = self.channels

    self.matrices = nn.ParameterList()
    self.biases = nn.ParameterList()
    self.factors = nn.ParameterList()

    for i in range(len(self.filters) + 1):
        init = np.log(np.expm1(1 / scale / filters[i + 1]))
        matrix = torch.Tensor(channels, filters[i + 1], filters[i])
        matrix.data.fill_(init)
        self.matrices.append(nn.Parameter(matrix))

        bias = torch.Tensor(channels, filters[i + 1], 1)
        nn.init.uniform_(bias, -0.5, 0.5)
        self.biases.append(nn.Parameter(bias))

        if i < len(self.filters):
            factor = torch.Tensor(channels, filters[i + 1], 1)
            nn.init.zeros_(factor)
            self.factors.append(nn.Parameter(factor))

    self.quantiles = nn.Parameter(torch.Tensor(channels, 1, 3))
    init = torch.Tensor([-self.init_scale, 0, self.init_scale])
    self.quantiles.data = init.repeat(self.quantiles.size(0), 1, 1)

    target = np.log(2 / self.tail_mass - 1)
    self.register_buffer("target", torch.Tensor([-target, 0, target]))

forward

forward(x, training=None)
Source code in tinify/entropy_models/entropy_models.py
def forward(self, x: Tensor, training: bool | None = None) -> tuple[Tensor, Tensor]:
    if training is None:
        training = self.training

    D = x.dim()
    # B C ...  ->  C B ...
    perm = [1, 0] + list(range(2, D))
    inv_perm = [0] * D
    for i, p in enumerate(perm):
        inv_perm[p] = i

    x = x.permute(*perm).contiguous()
    shape = x.size()
    values = x.reshape(x.size(0), 1, -1)

    # Add noise or quantize
    outputs = self.quantize(
        values, "noise" if training else "dequantize", self._get_medians()
    )

    if not torch.jit.is_scripting():
        likelihood, _, _ = self._likelihood(outputs)
        if self.use_likelihood_bound:
            likelihood = self.likelihood_lower_bound(likelihood)
    else:
        raise NotImplementedError()
        # TorchScript not yet supported
        # likelihood = torch.zeros_like(outputs)

    # Convert back to input tensor shape
    outputs = outputs.reshape(shape).permute(*inv_perm).contiguous()
    likelihood = likelihood.reshape(shape).permute(*inv_perm).contiguous()

    return outputs, likelihood

compress

compress(x)
Source code in tinify/entropy_models/entropy_models.py
def compress(self, x: Tensor) -> list[bytes]:
    indexes = self._build_indexes(x.size())
    medians = self._get_medians().detach()
    spatial_dims = len(x.size()) - 2
    medians = self._extend_ndims(medians, spatial_dims)
    medians = medians.expand(x.size(0), *([-1] * (spatial_dims + 1)))
    return super().compress(x, indexes, medians)

decompress

decompress(strings, size)
Source code in tinify/entropy_models/entropy_models.py
def decompress(self, strings: list[bytes], size: tuple[int, ...]) -> Tensor:
    output_size = (len(strings), self._quantized_cdf.size(0), *size)
    indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
    medians = self._extend_ndims(self._get_medians().detach(), len(size))
    medians = medians.expand(len(strings), *([-1] * (len(size) + 1)))
    return super().decompress(strings, indexes, medians.dtype, medians)

loss

loss()
Source code in tinify/entropy_models/entropy_models.py
def loss(self) -> Tensor:
    logits = self._logits_cumulative(self.quantiles, stop_gradient=True)
    loss = torch.abs(logits - self.target).sum()
    return loss

update

update(force=False, update_quantiles=False)
Source code in tinify/entropy_models/entropy_models.py
def update(self, force: bool = False, update_quantiles: bool = False) -> bool:
    # Check if we need to update the bottleneck parameters, the offsets are
    # only computed and stored when the conditonal model is update()'d.
    if self._offset.numel() > 0 and not force:
        return False

    if update_quantiles:
        self._update_quantiles()

    medians = self.quantiles[:, 0, 1]

    minima = medians - self.quantiles[:, 0, 0]
    minima = torch.ceil(minima).int()
    minima = torch.clamp(minima, min=0)

    maxima = self.quantiles[:, 0, 2] - medians
    maxima = torch.ceil(maxima).int()
    maxima = torch.clamp(maxima, min=0)

    self._offset = -minima

    pmf_start = medians - minima
    pmf_length = maxima + minima + 1

    max_length = pmf_length.max().item()
    device = pmf_start.device
    samples = torch.arange(max_length, device=device)
    samples = samples[None, :] + pmf_start[:, None, None]

    pmf, lower, upper = self._likelihood(samples, stop_gradient=True)
    pmf = pmf[:, 0, :]
    tail_mass = torch.sigmoid(lower[:, 0, :1]) + torch.sigmoid(-upper[:, 0, -1:])

    quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
    self._quantized_cdf = quantized_cdf
    self._cdf_length = pmf_length + 2
    return True

Gaussian Conditional

GaussianConditional

GaussianConditional(scale_table, *args, scale_bound=0.11, tail_mass=1e-09, **kwargs)

Bases: EntropyModel

Gaussian conditional layer, introduced by J. Balle, D. Minnen, S. Singh, S. J. Hwang, N. Johnston, in "Variational image compression with a scale hyperprior" <https://arxiv.org/abs/1802.01436>_.

This is a re-implementation of the Gaussian conditional layer in tensorflow/compression. See the tensorflow documentation <https://github.com/tensorflow/compression/blob/v1.3/docs/api_docs/python/tfc/GaussianConditional.md>__ for more information.

Source code in tinify/entropy_models/entropy_models.py
def __init__(
    self,
    scale_table: list[float] | tuple[float, ...] | None,
    *args: Any,
    scale_bound: float = 0.11,
    tail_mass: float = 1e-9,
    **kwargs: Any,
) -> None:
    super().__init__(*args, **kwargs)

    if not isinstance(scale_table, (type(None), list, tuple)):
        raise ValueError(f'Invalid type for scale_table "{type(scale_table)}"')

    if isinstance(scale_table, (list, tuple)) and len(scale_table) < 1:
        raise ValueError(f'Invalid scale_table length "{len(scale_table)}"')

    if scale_table and (
        scale_table != sorted(scale_table) or any(s <= 0 for s in scale_table)
    ):
        raise ValueError(f'Invalid scale_table "({scale_table})"')

    self.tail_mass = float(tail_mass)
    if scale_bound is None and scale_table:
        scale_bound = self.scale_table[0]
    if scale_bound <= 0:
        raise ValueError("Invalid parameters")
    self.lower_bound_scale = LowerBound(scale_bound)

    self.register_buffer(
        "scale_table",
        self._prepare_scale_table(scale_table) if scale_table else torch.Tensor(),
    )

    self.register_buffer(
        "scale_bound",
        torch.Tensor([float(scale_bound)]) if scale_bound is not None else None,
    )

forward

forward(inputs, scales, means=None, training=None)
Source code in tinify/entropy_models/entropy_models.py
def forward(
    self,
    inputs: Tensor,
    scales: Tensor,
    means: Tensor | None = None,
    training: bool | None = None,
) -> tuple[Tensor, Tensor]:
    if training is None:
        training = self.training
    outputs = self.quantize(inputs, "noise" if training else "dequantize", means)
    likelihood = self._likelihood(outputs, scales, means)
    if self.use_likelihood_bound:
        likelihood = self.likelihood_lower_bound(likelihood)
    return outputs, likelihood

Entropy Model Base

EntropyModel

EntropyModel(likelihood_bound=1e-09, entropy_coder=None, entropy_coder_precision=16)

Bases: Module

Entropy model base class.

Parameters:

Name Type Description Default
likelihood_bound float

minimum likelihood bound

1e-09
entropy_coder str

set the entropy coder to use, use default one if None

None
entropy_coder_precision int

set the entropy coder precision

16

compress

compress(inputs, indexes, means=None)

Compress input tensors to char strings.

Parameters:

Name Type Description Default
inputs Tensor

input tensors

required
indexes Tensor

tensors CDF indexes

required
means Tensor

optional tensor means

None

decompress

decompress(strings, indexes, dtype=float, means=None)

Decompress char strings to tensors.

Parameters:

Name Type Description Default
strings list[bytes]

compressed tensors

required
indexes Tensor

tensors CDF indexes

required
dtype dtype

type of dequantized output

float
means Tensor

optional tensor means

None