utils.distributed

utils.distributed

utility helpers for distributed checks

Functions

Name Description
barrier Acts as a barrier to wait for all processes. This ensures that all processes
compute_and_broadcast Compute a value using the function ‘fn’ only on the specified rank (default is 0).
gather_from_all_ranks Run a callable ‘fn’ on all ranks and gather the results on the specified rank.
gather_scalar_from_all_ranks Run a callable ‘fn’ on all ranks and gather the results on the specified rank.
is_distributed Check if distributed training is initialized.
is_main_process Check if the current process is the main process.
reduce_and_broadcast Run a callable ‘fn1’ on all ranks, gather the results, reduce them using ‘fn2’,
zero_first runs the wrapped context so that rank 0 runs first before other ranks
zero_only Context manager that only runs the enclosed block on the main rank.

barrier

utils.distributed.barrier()

Acts as a barrier to wait for all processes. This ensures that all processes reach the barrier before proceeding further.

compute_and_broadcast

utils.distributed.compute_and_broadcast(fn)

Compute a value using the function ‘fn’ only on the specified rank (default is 0). The value is then broadcasted to all other ranks.

Args: - fn (callable): A function that computes the value. This should not have any side effects. - rank (int, optional): The rank that computes the value. Default is 0.

Returns: - The computed value (int or float).

gather_from_all_ranks

utils.distributed.gather_from_all_ranks(fn, world_size=1)

Run a callable ‘fn’ on all ranks and gather the results on the specified rank.

Args: - fn (callable): A function that computes the value. This should not have any side effects. - rank (int, optional): The rank that gathers the values. Default is 0. - world_size (int, optional): Total number of processes in the current distributed setup.

Returns: - A list of computed values from all ranks if on the gathering rank, otherwise None.

gather_scalar_from_all_ranks

utils.distributed.gather_scalar_from_all_ranks(fn, world_size=1)

Run a callable ‘fn’ on all ranks and gather the results on the specified rank.

Args: - fn (callable): A function that computes the value. This should not have any side effects. - rank (int, optional): The rank that gathers the values. Default is 0. - world_size (int, optional): Total number of processes in the current distributed setup.

Returns: - A list of computed values from all ranks if on the gathering rank, otherwise None.

is_distributed

utils.distributed.is_distributed()

Check if distributed training is initialized.

is_main_process

utils.distributed.is_main_process()

Check if the current process is the main process. If not in distributed mode, always return True.

reduce_and_broadcast

utils.distributed.reduce_and_broadcast(fn1, fn2)

Run a callable ‘fn1’ on all ranks, gather the results, reduce them using ‘fn2’, and then broadcast the reduced result to all ranks.

Args: - fn1 (callable): A function that computes the value on each rank. - fn2 (callable): A reduction function that takes a list of values and returns a single value. - world_size (int, optional): Total number of processes in the current distributed setup.

Returns: - The reduced and broadcasted value.

zero_first

utils.distributed.zero_first(is_main)

runs the wrapped context so that rank 0 runs first before other ranks

zero_only

utils.distributed.zero_only()

Context manager that only runs the enclosed block on the main rank.