integrations.base
integrations.base
Base class for all plugins.
A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl. Plugins can be used to integrate third-party models, modify the training process, or add new features.
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
Classes
Name | Description |
---|---|
BaseOptimizerFactory | Base class for factories to create custom optimizers |
BasePlugin | Base class for all plugins. Defines the interface for plugin methods. |
PluginManager | The PluginManager class is responsible for loading and managing plugins. |
BaseOptimizerFactory
integrations.base.BaseOptimizerFactory()
Base class for factories to create custom optimizers
BasePlugin
self) integrations.base.BasePlugin(
Base class for all plugins. Defines the interface for plugin methods.
Attributes: None
Methods: register(cfg): Registers the plugin with the given configuration. pre_model_load(cfg): Performs actions before the model is loaded. post_model_load(cfg, model): Performs actions after the model is loaded. pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. create_optimizer(cfg, trainer): Creates and returns an optimizer for training. create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler. add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
Methods
Name | Description |
---|---|
add_callbacks_post_trainer | Adds callbacks to the trainer after creating the trainer. |
add_callbacks_pre_trainer | setup callbacks before creating the trainer. |
create_lr_scheduler | Creates and returns a learning rate scheduler. |
create_optimizer | Creates and returns an optimizer for training. |
get_input_args | Returns a pydantic model for the plugin’s input arguments. |
get_trainer_cls | Returns a custom class for the trainer. |
post_lora_load | Performs actions after LoRA weights are loaded. |
post_model_load | Performs actions after the model is loaded. |
post_train | Performs actions after training is complete. |
post_train_unload | Performs actions after training is complete and the model is unloaded. |
pre_lora_load | Performs actions before LoRA weights are loaded. |
pre_model_load | Performs actions before the model is loaded. |
register | Registers the plugin with the given configuration. |
add_callbacks_post_trainer
integrations.base.BasePlugin.add_callbacks_post_trainer(cfg, trainer)
Adds callbacks to the trainer after creating the trainer. This is useful for callbacks that require access to the model or trainer.
Parameters: cfg (dict): The configuration for the plugin. trainer (object): The trainer object for training.
Returns: List[callable]: A list of callback functions to be added
add_callbacks_pre_trainer
integrations.base.BasePlugin.add_callbacks_pre_trainer(cfg, model)
setup callbacks before creating the trainer.
Parameters: cfg (dict): The configuration for the plugin. model (object): The loaded model.
Returns: List[callable]: A list of callback functions to be added to the TrainingArgs
create_lr_scheduler
integrations.base.BasePlugin.create_lr_scheduler(cfg, trainer, optimizer)
Creates and returns a learning rate scheduler.
Parameters: cfg (dict): The configuration for the plugin. trainer (object): The trainer object for training. optimizer (object): The optimizer for training.
Returns: object: The created learning rate scheduler.
create_optimizer
integrations.base.BasePlugin.create_optimizer(cfg, trainer)
Creates and returns an optimizer for training.
Parameters: cfg (dict): The configuration for the plugin. trainer (object): The trainer object for training.
Returns: object: The created optimizer.
get_input_args
integrations.base.BasePlugin.get_input_args()
Returns a pydantic model for the plugin’s input arguments.
get_trainer_cls
integrations.base.BasePlugin.get_trainer_cls(cfg)
Returns a custom class for the trainer.
Parameters: cfg (dict): The global axolotl configuration.
Returns: class: The class for the trainer.
post_lora_load
integrations.base.BasePlugin.post_lora_load(cfg, model)
Performs actions after LoRA weights are loaded.
Parameters: cfg (dict): The configuration for the plugin. model (object): The loaded model.
Returns: None
post_model_load
integrations.base.BasePlugin.post_model_load(cfg, model)
Performs actions after the model is loaded.
Parameters: cfg (dict): The configuration for the plugin. model (object): The loaded model.
Returns: None
post_train
integrations.base.BasePlugin.post_train(cfg, model)
Performs actions after training is complete.
Parameters: cfg (dict): The axolotl configuration model (object): The loaded model.
Returns: None
post_train_unload
integrations.base.BasePlugin.post_train_unload(cfg)
Performs actions after training is complete and the model is unloaded.
Parameters: cfg (dict): The configuration for the plugin.
Returns: None
pre_lora_load
integrations.base.BasePlugin.pre_lora_load(cfg, model)
Performs actions before LoRA weights are loaded.
Parameters: cfg (dict): The configuration for the plugin. model (object): The loaded model.
Returns: None
pre_model_load
integrations.base.BasePlugin.pre_model_load(cfg)
Performs actions before the model is loaded.
Parameters: cfg (dict): The configuration for the plugin.
Returns: None
register
integrations.base.BasePlugin.register(cfg)
Registers the plugin with the given configuration.
Parameters: cfg (dict): The configuration for the plugin.
Returns: None
PluginManager
integrations.base.PluginManager()
The PluginManager class is responsible for loading and managing plugins. It should be a singleton so it can be accessed from anywhere in the codebase.
Attributes: plugins (ListBasePlugin): A list of loaded plugins.
Methods: get_instance(): Static method to get the singleton instance of PluginManager. register(plugin_name: str): Registers a new plugin by its name. pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
Methods
Name | Description |
---|---|
add_callbacks_post_trainer | Calls the add_callbacks_post_trainer method of all registered plugins. |
add_callbacks_pre_trainer | Calls the add_callbacks_pre_trainer method of all registered plugins. |
create_lr_scheduler | Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler. |
create_optimizer | Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer. |
get_input_args | Returns a list of Pydantic classes for all registered plugins’ input arguments.’ |
get_instance | Returns the singleton instance of PluginManager. |
get_trainer_cls | Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class. |
post_lora_load | Calls the post_lora_load method of all registered plugins. |
post_model_load | Calls the post_model_load method of all registered plugins. |
post_train_unload | Calls the post_train_unload method of all registered plugins. |
pre_lora_load | Calls the pre_lora_load method of all registered plugins. |
pre_model_load | Calls the pre_model_load method of all registered plugins. |
register | Registers a new plugin by its name. |
add_callbacks_post_trainer
integrations.base.PluginManager.add_callbacks_post_trainer(cfg, trainer)
Calls the add_callbacks_post_trainer method of all registered plugins.
Parameters: cfg (dict): The configuration for the plugins. trainer (object): The trainer object for training.
Returns: List[callable]: A list of callback functions to be added to the TrainingArgs.
add_callbacks_pre_trainer
integrations.base.PluginManager.add_callbacks_pre_trainer(cfg, model)
Calls the add_callbacks_pre_trainer method of all registered plugins.
Parameters: cfg (dict): The configuration for the plugins. model (object): The loaded model.
Returns: List[callable]: A list of callback functions to be added to the TrainingArgs.
create_lr_scheduler
integrations.base.PluginManager.create_lr_scheduler(cfg, trainer, optimizer)
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
Parameters: cfg (dict): The configuration for the plugins. trainer (object): The trainer object for training. optimizer (object): The optimizer for training.
Returns: object: The created learning rate scheduler, or None if none was found.
create_optimizer
integrations.base.PluginManager.create_optimizer(cfg, trainer)
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
Parameters: cfg (dict): The configuration for the plugins. trainer (object): The trainer object for training.
Returns: object: The created optimizer, or None if none was found.
get_input_args
integrations.base.PluginManager.get_input_args()
Returns a list of Pydantic classes for all registered plugins’ input arguments.’
Returns: list[str]: A list of Pydantic classes for all registered plugins’ input arguments.’
get_instance
integrations.base.PluginManager.get_instance()
Returns the singleton instance of PluginManager. If the instance doesn’t exist, it creates a new one.
get_trainer_cls
integrations.base.PluginManager.get_trainer_cls(cfg)
Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.
Parameters: cfg (dict): The configuration for the plugins.
Returns: object: The trainer class, or None if none was found.
post_lora_load
integrations.base.PluginManager.post_lora_load(cfg, model)
Calls the post_lora_load method of all registered plugins.
Parameters: cfg (dict): The configuration for the plugins. model (object): The loaded model.
Returns: None
post_model_load
integrations.base.PluginManager.post_model_load(cfg, model)
Calls the post_model_load method of all registered plugins.
Parameters: cfg (dict): The configuration for the plugins. model (object): The loaded model.
Returns: None
post_train_unload
integrations.base.PluginManager.post_train_unload(cfg)
Calls the post_train_unload method of all registered plugins.
Parameters: cfg (dict): The configuration for the plugins. model (object): The loaded model.
Returns: None
pre_lora_load
integrations.base.PluginManager.pre_lora_load(cfg, model)
Calls the pre_lora_load method of all registered plugins.
Parameters: cfg (dict): The configuration for the plugins. model (object): The loaded model.
Returns: None
pre_model_load
integrations.base.PluginManager.pre_model_load(cfg)
Calls the pre_model_load method of all registered plugins.
Parameters: cfg (dict): The configuration for the plugins.
Returns: None
register
integrations.base.PluginManager.register(plugin_name)
Registers a new plugin by its name.
Parameters: plugin_name (str): The name of the plugin to be registered.
Returns: None
Raises: ImportError: If the plugin module cannot be imported.
Functions
Name | Description |
---|---|
load_plugin | Loads a plugin based on the given plugin name. |
load_plugin
integrations.base.load_plugin(plugin_name)
Loads a plugin based on the given plugin name.
The plugin name should be in the format “module_name.class_name”. This function splits the plugin name into module and class, imports the module, retrieves the class from the module, and creates an instance of the class.
Parameters: plugin_name (str): The name of the plugin to be loaded. The name should be in the format “module_name.class_name”.
Returns: BasePlugin: An instance of the loaded plugin.
Raises: ImportError: If the plugin module cannot be imported.