Source code for vlkit.ops.distributed

import torch
import torch.distributed as dist

[docs]class AllGather(torch.autograd.Function): """ all_gather with gradient back-propagation """
[docs] @staticmethod def forward(ctx, tensor_list, tensor): dist.all_gather(tensor_list, tensor) return tuple(tensor_list)
[docs] @staticmethod def backward(ctx, *grad_list): grad_list = list(grad_list) rank = dist.get_rank() dist_ops = [ dist.reduce(grad_list[i], i, async_op=True) for i in range(dist.get_world_size()) ] for op in dist_ops: op.wait() return None, grad_list[rank]
all_gather = AllGather.apply