cli.merge_sharded_fsdp_weights

cli.merge_sharded_fsdp_weights

CLI to merge sharded FSDP model checkpoints into a single combined checkpoint.

Classes

Name Description
BFloat16CastPlanner A custom planner to cast tensors to bfloat16 on the fly during loading.

BFloat16CastPlanner

cli.merge_sharded_fsdp_weights.BFloat16CastPlanner()

A custom planner to cast tensors to bfloat16 on the fly during loading.

Functions

Name Description
do_cli Parses axolotl config, CLI args, and calls merge_fsdp_weights.
merge_fsdp_weights Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if

do_cli

cli.merge_sharded_fsdp_weights.do_cli(config=Path('examples/'), **kwargs)

Parses axolotl config, CLI args, and calls merge_fsdp_weights.

Parameters

Name Type Description Default
config Union[Path, str] Path to axolotl config YAML file. Path('examples/')
kwargs Additional keyword arguments to override config file values. {}

merge_fsdp_weights

cli.merge_sharded_fsdp_weights.merge_fsdp_weights(
    checkpoint_dir,
    output_path,
    safe_serialization=False,
    remove_checkpoint_dir=False,
)

Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if SHARDED_STATE_DICT was used for the model. Weights will be saved to {output_path}/model.safetensors if safe_serialization else pytorch_model.bin.

Note: this is a CPU-bound process.

Parameters

Name Type Description Default
checkpoint_dir str The directory containing the FSDP checkpoints (can be either the model or optimizer). required
output_path str The path to save the merged checkpoint. required
safe_serialization bool, optional, defaults to True Whether to save the merged weights with safetensors (recommended). False
remove_checkpoint_dir bool, optional, defaults to False Whether to remove the checkpoint directory after merging. False

Raises

Name Type Description
ValueError If torch version < 2.3.0, or if checkpoint_dir does not exist.