Python

【Pytorch】Total Variationの実装

2022年2月13日

※ 本ブログはプロモーションが含まれています

【Pytorch】Total Variationの実装

ポロシャツエンジニア

3分で技術が身に付くブログ!
・ポロシャツを愛するエンジニア
・企業にて研究開発職
・画像処理 | Deep Learning

画像の平滑化を損失関数に組み込むことを目的に,PytorchにてTotal Variationを実装しました!
Pytorchの学習には、 PyTorchニューラルネットワーク 実装ハンドブック がおすすめです!

おすすめの読者

  • PytorchでTotal Variationを実装したい!
  • Total Variationを活用した、画像処理に挑戦したい!
こちらもおすすめ!

【理論】Total Variationの理論を理解する

Total Variation(総変動)はあるデータの変動の総量を表し,2次元画像信号\(x\)におけるanisotropic total variationは

$$TV_{anisotropic} = \sum_{i, j} |x_{i+1, j}-x_{i, j}| + |x_{i, j+1}-x_{i, j}|$$

のように表されます。

Wikipediaのページは以下を参照してください。
Total variation on Wikipedia
Total variation denoising on Wikipedia

【実装】Pytorchでの実装例

Interface

以下のTensorに対応する形で実装しました。

  • 2-dimensional tensor: torch.Tensor([height, width])
  • 3-dimensional tensor: torch.Tensor([channel, height, width])
  • 4-dimensional tensor: torch.Tensor([batch, channel, height, width])

出力Tensorのshapeは入力Tensorのshapeに依存します。

  • Input tensor was 2 or 3 dimensional: return tensor as a scalar
  • Input tensor was 4 dimensional: return tensor as an array
  • batch毎にtotal variationを算出する

下記方法でインスタンス生成。

# There are two ways to make instance.
# with `is_mean_reduction=False`
loss_ = TotalVariation()
# with `is_mean_reduction=True`
loss_ = TotalVariation(is_mean_reduction=True)
  • is_mean_reduction=False: 各ピクセルの変動の合計値を出力
  • is_mean_reduction=True: 各ピクセルの変動の平均値を出力

Pytorchでの実装

import torch
from torch import Tensor

class TotalVariation(torch.nn.Module):
    """Calculate the total variation for one or batch tensor.

    The total variation is the sum of the absolute differences for neighboring
    pixel-values in the input images.

    Example:
    >>> import torch
    >>> loss_ = TotalVariation()

    >>> # Example for 2-dimensional tensor.
    >>> tensor_ = torch.arange(0, 2.5, 0.1, requires_grad=True).reshape(5, 5)
    >>> tensor_.shape
    torch.Size([5, 5])
    >>> loss_(tensor_)
    tensor(12., grad_fn=<AddBackward0>)

    >>> # Example for 3-dimensional tensor.
    >>> tensor_ = torch.arange(0, 2.5, 0.1, requires_grad=True).reshape(1, 5, 5)
    >>> tensor_.shape
    torch.Size([1, 5, 5])
    >>> loss_(tensor_)
    tensor(12., grad_fn=<AddBackward0>)

    >>> # Example for 4-dimensional tensor.
    >>> tensor_ = (
    ...     torch.arange(0, 10.0, 0.1, requires_grad=True).reshape(4, 1, 5, 5)
    ... )
    >>> tensor_.shape
    torch.Size([4, 1, 5, 5])
    >>> loss_(tensor_)
    tensor([12.0000, 12.0000, 12.0000, 12.0000], grad_fn=<AddBackward0>)

    >>> # Example for 4-dimensional tensor with `is_mean_reduction=True`.
    >>> loss_ = TotalVariation(is_mean_reduction=True)
    >>> tensor_ = (
    ...     torch.arange(0, 10.0, 0.1, requires_grad=True).reshape(4, 1, 5, 5)
    ... )
    >>> loss_(tensor_)
    tensor([0.6000, 0.6000, 0.6000, 0.6000], grad_fn=<AddBackward0>)
    """

    def __init__(self, *, is_mean_reduction: bool = False) -> None:
        """Constructor.

        Args:
            is_mean_reduction (bool, optional):
                When `is_mean_reduction` is True, the sum of the output will be
                divided by the number of elements those used
                for total variation calculation. Defaults to False.
        """
        super(TotalVariation, self).__init__()
        self._is_mean = is_mean_reduction

    def forward(self, tensor_: Tensor) -> Tensor:
        return self._total_variation(tensor_)

    def _total_variation(self, tensor_: Tensor) -> Tensor:
        """Calculate total variation.

        Args:
            tensor_ (Tensor): input tensor must be the any following shapes:
                - 2-dimensional: [height, width]
                - 3-dimensional: [channel, height, width]
                - 4-dimensional: [batch, channel, height, width]

        Raises:
            ValueError: Input tensor is not either 2, 3 or 4-dimensional.

        Returns:
            Tensor: the output tensor shape depends on the size of the input.
                - Input tensor was 2 or 3 dimensional
                    return tensor as a scalar
                - Input tensor was 4 dimensional
                    return tensor as an array
        """
        ndims_ = tensor_.dim()

        if ndims_ == 2:
            y_diff = tensor_[1:, :] - tensor_[:-1, :]
            x_diff = tensor_[:, 1:] - tensor_[:, :-1]
        elif ndims_ == 3:
            y_diff = tensor_[:, 1:, :] - tensor_[:, :-1, :]
            x_diff = tensor_[:, :, 1:] - tensor_[:, :, :-1]
        elif ndims_ == 4:
            y_diff = tensor_[:, :, 1:, :] - tensor_[:, :, :-1, :]
            x_diff = tensor_[:, :, :, 1:] - tensor_[:, :, :, :-1]
        else:
            raise ValueError(
                'Input tensor must be either 2, 3 or 4-dimensional.')

        sum_axis = tuple({abs(x) for x in range(ndims_ - 3, ndims_)})
        y_denominator = (
            y_diff.shape[sum_axis[0]::].numel() if self._is_mean else 1
        )
        x_denominator = (
            x_diff.shape[sum_axis[0]::].numel() if self._is_mean else 1
        )

        return (
            torch.sum(torch.abs(y_diff), dim=sum_axis) / y_denominator
            + torch.sum(torch.abs(x_diff), dim=sum_axis) / x_denominator
        )

使用例

$ python
Python 3.8.2 
[GCC 7.5.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> from total_variation import TotalVariation
>>> loss_ = TotalVariation(is_mean_reduction=True)
>>> tensor_ = torch.arange(0, 10.0, 0.1, requires_grad=True).reshape(4, 1, 5, 5)
>>> tensor_
tensor([[[[0.0000, 0.1000, 0.2000, 0.3000, 0.4000],
          [0.5000, 0.6000, 0.7000, 0.8000, 0.9000],
          [1.0000, 1.1000, 1.2000, 1.3000, 1.4000],
          [1.5000, 1.6000, 1.7000, 1.8000, 1.9000],
          [2.0000, 2.1000, 2.2000, 2.3000, 2.4000]]],


        [[[2.5000, 2.6000, 2.7000, 2.8000, 2.9000],
          [3.0000, 3.1000, 3.2000, 3.3000, 3.4000],
          [3.5000, 3.6000, 3.7000, 3.8000, 3.9000],
          [4.0000, 4.1000, 4.2000, 4.3000, 4.4000],
          [4.5000, 4.6000, 4.7000, 4.8000, 4.9000]]],


        [[[5.0000, 5.1000, 5.2000, 5.3000, 5.4000],
          [5.5000, 5.6000, 5.7000, 5.8000, 5.9000],
          [6.0000, 6.1000, 6.2000, 6.3000, 6.4000],
          [6.5000, 6.6000, 6.7000, 6.8000, 6.9000],
          [7.0000, 7.1000, 7.2000, 7.3000, 7.4000]]],


        [[[7.5000, 7.6000, 7.7000, 7.8000, 7.9000],
          [8.0000, 8.1000, 8.2000, 8.3000, 8.4000],
          [8.5000, 8.6000, 8.7000, 8.8000, 8.9000],
          [9.0000, 9.1000, 9.2000, 9.3000, 9.4000],
          [9.5000, 9.6000, 9.7000, 9.8000, 9.9000]]]], grad_fn=<ViewBackward>)
>>> output = loss_(tensor_)
>>> output
tensor([0.6000, 0.6000, 0.6000, 0.6000], grad_fn=<AddBackward0>)
>>> output[0].backward()

上記例では,すべてのピクセルにおいて,

  • Y方向の差分: 0.5000
  • X方向の差分: 0.1000

であるので,Y方向およびX方向の変動の合計値である0.6000が出力されています.

L1Lossへの接続

下記のように,良しなに実装すると,L1 Lossとの接続が可能です。

$ python
Python 3.8.2 
[GCC 7.5.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> from total_variation import TotalVariation
>>> tv_loss = TotalVariation(is_mean_reduction=True)
>>> l1_loss = torch.nn.L1Loss()
>>> input_ = torch.arange(0, 10.0, 0.1, requires_grad=True).reshape(4, 1, 5, 5)
>>> reference = torch.arange(0, 20.0, 0.2, requires_grad=True).reshape(4, 1, 5, 5)
>>> tv_input = tv_loss(input_)
>>> tv_reference = tv_loss(reference)
>>> tv_input
tensor([0.6000, 0.6000, 0.6000, 0.6000], grad_fn=<AddBackward0>)
>>> tv_reference
tensor([1.2000, 1.2000, 1.2000, 1.2000], grad_fn=<AddBackward0>)
>>> l1_output = l1_loss(tv_input, tv_reference)
>>> l1_output
tensor(0.6000, grad_fn=<L1LossBackward>)
>>> l1_output.backward()

【まとめ】PytorchでTotal variationを実装する

Total Variation自体は以前より存在するテクニックだが,Deep Learningを用いた超解像やノイズ除去等においても有用な場面があるのかなと思います!

Pytorchの学習には、 PyTorchニューラルネットワーク 実装ハンドブック がおすすめです!

こちらもおすすめ!
  • この記事を書いた人

ポロシャツエンジニア

3分で技術が身に付くブログ!
・ポロシャツを愛するエンジニア
・企業にて研究開発職
・画像処理 | Deep Learning

-Python