cli.merge_sharded_fsdp_weights
cli.merge_sharded_fsdp_weights
CLI to merge sharded FSDP model checkpoints into a single combined checkpoint.
Classes
Name | Description |
---|---|
BFloat16CastPlanner | A custom planner to cast tensors to bfloat16 on the fly during loading. |
BFloat16CastPlanner
cli.merge_sharded_fsdp_weights.BFloat16CastPlanner()
A custom planner to cast tensors to bfloat16 on the fly during loading.
Functions
Name | Description |
---|---|
do_cli | Parses axolotl config, CLI args, and calls merge_fsdp_weights . |
merge_fsdp_weights | Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if |
do_cli
=Path('examples/'), **kwargs) cli.merge_sharded_fsdp_weights.do_cli(config
Parses axolotl
config, CLI args, and calls merge_fsdp_weights
.
Parameters
Name | Type | Description | Default |
---|---|---|---|
config | Union[Path, str] | Path to axolotl config YAML file. |
Path('examples/') |
kwargs | Additional keyword arguments to override config file values. | {} |
merge_fsdp_weights
cli.merge_sharded_fsdp_weights.merge_fsdp_weights(
checkpoint_dir,
output_path,=False,
safe_serialization=False,
remove_checkpoint_dir )
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
SHARDED_STATE_DICT
was used for the model. Weights will be saved to {output_path}/model.safetensors
if
safe_serialization
else pytorch_model.bin
.
Note: this is a CPU-bound process.
Parameters
Name | Type | Description | Default |
---|---|---|---|
checkpoint_dir | str |
The directory containing the FSDP checkpoints (can be either the model or optimizer). | required |
output_path | str |
The path to save the merged checkpoint. | required |
safe_serialization | bool , optional, defaults to True |
Whether to save the merged weights with safetensors (recommended). | False |
remove_checkpoint_dir | bool , optional, defaults to False |
Whether to remove the checkpoint directory after merging. | False |
Raises
Name | Type | Description |
---|---|---|
ValueError | If torch version < 2.3.0, or if checkpoint_dir does not exist. |