import torch
[docs]class WassersteinLossFunction(torch.autograd.Function):
[docs] @staticmethod
def forward(ctx, prediction, label, M, reg, numItermax=100, eps=1e-6):
# Generate target matrix
bs = prediction.size(0)
dim = prediction.size(1)
target = torch.zeros(bs, dim).cuda()
idx = torch.arange(bs).cuda()
target[idx, label - 11] = 1
# Compute Wasserstein Distance
u = torch.ones(bs, dim, dtype=M.dtype).cuda() / dim
v = torch.ones(bs, dim, dtype=M.dtype).cuda() / dim
# K= torch.exp((-M/reg)-1)
K = torch.empty(M.shape, dtype=M.dtype).cuda()
torch.div(M, -reg, out=K)
K = K - 1
torch.exp(K, out=K)
# KM= K * M
KM = torch.mul(K, M)
# KlogK = K * logK
KlogK = torch.mul(K, torch.log(K))
for i in range(numItermax):
v = torch.div(target, torch.mm(u, K))
u = torch.div(prediction, torch.mm(v, K.transpose(0, 1)))
u[torch.abs(u) < eps] = eps
v[torch.abs(v) < eps] = eps
tmp1 = torch.mm(u, KM)
loss = torch.mul(v, tmp1).sum()
ulogu = torch.mul(u, torch.log(u))
tmp2 = torch.mm(ulogu, K)
entropy1 = torch.mul(tmp2, v).sum()
vlogv = torch.mul(v, torch.log(v))
tmp3 = torch.mm(vlogv, K.transpose(0, 1))
entropy2 = torch.mul(tmp3, u).sum()
tmp4 = torch.mm(u, KlogK)
entropy3 = torch.mul(tmp4, v).sum()
entropy = (entropy1 + entropy2 + entropy3) * reg
loss_total = (loss + entropy)
# Save intermediate variables
ctx.save_for_backward(u, torch.tensor([reg], dtype=M.dtype).cuda())
return loss_total.clone() / bs
[docs] @staticmethod
def backward(ctx, grad_output):
u, reg = ctx.saved_tensors
dim = u.size(1)
grad_input = grad_output.clone()
grad = torch.log(u)
shifting = torch.sum(grad, dim=1, keepdim=True) / dim
return grad_input * (grad - shifting) * reg, None, None, None, None, None
[docs]class WassersteinLoss(torch.nn.Module):
def __init__(self, gm, reg, max_iter, eps=1e-6):
self.gm = gm
self.reg = reg
self.max_iter = max_iter
self.eps = eps
self.wasserstein_func = WassersteinLossFunction.apply()
[docs] def forward(self, prediction, target):
return self.wasserstein_func(prediction, target, self.gm, self.reg, self.max_iter, self.eps)