utils.distributed
utils.distributed
utility helpers for distributed checks
Functions
Name | Description |
---|---|
barrier | Acts as a barrier to wait for all processes. This ensures that all processes |
compute_and_broadcast | Compute a value using the function ‘fn’ only on the specified rank (default is 0). |
gather_from_all_ranks | Run a callable ‘fn’ on all ranks and gather the results on the specified rank. |
gather_scalar_from_all_ranks | Run a callable ‘fn’ on all ranks and gather the results on the specified rank. |
is_distributed | Check if distributed training is initialized. |
is_main_process | Check if the current process is the main process. |
reduce_and_broadcast | Run a callable ‘fn1’ on all ranks, gather the results, reduce them using ‘fn2’, |
zero_first | runs the wrapped context so that rank 0 runs first before other ranks |
zero_only | Context manager that only runs the enclosed block on the main rank. |
barrier
utils.distributed.barrier()
Acts as a barrier to wait for all processes. This ensures that all processes reach the barrier before proceeding further.
compute_and_broadcast
utils.distributed.compute_and_broadcast(fn)
Compute a value using the function ‘fn’ only on the specified rank (default is 0). The value is then broadcasted to all other ranks.
Args: - fn (callable): A function that computes the value. This should not have any side effects. - rank (int, optional): The rank that computes the value. Default is 0.
Returns: - The computed value (int or float).
gather_from_all_ranks
=1) utils.distributed.gather_from_all_ranks(fn, world_size
Run a callable ‘fn’ on all ranks and gather the results on the specified rank.
Args: - fn (callable): A function that computes the value. This should not have any side effects. - rank (int, optional): The rank that gathers the values. Default is 0. - world_size (int, optional): Total number of processes in the current distributed setup.
Returns: - A list of computed values from all ranks if on the gathering rank, otherwise None.
gather_scalar_from_all_ranks
=1) utils.distributed.gather_scalar_from_all_ranks(fn, world_size
Run a callable ‘fn’ on all ranks and gather the results on the specified rank.
Args: - fn (callable): A function that computes the value. This should not have any side effects. - rank (int, optional): The rank that gathers the values. Default is 0. - world_size (int, optional): Total number of processes in the current distributed setup.
Returns: - A list of computed values from all ranks if on the gathering rank, otherwise None.
is_distributed
utils.distributed.is_distributed()
Check if distributed training is initialized.
is_main_process
utils.distributed.is_main_process()
Check if the current process is the main process. If not in distributed mode, always return True.
reduce_and_broadcast
utils.distributed.reduce_and_broadcast(fn1, fn2)
Run a callable ‘fn1’ on all ranks, gather the results, reduce them using ‘fn2’, and then broadcast the reduced result to all ranks.
Args: - fn1 (callable): A function that computes the value on each rank. - fn2 (callable): A reduction function that takes a list of values and returns a single value. - world_size (int, optional): Total number of processes in the current distributed setup.
Returns: - The reduced and broadcasted value.
zero_first
utils.distributed.zero_first(is_main)
runs the wrapped context so that rank 0 runs first before other ranks
zero_only
utils.distributed.zero_only()
Context manager that only runs the enclosed block on the main rank.