pytorch DDP模式下, 获取数据的的preftech + stream
直接上代码
- DDP forward
if self.device_ids:
if len(self.device_ids) == 1:
inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0])
output = self.module(*inputs[0], **kwargs[0])
else:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module(*inputs, **kwargs)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
r"""Scatter with support for kwargs dictionary"""
inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
def scatter(inputs, target_gpus, dim=0):
r"""
Slices tensors into approximately equal chunks and
distributes them across given GPUs. Duplicates
references to objects that are not tensors.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if is_namedtuple(obj):
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return [list(i) for i in zip(*map(scatter_map, obj))]
if isinstance(obj, dict) and len(obj) > 0:
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
return [obj for targets in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
res = scatter_map(inputs)
finally:
scatter_map = None
return res
torch/nn/parallel/_functions.py,默认tensor to gpu,已经有了stream的加持。
from torch.nn.parallel._functions import Scatter, Gather
class Scatter(Function):
@staticmethod
def forward(ctx, target_gpus, chunk_sizes, dim, input):
target_gpus = [_get_device_index(x, True) for x in target_gpus]
ctx.dim = dim
ctx.input_device = input.get_device() if input.device.type != "cpu" else -1
streams = None
if torch.cuda.is_available() and ctx.input_device == -1:
# Perform CPU to GPU copies in a background stream
streams = [_get_stream(device) for device in target_gpus]
outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
# Synchronize with the copy stream
if streams is not None:
for i, output in enumerate(outputs):
with torch.cuda.device(target_gpus[i]):
main_stream = torch.cuda.current_stream()
main_stream.wait_stream(streams[i])
output.record_stream(main_stream)
return outputs
@staticmethod
def backward(ctx, *grad_output):
return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
结论:目前DDP模式下,已经有了preftech + stream的加持。