integrations.base

integrations.base

Base class for all plugins.

A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl. Plugins can be used to integrate third-party models, modify the training process, or add new features.

To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.

Classes

Name Description
BaseOptimizerFactory Base class for factories to create custom optimizers
BasePlugin Base class for all plugins. Defines the interface for plugin methods.
PluginManager The PluginManager class is responsible for loading and managing plugins.

BaseOptimizerFactory

integrations.base.BaseOptimizerFactory()

Base class for factories to create custom optimizers

BasePlugin

integrations.base.BasePlugin(self)

Base class for all plugins. Defines the interface for plugin methods.

Attributes: None

Methods: register(cfg): Registers the plugin with the given configuration. pre_model_load(cfg): Performs actions before the model is loaded. post_model_load(cfg, model): Performs actions after the model is loaded. pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. create_optimizer(cfg, trainer): Creates and returns an optimizer for training. create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler. add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.

Methods

Name Description
add_callbacks_post_trainer Adds callbacks to the trainer after creating the trainer.
add_callbacks_pre_trainer setup callbacks before creating the trainer.
create_lr_scheduler Creates and returns a learning rate scheduler.
create_optimizer Creates and returns an optimizer for training.
get_input_args Returns a pydantic model for the plugin’s input arguments.
get_trainer_cls Returns a custom class for the trainer.
post_lora_load Performs actions after LoRA weights are loaded.
post_model_load Performs actions after the model is loaded.
post_train Performs actions after training is complete.
post_train_unload Performs actions after training is complete and the model is unloaded.
pre_lora_load Performs actions before LoRA weights are loaded.
pre_model_load Performs actions before the model is loaded.
register Registers the plugin with the given configuration.
add_callbacks_post_trainer
integrations.base.BasePlugin.add_callbacks_post_trainer(cfg, trainer)

Adds callbacks to the trainer after creating the trainer. This is useful for callbacks that require access to the model or trainer.

Parameters: cfg (dict): The configuration for the plugin. trainer (object): The trainer object for training.

Returns: List[callable]: A list of callback functions to be added

add_callbacks_pre_trainer
integrations.base.BasePlugin.add_callbacks_pre_trainer(cfg, model)

setup callbacks before creating the trainer.

Parameters: cfg (dict): The configuration for the plugin. model (object): The loaded model.

Returns: List[callable]: A list of callback functions to be added to the TrainingArgs

create_lr_scheduler
integrations.base.BasePlugin.create_lr_scheduler(cfg, trainer, optimizer)

Creates and returns a learning rate scheduler.

Parameters: cfg (dict): The configuration for the plugin. trainer (object): The trainer object for training. optimizer (object): The optimizer for training.

Returns: object: The created learning rate scheduler.

create_optimizer
integrations.base.BasePlugin.create_optimizer(cfg, trainer)

Creates and returns an optimizer for training.

Parameters: cfg (dict): The configuration for the plugin. trainer (object): The trainer object for training.

Returns: object: The created optimizer.

get_input_args
integrations.base.BasePlugin.get_input_args()

Returns a pydantic model for the plugin’s input arguments.

get_trainer_cls
integrations.base.BasePlugin.get_trainer_cls(cfg)

Returns a custom class for the trainer.

Parameters: cfg (dict): The global axolotl configuration.

Returns: class: The class for the trainer.

post_lora_load
integrations.base.BasePlugin.post_lora_load(cfg, model)

Performs actions after LoRA weights are loaded.

Parameters: cfg (dict): The configuration for the plugin. model (object): The loaded model.

Returns: None

post_model_load
integrations.base.BasePlugin.post_model_load(cfg, model)

Performs actions after the model is loaded.

Parameters: cfg (dict): The configuration for the plugin. model (object): The loaded model.

Returns: None

post_train
integrations.base.BasePlugin.post_train(cfg, model)

Performs actions after training is complete.

Parameters: cfg (dict): The axolotl configuration model (object): The loaded model.

Returns: None

post_train_unload
integrations.base.BasePlugin.post_train_unload(cfg)

Performs actions after training is complete and the model is unloaded.

Parameters: cfg (dict): The configuration for the plugin.

Returns: None

pre_lora_load
integrations.base.BasePlugin.pre_lora_load(cfg, model)

Performs actions before LoRA weights are loaded.

Parameters: cfg (dict): The configuration for the plugin. model (object): The loaded model.

Returns: None

pre_model_load
integrations.base.BasePlugin.pre_model_load(cfg)

Performs actions before the model is loaded.

Parameters: cfg (dict): The configuration for the plugin.

Returns: None

register
integrations.base.BasePlugin.register(cfg)

Registers the plugin with the given configuration.

Parameters: cfg (dict): The configuration for the plugin.

Returns: None

PluginManager

integrations.base.PluginManager()

The PluginManager class is responsible for loading and managing plugins. It should be a singleton so it can be accessed from anywhere in the codebase.

Attributes: plugins (ListBasePlugin): A list of loaded plugins.

Methods: get_instance(): Static method to get the singleton instance of PluginManager. register(plugin_name: str): Registers a new plugin by its name. pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.

Methods

Name Description
add_callbacks_post_trainer Calls the add_callbacks_post_trainer method of all registered plugins.
add_callbacks_pre_trainer Calls the add_callbacks_pre_trainer method of all registered plugins.
create_lr_scheduler Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
create_optimizer Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
get_input_args Returns a list of Pydantic classes for all registered plugins’ input arguments.’
get_instance Returns the singleton instance of PluginManager.
get_trainer_cls Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.
post_lora_load Calls the post_lora_load method of all registered plugins.
post_model_load Calls the post_model_load method of all registered plugins.
post_train_unload Calls the post_train_unload method of all registered plugins.
pre_lora_load Calls the pre_lora_load method of all registered plugins.
pre_model_load Calls the pre_model_load method of all registered plugins.
register Registers a new plugin by its name.
add_callbacks_post_trainer
integrations.base.PluginManager.add_callbacks_post_trainer(cfg, trainer)

Calls the add_callbacks_post_trainer method of all registered plugins.

Parameters: cfg (dict): The configuration for the plugins. trainer (object): The trainer object for training.

Returns: List[callable]: A list of callback functions to be added to the TrainingArgs.

add_callbacks_pre_trainer
integrations.base.PluginManager.add_callbacks_pre_trainer(cfg, model)

Calls the add_callbacks_pre_trainer method of all registered plugins.

Parameters: cfg (dict): The configuration for the plugins. model (object): The loaded model.

Returns: List[callable]: A list of callback functions to be added to the TrainingArgs.

create_lr_scheduler
integrations.base.PluginManager.create_lr_scheduler(cfg, trainer, optimizer)

Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.

Parameters: cfg (dict): The configuration for the plugins. trainer (object): The trainer object for training. optimizer (object): The optimizer for training.

Returns: object: The created learning rate scheduler, or None if none was found.

create_optimizer
integrations.base.PluginManager.create_optimizer(cfg, trainer)

Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.

Parameters: cfg (dict): The configuration for the plugins. trainer (object): The trainer object for training.

Returns: object: The created optimizer, or None if none was found.

get_input_args
integrations.base.PluginManager.get_input_args()

Returns a list of Pydantic classes for all registered plugins’ input arguments.’

Returns: list[str]: A list of Pydantic classes for all registered plugins’ input arguments.’

get_instance
integrations.base.PluginManager.get_instance()

Returns the singleton instance of PluginManager. If the instance doesn’t exist, it creates a new one.

get_trainer_cls
integrations.base.PluginManager.get_trainer_cls(cfg)

Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.

Parameters: cfg (dict): The configuration for the plugins.

Returns: object: The trainer class, or None if none was found.

post_lora_load
integrations.base.PluginManager.post_lora_load(cfg, model)

Calls the post_lora_load method of all registered plugins.

Parameters: cfg (dict): The configuration for the plugins. model (object): The loaded model.

Returns: None

post_model_load
integrations.base.PluginManager.post_model_load(cfg, model)

Calls the post_model_load method of all registered plugins.

Parameters: cfg (dict): The configuration for the plugins. model (object): The loaded model.

Returns: None

post_train_unload
integrations.base.PluginManager.post_train_unload(cfg)

Calls the post_train_unload method of all registered plugins.

Parameters: cfg (dict): The configuration for the plugins. model (object): The loaded model.

Returns: None

pre_lora_load
integrations.base.PluginManager.pre_lora_load(cfg, model)

Calls the pre_lora_load method of all registered plugins.

Parameters: cfg (dict): The configuration for the plugins. model (object): The loaded model.

Returns: None

pre_model_load
integrations.base.PluginManager.pre_model_load(cfg)

Calls the pre_model_load method of all registered plugins.

Parameters: cfg (dict): The configuration for the plugins.

Returns: None

register
integrations.base.PluginManager.register(plugin_name)

Registers a new plugin by its name.

Parameters: plugin_name (str): The name of the plugin to be registered.

Returns: None

Raises: ImportError: If the plugin module cannot be imported.

Functions

Name Description
load_plugin Loads a plugin based on the given plugin name.

load_plugin

integrations.base.load_plugin(plugin_name)

Loads a plugin based on the given plugin name.

The plugin name should be in the format “module_name.class_name”. This function splits the plugin name into module and class, imports the module, retrieves the class from the module, and creates an instance of the class.

Parameters: plugin_name (str): The name of the plugin to be loaded. The name should be in the format “module_name.class_name”.

Returns: BasePlugin: An instance of the loaded plugin.

Raises: ImportError: If the plugin module cannot be imported.