Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EntropyBottleneck with adjustable bin width #308

Open
ghost opened this issue Sep 21, 2024 · 4 comments
Open

EntropyBottleneck with adjustable bin width #308

ghost opened this issue Sep 21, 2024 · 4 comments

Comments

@ghost
Copy link

ghost commented Sep 21, 2024

Feature

Support for Custom Bin Width in EntropyBottleneck

Motivation

So far only bin width equal to 1 is considered, but would be good to have this as tunable option.

Additional context

Are the only methods that should be changed quantize and _likelihood ? Or are there other important changes I am missing..
I'm not sure about _get_medians.

Here is how I would change quantize:

def quantize(
    self, inputs: Tensor, mode: str, means: Optional[Tensor] = None, bin_width: float = 1.0) -> Tensor:
    if mode not in ("noise", "dequantize", "symbols"):
        raise ValueError(f'Invalid quantization mode: "{mode}"')

    if mode == "noise":
        half = bin_width / 2
        noise = torch.empty_like(inputs).uniform_(-half, half)
        inputs = inputs + noise
        return inputs

    outputs = inputs.clone()
    if means is not None:
        outputs -= means

    outputs = torch.round(outputs / bin_width) * bin_width

    if mode == "dequantize":
        if means is not None:
            outputs += means
        return outputs

    assert mode == "symbols", mode
    outputs = outputs.int()
    return outputs

and here _likelihood:

def _likelihood(self, inputs: Tensor, bin_width: float = 1.0, stop_gradient: bool = False) -> Tuple[Tensor, Tensor, Tensor]:
    half = bin_width / 2  # Adjust based on the bin width
    lower = self._logits_cumulative(inputs - half, stop_gradient=stop_gradient)
    upper = self._logits_cumulative(inputs + half, stop_gradient=stop_gradient)
    likelihood = torch.sigmoid(upper) - torch.sigmoid(lower)
    return likelihood, lower, upper

Your further guidance is appreciated!
Thank you!

@YodaEmbedding
Copy link
Contributor

YodaEmbedding commented Sep 22, 2024

At a quick glance, this should work for training, though I wonder if the lossless entropy coder also needs adjustment.

A simpler method might be to just rescale the outputs by the desired bin width instead. Since both are uniform quantizers, it should be equivalent.

@uzumaki671
Copy link

uzumaki671 commented Oct 10, 2024

Thank you for the suggestion! Here's how I implemented it:

Overview of Changes:

I added a bin_width parameter to the quantize, dequantize, compress, and decompress methods of the EntropyModel to handle different quantization bin widths at test time.

Quantize and Dequantize Methods:

def quantize(
    self, inputs: Tensor, mode: str, means: Optional[Tensor] = None, bin_width: float = 1.0
) -> Tensor:
    if mode not in ("noise", "dequantize", "symbols"):
        raise ValueError(f'Invalid quantization mode: "{mode}"')

    # Scale inputs by bin width before quantization
    inputs_scaled = inputs / bin_width

    if mode == "noise":
        half = float(0.5)
        noise = torch.empty_like(inputs_scaled).uniform_(-half, half)
        inputs_scaled = inputs_scaled + noise
        return inputs_scaled * bin_width  # Scale back after adding noise

    outputs = inputs_scaled.clone()
    if means is not None:
        outputs -= means / bin_width  # Scale means accordingly

    outputs = torch.round(outputs)  # Quantize to nearest integer

    if mode == "dequantize":
        if means is not None:
            outputs += means / bin_width
        return outputs * bin_width  # Scale back to original

    assert mode == "symbols", mode
    outputs = outputs.int()
    return outputs

@staticmethod
def dequantize(
    inputs: Tensor, means: Optional[Tensor] = None, dtype: torch.dtype = torch.float, bin_width: float = 1.0
) -> Tensor:
    if means is not None:
        outputs = inputs.type_as(means)
        outputs += means / bin_width  # Adjust means
    else:
        outputs = inputs.type(dtype)
    return outputs * bin_width  # Scale back to original

Compress and Decompress Methods:

def compress(self, inputs, indexes, means=None, bin_width: float = 1.0):
    symbols = self.quantize(inputs, "symbols", means, bin_width)

    if len(inputs.size()) < 2:
        raise ValueError(
            "Invalid `inputs` size. Expected a tensor with at least 2 dimensions."
        )

    if inputs.size() != indexes.size():
        raise ValueError("`inputs` and `indexes` should have the same size.")

    self._check_cdf_size()
    self._check_cdf_length()
    self._check_offsets_size()

    strings = []
    for i in range(symbols.size(0)):
        rv = self.entropy_coder.encode_with_indexes(
            symbols[i].reshape(-1).int().tolist(),
            indexes[i].reshape(-1).int().tolist(),
            self._quantized_cdf.tolist(),
            self._cdf_length.reshape(-1).int().tolist(),
            self._offset.reshape(-1).int().tolist(),
        )
        strings.append(rv)
    return strings

def decompress(
    self,
    strings: str,
    indexes: torch.IntTensor,
    dtype: torch.dtype = torch.float,
    means: torch.Tensor = None,
    bin_width: float = 1.0,
):
    if not isinstance(strings, (tuple, list)):
        raise ValueError("Invalid `strings` parameter type.")

    if not len(strings) == indexes.size(0):
        raise ValueError("Invalid strings or indexes parameters")

    if len(indexes.size()) < 2:
        raise ValueError(
            "Invalid `indexes` size. Expected a tensor with at least 2 dimensions."
        )

    self._check_cdf_size()
    self._check_cdf_length()
    self._check_offsets_size()

    if means is not None:
        if means.size()[:2] != indexes.size()[:2]:
            raise ValueError("Invalid means or indexes parameters")
        if means.size() != indexes.size():
            for i in range(2, len(indexes.size())):
                if means.size(i) != 1:
                    raise ValueError("Invalid means parameters")

    cdf = self._quantized_cdf
    outputs = cdf.new_empty(indexes.size())

    for i, s in enumerate(strings):
        values = self.entropy_coder.decode_with_indexes(
            s,
            indexes[i].reshape(-1).int().tolist(),
            cdf.tolist(),
            self._cdf_length.reshape(-1).int().tolist(),
            self._offset.reshape(-1).int().tolist(),
        )
        outputs[i] = torch.tensor(
            values, device=outputs.device, dtype=outputs.dtype
        ).reshape(outputs[i].size())
    outputs = self.dequantize(outputs, means, dtype, bin_width)
    return outputs

EntropyBottleneck Changes:

In the EntropyBottleneck, I added the bin_width parameter to the compress and decompress methods:

def compress(self, x, bin_width=1.0):
    indexes = self._build_indexes(x.size())
    medians = self._get_medians().detach()
    spatial_dims = len(x.size()) - 2
    medians = self._extend_ndims(medians, spatial_dims)
    medians = medians.expand(x.size(0), *([-1] * (spatial_dims + 1)))
    return super().compress(x, indexes, medians, bin_width)

def decompress(self, strings, size, bin_width=1.0):
    output_size = (len(strings), self._quantized_cdf.size(0), *size)
    indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
    medians = self._extend_ndims(self._get_medians().detach(), len(size))
    medians = medians.expand(len(strings), *([-1] * (len(size) + 1)))
    return super().decompress(strings, indexes, medians.dtype, medians, bin_width)

My goal with these changes is to test different bin_width values only at test time and analyze the tradeoff between the length of the bitstream and the reconstruction error.

Do you think rescaling inputs and means by bin_width inside the quantize and dequantize methods is the correct approach?

Thank you!

@YodaEmbedding
Copy link
Contributor

YodaEmbedding commented Nov 4, 2024

Apologies for the late response.

It looks like that should work for the runtime codec.

However, if you wanted to train it as well, you would also need to modify EntropyBottleneck.forward:

class EntropyBottleneck(EntropyModel):
    def forward(
        self, x: Tensor, training: Optional[bool] = None, bin_width=1.0
    ) -> Tuple[Tensor, Tensor]:
        ...

        outputs = self.quantize(
            values,
            "noise" if training else "dequantize",
            self._get_medians(),
            bin_width=bin_width,
        )

        ...

...And use it as follows:

@register_model("bmshj2018-factorized-vbr")
class FactorizedPriorVbr(CompressionModel):
    def forward(self, x, *, bin_width=1.0):
        y = self.g_a(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y, bin_width=bin_width)
        x_hat = self.g_s(y_hat)

    def compress(self, x, *, bin_width=1.0):
        y = self.g_a(x)
        y_strings = self.entropy_bottleneck.compress(y, bin_width=bin_width)
        return {"strings": [y_strings], "shape": y.size()[-2:]}

    def decompress(self, strings, shape, *, bin_width=1.0):
        assert isinstance(strings, list) and len(strings) == 1
        y_hat = self.entropy_bottleneck.decompress(strings[0], shape, bin_width=bin_width)
        x_hat = self.g_s(y_hat).clamp_(0, 1)
        return {"x_hat": x_hat}

QVRF1 does something quite similar, though with two differences:

  • They apply bin_width ($≝ 1 / a$) to the GaussianConditional for the mean-scale hyperprior (mbt2018-mean) model.
  • They "finetune" a single high-rate (λ=0.18) model over a set $A = \{ a_1, a_2, \ldots, a_n \}$ of different bin widths to ensure that the model works well over those bin widths.

    We use staged training strategies, where the network parameters are optimized on $λ=0.18$ for the first 2000 epochs. Then, $A$ are optimized jointly with noise approximation for 500 epochs and straight-through estimation for another 500 epochs.

Footnotes

  1. Tong et al. QVRF: A Quantization-error-aware Variable Rate Framework for Learned Image Compression. https://arxiv.org/pdf/2303.05744

@uzumaki671
Copy link

Thank you very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants