Source code for vlkit.ops.non_local

import torch
import torch.nn as nn
from einops import rearrange

[docs]class NonLocal(torch.nn.Module): def __init__(self, in_chs, hidden_chs=None, return_affinity=False): super().__init__() self.in_chs = in_chs if hidden_chs is not None: self.hidden_chs = hidden_chs else: self.hidden_chs = in_chs self.return_affinity = return_affinity self.conv_k = nn.Conv2d(self.in_chs, self.hidden_chs, kernel_size=1, bias=False) self.conv_q = nn.Conv2d(self.in_chs, self.hidden_chs, kernel_size=1, bias=False) self.conv_v = nn.Conv2d(self.in_chs, self.hidden_chs, kernel_size=1, bias=False) self.conv = nn.Conv2d(self.hidden_chs, self.in_chs, kernel_size=1, bias=False)
[docs] def forward(self, x): assert x.ndim == 4 n, c, h, w = x.shape k = self.conv_k(x) q = self.conv_q(x) v = self.conv_v(x) k = rearrange(k, "n c h w -> n (h w) c") q = rearrange(q, "n c h w -> n c (h w)") v = rearrange(v, "n c h w -> n c (h w)") affinity = torch.bmm(k, q).softmax(dim=0) x = torch.bmm(v, affinity).view(n, -1, h, w) v = v.view(n, -1, h, w) x = self.conv(x + v) if self.return_affinity: return x, affinity else: return x