import torch
import torch.distributed as dist
[docs]class AllGather(torch.autograd.Function):
"""
all_gather with gradient back-propagation
"""
[docs] @staticmethod
def forward(ctx, tensor_list, tensor):
dist.all_gather(tensor_list, tensor)
return tuple(tensor_list)
[docs] @staticmethod
def backward(ctx, *grad_list):
grad_list = list(grad_list)
rank = dist.get_rank()
dist_ops = [
dist.reduce(grad_list[i], i, async_op=True) for i in range(dist.get_world_size())
]
for op in dist_ops:
op.wait()
return None, grad_list[rank]
all_gather = AllGather.apply