Train the Model on the Device#

Once the training artifacts are generated, the model can be trained on the device using the onnxruntime training python API.

The expected training artifacts are:

  1. The training onnx model

  2. The checkpoint state

  3. The optimizer onnx model

  4. The eval onnx model (optional)

Sample usage:

from onnxruntime.training.api import CheckpointState, Module, Optimizer

# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)

# Create the module
module = Module(path_to_the_training_model,
                state,
                path_to_the_eval_model,
                device="cpu")

optimizer = Optimizer(path_to_the_optimizer_model, module)

# Training loop
for ...:
    module.train()
    training_loss = module(...)
    optimizer.step()
    module.lazy_reset_grad()

# Eval
module.eval()
eval_loss = module(...)

# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
class onnxruntime.training.api.CheckpointState(state: CheckpointState)[source]#

Bases: object

Class that holds the state of the training session

This class holds all the state information of the training session such as the model parameters, its gradients, the optimizer state and user defined properties.

User defined properties can be indexed by name from the CheckpointState object.

To create the CheckpointState, use the CheckpointState.load_checkpoint method.

Parameters:

state – The C.Checkpoint state object that holds the underlying session state.

classmethod load_checkpoint(checkpoint_uri: str | os.PathLike) CheckpointState[source]#

Loads the checkpoint state from the checkpoint file

Parameters:

checkpoint_uri – The path to the checkpoint file.

Returns:

The checkpoint state object.

Return type:

CheckpointState

classmethod save_checkpoint(state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False) None[source]#

Saves the checkpoint state to the checkpoint file

Parameters:
  • state – The checkpoint state object.

  • checkpoint_uri – The path to the checkpoint file.

  • include_optimizer_state – If True, the optimizer state is also saved to the checkpoint file.

__getitem__(name: str) int | float | str[source]#

Gets the property associated with the given name

Parameters:

name – The name of the property

Returns:

The value of the property

__setitem__(name: str, value: int | float | str) None[source]#

Sets the property value for the given name

Parameters:
  • name – The name of the property

  • value – The value of the property

__contains__(name: str) bool[source]#

Checks if the property exists in the state

Parameters:

name – The name of the property

Returns:

True if the property exists, False otherwise

class onnxruntime.training.api.Module(train_model_uri: PathLike, state: CheckpointState, eval_model_uri: Optional[PathLike] = None, device: str = 'cpu', session_options: Optional[SessionOptions] = None)[source]#

Bases: object

Trainer class that provides training and evaluation methods for ONNX models.

Before instantiating the Module class, it is expected that the training artifacts have been generated using the onnxruntime.training.artifacts.generate_artifacts utility.

The training artifacts include:
  • The training model

  • The evaluation model (optional)

  • The optimizer model (optional)

  • The checkpoint file

training#

True if the model is in training mode, False if it is in evaluation mode.

Type:

bool

Parameters:
  • train_model_uri – The path to the training model.

  • state – The checkpoint state object.

  • eval_model_uri – The path to the evaluation model.

  • device – The device to run the model on. Default is “cpu”.

  • session_options – The session options to use for the model.

__call__(*user_inputs) tuple[numpy.ndarray, ...] | numpy.ndarray | tuple[onnxruntime.capi.onnxruntime_inference_collection.OrtValue, ...] | onnxruntime.capi.onnxruntime_inference_collection.OrtValue[source]#

Invokes either the training or the evaluation step of the model.

Parameters:

*user_inputs – The inputs to the model. The user inputs can be either numpy arrays or OrtValues.

Returns:

The outputs of the model.

train(mode: bool = True) Module[source]#

Sets the Module in training mode.

Parameters:

mode – whether to set the model to training mode (True) or evaluation mode (False). Default: True.

Returns:

self

eval() Module[source]#

Sets the Module in evaluation mode.

Returns:

self

lazy_reset_grad()[source]#

Lazily resets the training gradients.

This function sets the internal state of the module such that the module gradients will be scheduled to be reset just before the new gradients are computed on the next invocation of train().

get_contiguous_parameters(trainable_only: bool = False) OrtValue[source]#

Creates a contiguous buffer of the training session parameters

Parameters:

trainable_only – If True, only trainable parameters are considered. Otherwise, all parameters are considered.

Returns:

The contiguous buffer of the training session parameters.

get_parameters_size(trainable_only: bool = True) int[source]#

Returns the size of the parameters.

Parameters:

trainable_only – If True, only trainable parameters are considered. Otherwise, all parameters are considered.

Returns:

The number of primitive (example floating point) elements in the parameters.

copy_buffer_to_parameters(buffer: OrtValue, trainable_only: bool = True) None[source]#

Copies the OrtValue buffer to the training session parameters.

Parameters:

buffer – The OrtValue buffer to copy to the training session parameters.

export_model_for_inferencing(inference_model_uri: str | os.PathLike, graph_output_names: list[str]) None[source]#

Exports the model for inferencing.

Once training is complete, this function can be used to drop the training specific nodes in the onnx model. In particular, this function does the following:

  • Parse over the training graph and identify nodes that generate the given output names.

  • Drop all subsequent nodes in the graph since they are not relevant to the inference graph.

Parameters:
  • inference_model_uri – The path to the inference model.

  • graph_output_names – The list of output names that are required for inferencing.

input_names() list[str][source]#

Returns the input names of the training or eval model.

output_names() list[str][source]#

Returns the output names of the training or eval model.

class onnxruntime.training.api.Optimizer(optimizer_uri: str | os.PathLike, module: Module)[source]#

Bases: object

Class that provides methods to update the model parameters based on the computed gradients.

Parameters:
  • optimizer_uri – The path to the optimizer model.

  • model – The module to be trained.

step() None[source]#

Updates the model parameters based on the computed gradients.

This method updates the model parameters by taking a step in the direction of the computed gradients. The optimizer used depends on the optimizer model provided.

set_learning_rate(learning_rate: float) None[source]#

Sets the learning rate for the optimizer.

Parameters:

learning_rate – The learning rate to be set.

get_learning_rate() float[source]#

Gets the current learning rate of the optimizer.

Returns:

The current learning rate.

class onnxruntime.training.api.LinearLRScheduler(optimizer: Optimizer, warmup_step_count: int, total_step_count: int, initial_lr: float)[source]#

Bases: object

Linearly updates the learning rate in the optimizer

The linear learning rate scheduler decays the learning rate by linearly updated multiplicative factor from the initial learning rate set on the training session to 0. The decay is performed after the initial warm up phase where the learning rate is linearly incremented from 0 to the initial learning rate provided.

Parameters:
  • optimizer – User’s onnxruntime training Optimizer

  • warmup_step_count – The number of steps in the warm up phase.

  • total_step_count – The total number of training steps.

  • initial_lr – The initial learning rate.

step() None[source]#

Updates the learning rate of the optimizer linearly.

This method should be called at each step of training to ensure that the learning rate is properly adjusted.