Skip to content

Commit

Permalink
Improve Comms Benchmark Timing (#833)
Browse files Browse the repository at this point in the history
* Change to cuda event-based timing

* Add event args to called funcs

* Add missing comma to args
  • Loading branch information
Quentin-Anthony authored Dec 20, 2023
1 parent dd0f181 commit 8e4cdd8
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 25 deletions.
14 changes: 9 additions & 5 deletions benchmarks/communication/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


# Run all_gather and print metrics
def timed_all_gather(input, output, args):
def timed_all_gather(input, output, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist

Expand All @@ -33,11 +33,12 @@ def timed_all_gather(input, output, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
all_gather_func(output, input, group=None, async_op=args.async_op)
end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand All @@ -63,6 +64,9 @@ def run_all_gather(local_rank, args):
global_rank = dist.get_rank()
world_size = dist.get_world_size()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
# Create list of message sizes
M_LIST = []
Expand Down Expand Up @@ -92,7 +96,7 @@ def run_all_gather(local_rank, args):
else:
raise e
sync_all()
timed_all_gather(input, output, args)
timed_all_gather(input, output, start_event, end_event, args)
else:
# all_gather_into_tensor saves memory
if ((args.dist == 'torch' or args.dist == 'deepspeed') and dist.has_all_gather_into_tensor()):
Expand Down Expand Up @@ -126,7 +130,7 @@ def run_all_gather(local_rank, args):
raise e

sync_all()
timed_all_gather(input, output, args)
timed_all_gather(input, output, start_event, end_event, args)


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepspeed.accelerator import get_accelerator


def timed_all_reduce(input, args):
def timed_all_reduce(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -27,11 +27,12 @@ def timed_all_reduce(input, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
dist.all_reduce(input, async_op=args.async_op)
end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand Down Expand Up @@ -59,6 +60,9 @@ def run_all_reduce(local_rank, args):
world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
Expand All @@ -82,7 +86,7 @@ def run_all_reduce(local_rank, args):
else:
raise e
sync_all()
timed_all_reduce(input, args)
timed_all_reduce(input, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
Expand All @@ -104,7 +108,7 @@ def run_all_reduce(local_rank, args):
else:
raise e
sync_all()
timed_all_reduce(input, args)
timed_all_reduce(input, start_event, end_event, args)


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepspeed.accelerator import get_accelerator


def timed_all_to_all(input, output, args):
def timed_all_to_all(input, output, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -27,11 +27,12 @@ def timed_all_to_all(input, output, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
dist.all_to_all_single(output, input, async_op=args.async_op)
end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand All @@ -58,6 +59,9 @@ def run_all_to_all(local_rank, args):
# Prepare benchmark header
print_header(args, 'all_to_all')

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
Expand All @@ -83,7 +87,7 @@ def run_all_to_all(local_rank, args):
else:
raise e
sync_all()
timed_all_to_all(input, output, args)
timed_all_to_all(input, output, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
elements_per_gpu = max_numel(comm_op='all_to_all',
Expand Down Expand Up @@ -118,7 +122,7 @@ def run_all_to_all(local_rank, args):
print(f"Before AllToAll Input List at rank {global_rank}: {input}")
dist.barrier()

timed_all_to_all(input, output, args)
timed_all_to_all(input, output, start_event, end_event, args)

if args.debug:
for i in range(world_size):
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepspeed.accelerator import get_accelerator


def timed_broadcast(input, args):
def timed_broadcast(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -27,11 +27,12 @@ def timed_broadcast(input, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
dist.broadcast(input, 0, async_op=args.async_op)
end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand Down Expand Up @@ -59,6 +60,9 @@ def run_broadcast(local_rank, args):
world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
Expand All @@ -82,7 +86,7 @@ def run_broadcast(local_rank, args):
else:
raise e
sync_all()
timed_broadcast(input, args)
timed_broadcast(input, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
Expand All @@ -102,7 +106,7 @@ def run_broadcast(local_rank, args):
sync_all()
return
sync_all()
timed_broadcast(input, args)
timed_broadcast(input, start_event, end_event, args)


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/pt2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepspeed.accelerator import get_accelerator


def timed_pt2pt(input, args):
def timed_pt2pt(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -36,7 +36,7 @@ def timed_pt2pt(input, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
if dist.get_rank() == 0:
if args.async_op:
Expand All @@ -49,8 +49,9 @@ def timed_pt2pt(input, args):
else:
dist.recv(input, src=0)

end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand All @@ -77,6 +78,9 @@ def run_pt2pt(local_rank, args):
global_rank = dist.get_rank()
world_size = dist.get_world_size()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
# Create list of message sizes
M_LIST = []
Expand All @@ -101,7 +105,7 @@ def run_pt2pt(local_rank, args):
else:
raise e
sync_all()
timed_pt2pt(input, args)
timed_pt2pt(input, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so double mem_factor
Expand All @@ -121,7 +125,7 @@ def run_pt2pt(local_rank, args):
sync_all()
return
sync_all()
timed_pt2pt(input, args)
timed_pt2pt(input, start_event, end_event, args)


if __name__ == "__main__":
Expand Down

0 comments on commit 8e4cdd8

Please sign in to comment.