画像の平滑化を損失関数に組み込むことを目的に,PytorchにてTotal Variationを実装しました!
Pytorchの学習には、 PyTorchニューラルネットワーク 実装ハンドブック
がおすすめです!
おすすめの読者
- PytorchでTotal Variationを実装したい!
- Total Variationを活用した、画像処理に挑戦したい!
【理論】Total Variationの理論を理解する
Total Variation(総変動)はあるデータの変動の総量を表し,2次元画像信号\(x\)におけるanisotropic total variationは
$$TV_{anisotropic} = \sum_{i, j} |x_{i+1, j}-x_{i, j}| + |x_{i, j+1}-x_{i, j}|$$
のように表されます。
Wikipediaのページは以下を参照してください。
Total variation on Wikipedia
Total variation denoising on Wikipedia
【実装】Pytorchでの実装例
Interface
以下のTensorに対応する形で実装しました。
- 2-dimensional tensor:
torch.Tensor([height, width])
- 3-dimensional tensor:
torch.Tensor([channel, height, width])
- 4-dimensional tensor:
torch.Tensor([batch, channel, height, width])
出力Tensorのshapeは入力Tensorのshapeに依存します。
- Input tensor was 2 or 3 dimensional: return tensor as a scalar
- Input tensor was 4 dimensional: return tensor as an array
- batch毎にtotal variationを算出する
下記方法でインスタンス生成。
# There are two ways to make instance.
# with `is_mean_reduction=False`
loss_ = TotalVariation()
# with `is_mean_reduction=True`
loss_ = TotalVariation(is_mean_reduction=True)
is_mean_reduction=False
: 各ピクセルの変動の合計値を出力is_mean_reduction=True
: 各ピクセルの変動の平均値を出力
Pytorchでの実装
import torch
from torch import Tensor
class TotalVariation(torch.nn.Module):
"""Calculate the total variation for one or batch tensor.
The total variation is the sum of the absolute differences for neighboring
pixel-values in the input images.
Example:
>>> import torch
>>> loss_ = TotalVariation()
>>> # Example for 2-dimensional tensor.
>>> tensor_ = torch.arange(0, 2.5, 0.1, requires_grad=True).reshape(5, 5)
>>> tensor_.shape
torch.Size([5, 5])
>>> loss_(tensor_)
tensor(12., grad_fn=<AddBackward0>)
>>> # Example for 3-dimensional tensor.
>>> tensor_ = torch.arange(0, 2.5, 0.1, requires_grad=True).reshape(1, 5, 5)
>>> tensor_.shape
torch.Size([1, 5, 5])
>>> loss_(tensor_)
tensor(12., grad_fn=<AddBackward0>)
>>> # Example for 4-dimensional tensor.
>>> tensor_ = (
... torch.arange(0, 10.0, 0.1, requires_grad=True).reshape(4, 1, 5, 5)
... )
>>> tensor_.shape
torch.Size([4, 1, 5, 5])
>>> loss_(tensor_)
tensor([12.0000, 12.0000, 12.0000, 12.0000], grad_fn=<AddBackward0>)
>>> # Example for 4-dimensional tensor with `is_mean_reduction=True`.
>>> loss_ = TotalVariation(is_mean_reduction=True)
>>> tensor_ = (
... torch.arange(0, 10.0, 0.1, requires_grad=True).reshape(4, 1, 5, 5)
... )
>>> loss_(tensor_)
tensor([0.6000, 0.6000, 0.6000, 0.6000], grad_fn=<AddBackward0>)
"""
def __init__(self, *, is_mean_reduction: bool = False) -> None:
"""Constructor.
Args:
is_mean_reduction (bool, optional):
When `is_mean_reduction` is True, the sum of the output will be
divided by the number of elements those used
for total variation calculation. Defaults to False.
"""
super(TotalVariation, self).__init__()
self._is_mean = is_mean_reduction
def forward(self, tensor_: Tensor) -> Tensor:
return self._total_variation(tensor_)
def _total_variation(self, tensor_: Tensor) -> Tensor:
"""Calculate total variation.
Args:
tensor_ (Tensor): input tensor must be the any following shapes:
- 2-dimensional: [height, width]
- 3-dimensional: [channel, height, width]
- 4-dimensional: [batch, channel, height, width]
Raises:
ValueError: Input tensor is not either 2, 3 or 4-dimensional.
Returns:
Tensor: the output tensor shape depends on the size of the input.
- Input tensor was 2 or 3 dimensional
return tensor as a scalar
- Input tensor was 4 dimensional
return tensor as an array
"""
ndims_ = tensor_.dim()
if ndims_ == 2:
y_diff = tensor_[1:, :] - tensor_[:-1, :]
x_diff = tensor_[:, 1:] - tensor_[:, :-1]
elif ndims_ == 3:
y_diff = tensor_[:, 1:, :] - tensor_[:, :-1, :]
x_diff = tensor_[:, :, 1:] - tensor_[:, :, :-1]
elif ndims_ == 4:
y_diff = tensor_[:, :, 1:, :] - tensor_[:, :, :-1, :]
x_diff = tensor_[:, :, :, 1:] - tensor_[:, :, :, :-1]
else:
raise ValueError(
'Input tensor must be either 2, 3 or 4-dimensional.')
sum_axis = tuple({abs(x) for x in range(ndims_ - 3, ndims_)})
y_denominator = (
y_diff.shape[sum_axis[0]::].numel() if self._is_mean else 1
)
x_denominator = (
x_diff.shape[sum_axis[0]::].numel() if self._is_mean else 1
)
return (
torch.sum(torch.abs(y_diff), dim=sum_axis) / y_denominator
+ torch.sum(torch.abs(x_diff), dim=sum_axis) / x_denominator
)
使用例
$ python
Python 3.8.2
[GCC 7.5.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> from total_variation import TotalVariation
>>> loss_ = TotalVariation(is_mean_reduction=True)
>>> tensor_ = torch.arange(0, 10.0, 0.1, requires_grad=True).reshape(4, 1, 5, 5)
>>> tensor_
tensor([[[[0.0000, 0.1000, 0.2000, 0.3000, 0.4000],
[0.5000, 0.6000, 0.7000, 0.8000, 0.9000],
[1.0000, 1.1000, 1.2000, 1.3000, 1.4000],
[1.5000, 1.6000, 1.7000, 1.8000, 1.9000],
[2.0000, 2.1000, 2.2000, 2.3000, 2.4000]]],
[[[2.5000, 2.6000, 2.7000, 2.8000, 2.9000],
[3.0000, 3.1000, 3.2000, 3.3000, 3.4000],
[3.5000, 3.6000, 3.7000, 3.8000, 3.9000],
[4.0000, 4.1000, 4.2000, 4.3000, 4.4000],
[4.5000, 4.6000, 4.7000, 4.8000, 4.9000]]],
[[[5.0000, 5.1000, 5.2000, 5.3000, 5.4000],
[5.5000, 5.6000, 5.7000, 5.8000, 5.9000],
[6.0000, 6.1000, 6.2000, 6.3000, 6.4000],
[6.5000, 6.6000, 6.7000, 6.8000, 6.9000],
[7.0000, 7.1000, 7.2000, 7.3000, 7.4000]]],
[[[7.5000, 7.6000, 7.7000, 7.8000, 7.9000],
[8.0000, 8.1000, 8.2000, 8.3000, 8.4000],
[8.5000, 8.6000, 8.7000, 8.8000, 8.9000],
[9.0000, 9.1000, 9.2000, 9.3000, 9.4000],
[9.5000, 9.6000, 9.7000, 9.8000, 9.9000]]]], grad_fn=<ViewBackward>)
>>> output = loss_(tensor_)
>>> output
tensor([0.6000, 0.6000, 0.6000, 0.6000], grad_fn=<AddBackward0>)
>>> output[0].backward()
上記例では,すべてのピクセルにおいて,
- Y方向の差分:
0.5000
- X方向の差分:
0.1000
であるので,Y方向およびX方向の変動の合計値である0.6000
が出力されています.
L1Lossへの接続
下記のように,良しなに実装すると,L1 Lossとの接続が可能です。
$ python
Python 3.8.2
[GCC 7.5.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> from total_variation import TotalVariation
>>> tv_loss = TotalVariation(is_mean_reduction=True)
>>> l1_loss = torch.nn.L1Loss()
>>> input_ = torch.arange(0, 10.0, 0.1, requires_grad=True).reshape(4, 1, 5, 5)
>>> reference = torch.arange(0, 20.0, 0.2, requires_grad=True).reshape(4, 1, 5, 5)
>>> tv_input = tv_loss(input_)
>>> tv_reference = tv_loss(reference)
>>> tv_input
tensor([0.6000, 0.6000, 0.6000, 0.6000], grad_fn=<AddBackward0>)
>>> tv_reference
tensor([1.2000, 1.2000, 1.2000, 1.2000], grad_fn=<AddBackward0>)
>>> l1_output = l1_loss(tv_input, tv_reference)
>>> l1_output
tensor(0.6000, grad_fn=<L1LossBackward>)
>>> l1_output.backward()
【まとめ】PytorchでTotal variationを実装する
Total Variation自体は以前より存在するテクニックだが,Deep Learningを用いた超解像やノイズ除去等においても有用な場面があるのかなと思います!
- Image Restoration using Total Variation Regularized Deep Image Prior
- Deep Learning for Image Super-resolution: A Survey
Pytorchの学習には、 PyTorchニューラルネットワーク 実装ハンドブック
がおすすめです!