Train the Model on the Device#
Once the training artifacts are generated, the model can be trained on the device using the onnxruntime training python API.
The expected training artifacts are:
The training onnx model
The checkpoint state
The optimizer onnx model
The eval onnx model (optional)
Sample usage:
from onnxruntime.training.api import CheckpointState, Module, Optimizer
# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)
# Create the module
module = Module(path_to_the_training_model,
state,
path_to_the_eval_model,
device="cpu")
optimizer = Optimizer(path_to_the_optimizer_model, module)
# Training loop
for ...:
module.train()
training_loss = module(...)
optimizer.step()
module.lazy_reset_grad()
# Eval
module.eval()
eval_loss = module(...)
# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
- class onnxruntime.training.api.CheckpointState(state: CheckpointState)[source]#
Bases:
object
Class that holds the state of the training session
This class holds all the state information of the training session such as the model parameters, its gradients, the optimizer state and user defined properties.
User defined properties can be indexed by name from the CheckpointState object.
To create the CheckpointState, use the CheckpointState.load_checkpoint method.
- Parameters:
state – The C.Checkpoint state object that holds the underlying session state.
- classmethod load_checkpoint(checkpoint_uri: str | os.PathLike) CheckpointState [source]#
Loads the checkpoint state from the checkpoint file
- Parameters:
checkpoint_uri – The path to the checkpoint file.
- Returns:
The checkpoint state object.
- Return type:
- classmethod save_checkpoint(state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False) None [source]#
Saves the checkpoint state to the checkpoint file
- Parameters:
state – The checkpoint state object.
checkpoint_uri – The path to the checkpoint file.
include_optimizer_state – If True, the optimizer state is also saved to the checkpoint file.
- __getitem__(name: str) int | float | str [source]#
Gets the property associated with the given name
- Parameters:
name – The name of the property
- Returns:
The value of the property
- class onnxruntime.training.api.Module(train_model_uri: PathLike, state: CheckpointState, eval_model_uri: Optional[PathLike] = None, device: str = 'cpu', session_options: Optional[SessionOptions] = None)[source]#
Bases:
object
Trainer class that provides training and evaluation methods for ONNX models.
Before instantiating the Module class, it is expected that the training artifacts have been generated using the onnxruntime.training.artifacts.generate_artifacts utility.
- The training artifacts include:
The training model
The evaluation model (optional)
The optimizer model (optional)
The checkpoint file
- Parameters:
train_model_uri – The path to the training model.
state – The checkpoint state object.
eval_model_uri – The path to the evaluation model.
device – The device to run the model on. Default is “cpu”.
session_options – The session options to use for the model.
- __call__(*user_inputs) tuple[numpy.ndarray, ...] | numpy.ndarray | tuple[onnxruntime.capi.onnxruntime_inference_collection.OrtValue, ...] | onnxruntime.capi.onnxruntime_inference_collection.OrtValue [source]#
Invokes either the training or the evaluation step of the model.
- Parameters:
*user_inputs – The inputs to the model. The user inputs can be either numpy arrays or OrtValues.
- Returns:
The outputs of the model.
- train(mode: bool = True) Module [source]#
Sets the Module in training mode.
- Parameters:
mode – whether to set the model to training mode (True) or evaluation mode (False). Default: True.
- Returns:
self
- lazy_reset_grad()[source]#
Lazily resets the training gradients.
This function sets the internal state of the module such that the module gradients will be scheduled to be reset just before the new gradients are computed on the next invocation of train().
- get_contiguous_parameters(trainable_only: bool = False) OrtValue [source]#
Creates a contiguous buffer of the training session parameters
- Parameters:
trainable_only – If True, only trainable parameters are considered. Otherwise, all parameters are considered.
- Returns:
The contiguous buffer of the training session parameters.
- get_parameters_size(trainable_only: bool = True) int [source]#
Returns the size of the parameters.
- Parameters:
trainable_only – If True, only trainable parameters are considered. Otherwise, all parameters are considered.
- Returns:
The number of primitive (example floating point) elements in the parameters.
- copy_buffer_to_parameters(buffer: OrtValue, trainable_only: bool = True) None [source]#
Copies the OrtValue buffer to the training session parameters.
- Parameters:
buffer – The OrtValue buffer to copy to the training session parameters.
- export_model_for_inferencing(inference_model_uri: str | os.PathLike, graph_output_names: list[str]) None [source]#
Exports the model for inferencing.
Once training is complete, this function can be used to drop the training specific nodes in the onnx model. In particular, this function does the following:
Parse over the training graph and identify nodes that generate the given output names.
Drop all subsequent nodes in the graph since they are not relevant to the inference graph.
- Parameters:
inference_model_uri – The path to the inference model.
graph_output_names – The list of output names that are required for inferencing.
- class onnxruntime.training.api.Optimizer(optimizer_uri: str | os.PathLike, module: Module)[source]#
Bases:
object
Class that provides methods to update the model parameters based on the computed gradients.
- Parameters:
optimizer_uri – The path to the optimizer model.
model – The module to be trained.
- step() None [source]#
Updates the model parameters based on the computed gradients.
This method updates the model parameters by taking a step in the direction of the computed gradients. The optimizer used depends on the optimizer model provided.
- class onnxruntime.training.api.LinearLRScheduler(optimizer: Optimizer, warmup_step_count: int, total_step_count: int, initial_lr: float)[source]#
Bases:
object
Linearly updates the learning rate in the optimizer
The linear learning rate scheduler decays the learning rate by linearly updated multiplicative factor from the initial learning rate set on the training session to 0. The decay is performed after the initial warm up phase where the learning rate is linearly incremented from 0 to the initial learning rate provided.
- Parameters:
optimizer – User’s onnxruntime training Optimizer
warmup_step_count – The number of steps in the warm up phase.
total_step_count – The total number of training steps.
initial_lr – The initial learning rate.