Prepare for training#

Before the training can start on edge devices, the training artifacts need to be generated in an offline step.

These artifacts include:

  1. The training onnx model

  2. The checkpoint state

  3. The optimizer onnx model

  4. The eval onnx model (optional)

It is assumed that the an forward only onnx model is already available. This model can be generated by exporting the PyTorch model using the torch.onnx.export() API if using PyTorch.

Note

If using PyTorch to export the model, please use the following export arguments so training artifact generation can be successful:

  • export_params: True

  • do_constant_folding: False

  • training: torch.onnx.TrainingMode.TRAINING

Once the forward only onnx model is available, the training artifacts can be generated using the onnxruntime.training.artifacts.generate_artifacts() API.

Sample usage:

from onnxruntime.training import artifacts

# Load the forward only onnx model
model = onnx.load(path_to_forward_only_onnx_model)

# Generate the training artifacts
artifacts.generate_artifacts(model,
                             requires_grad = ["parameters", "needing", "gradients"],
                             frozen_params = ["parameters", "not", "needing", "gradients"],
                             loss = artifacts.LossType.CrossEntropyLoss,
                             optimizer = artifacts.OptimType.AdamW,
                             artifact_directory = path_to_output_artifact_directory)
class onnxruntime.training.artifacts.LossType(value)[source]#

Loss type to be added to the training model.

To be used with the loss parameter of generate_artifacts function.

MSELoss = 1#
CrossEntropyLoss = 2#
BCEWithLogitsLoss = 3#
L1Loss = 4#
class onnxruntime.training.artifacts.OptimType(value)[source]#

Optimizer type to be to be used while generating the optimizer model for training.

To be used with the optimizer parameter of generate_artifacts function.

AdamW = 1#
SGD = 2#
onnxruntime.training.artifacts.generate_artifacts(model: ModelProto, requires_grad: Optional[List[str]] = None, frozen_params: Optional[List[str]] = None, loss: Optional[Union[LossType, Block]] = None, optimizer: Optional[OptimType] = None, artifact_directory: Optional[Union[str, bytes, PathLike]] = None, **extra_options) None[source]#

Generates artifacts required for training with ORT training api.

This function generates the following artifacts:
  1. Training model (onnx.ModelProto): Contains the base model graph, loss sub graph and the gradient graph.

  2. Eval model (onnx.ModelProto): Contains the base model graph and the loss sub graph

  3. Checkpoint (directory): Contains the model parameters.

  4. Optimizer model (onnx.ModelProto): Model containing the optimizer graph.

Parameters:
  • model – The base model to be used for gradient graph generation.

  • requires_grad – List of names of model parameters that require gradient computation

  • frozen_params – List of names of model parameters that should be frozen.

  • loss – The loss function enum to be used for training. If None, no loss node is added to the graph.

  • optimizer – The optimizer enum to be used for training. If None, no optimizer model is generated.

  • artifact_directory – The directory to save the generated artifacts. If None, the current working directory is used.

  • prefix (str) – The prefix to be used for the generated artifacts. If not specified, no prefix is used.

  • ort_format (bool) – Whether to save the generated artifacts in ORT format or not. Default is False.

  • custom_op_library (str | os.PathLike) – The path to the custom op library. If not specified, no custom op library is used.

Raises:
  • RuntimeError – If the loss provided is neither one of the supported losses nor an instance of onnxblock.Block

  • RuntimeError – If the optimizer provided is not one of the supported optimizers.

Custom Loss#

If a custom loss is needed, the user can provide a custom loss function to the onnxruntime.training.artifacts.generate_artifacts() API. This is done by inheriting from the onnxruntime.training.onnxblock.Block class and implementing the build method.

The following example shows how to implement a custom loss function:

Let’s assume, we want to use a custom loss function with a model. For this example, we assume that our model generates two outputs. And the custom loss function must apply a loss function on each of the outputs and perform a weighted average on the output. Mathematically,

loss = 0.4 * mse_loss1(output1, target1) + 0.6 * mse_loss2(output2, target2)

Since this is a custom loss function, this loss type is not exposed as an enum by LossType enum.

For this, we make use of onnxblock.

import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training import artifacts

# Define a custom loss block that takes in two inputs
# and performs a weighted average of the losses from these
# two inputs.
class WeightedAverageLoss(onnxblock.Block):
    def __init__(self):
        self._loss1 = onnxblock.loss.MSELoss()
        self._loss2 = onnxblock.loss.MSELoss()
        self._w1 = onnxblock.blocks.Constant(0.4)
        self._w2 = onnxblock.blocks.Constant(0.6)
        self._add = onnxblock.blocks.Add()
        self._mul = onnxblock.blocks.Mul()

    def build(self, loss_input_name1, loss_input_name2):
        # The build method defines how the block should be stacked on top of
        # loss_input_name1 and loss_input_name2

        # Returns weighted average of the two losses
        return self._add(
            self._mul(self._w1(), self._loss1(loss_input_name1, target_name="target1")),
            self._mul(self._w2(), self._loss2(loss_input_name2, target_name="target2"))
        )

my_custom_loss = WeightedAverageLoss()

# Load the onnx model
model_path = "model.onnx"
base_model = onnx.load(model_path)

# Define the parameters that need their gradient computed
requires_grad = ["weight1", "bias1", "weight2", "bias2"]
frozen_params = ["weight3", "bias3"]

# Now, we can invoke generate_artifacts with this custom loss function
artifacts.generate_artifacts(base_model, requires_grad = requires_grad, frozen_params = frozen_params,
                            loss = my_custom_loss, optimizer = artifacts.OptimType.AdamW)

# Successful completion of the above call will generate 4 files in the current working directory,
# one for each of the artifacts mentioned above (training_model.onnx, eval_model.onnx, checkpoint, optimizer_model.onnx)
class onnxruntime.training.onnxblock.Block[source]#

Bases: ABC

Base class for all building blocks that can be stacked on top of each other.

All blocks that want to manipulate the model must subclass this class. The subclass’s implementation of the build method must return the names of the intermediate outputs from the block.

The subclass’s implementation of the build method must manipulate the base model as it deems fit, but the manipulated model must be valid (as deemed by the onnx checker).

base#

The base model that the subclass can manipulate.

Type:

onnx.ModelProto

abstract build(*args, **kwargs)[source]#

Customize the model by stacking up blocks on top of the inputs to this function.

This method must be overridden by the subclass.

Advanced Usage#

onnxblock is a library that can be used to build complex onnx models by stacking simple blocks on top of each other. An example of this is the ability to build a custom loss function as shown above.

onnxblock also provides a way to build a custom forward only or training (forward + backward) onnx model through the onnxruntime.training.onnxblock.ForwardBlock and onnxruntime.training.onnxblock.TrainingBlock classes respectively. These blocks inherit from the base onnxruntime.training.onnxblock.Block class and provide additional functionality to build inference and training models.

class onnxruntime.training.onnxblock.ForwardBlock[source]#

Bases: Block

Base class for all blocks that require forward model to be automatically built.

Blocks wanting to build a forward model by stacking blocks on top of the existing model must subclass this class. The subclass’s implementation of the build method must return the name of the graph output. This block will automatically register the output as a graph output and build the model.

Example:

>>> class MyForwardBlock(ForwardBlock):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.loss = onnxblock.loss.CrossEntropyLoss()
>>>
>>>     def build(self, loss_input_name: str):
>>>         # Add a cross entropy loss on top of the output so far (loss_input_name)
>>>         return self.loss(loss_input_name)

The above example will automatically build the forward graph that is composed of the existing model and the cross entropy loss function stacked on top of it.

abstract build(*args, **kwargs)[source]#

Customize the forward graph for this model by stacking up blocks on top of the inputs to this function.

This method should be overridden by the subclass. The output of this method should be the name of the graph output.

to_model_proto()[source]#

Returns the forward model.

Returns:

The forward model.

Return type:

model (onnx.ModelProto)

Raises:

RuntimeError – If the build method has not been invoked (i.e. the forward model has not been built yet).

class onnxruntime.training.onnxblock.TrainingBlock[source]#

Bases: Block

Base class for all blocks that require gradient model to be automatically built.

Blocks that require the gradient graph to be computed based on the output of the block must subclass this class. The subclass’s implementation of the build method must return the name of the output from where backpropagation must begin (typically the name of the output from the loss function).

Example:

>>> class MyTrainingBlock(TrainingBlock):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.loss = onnxblock.loss.CrossEntropyLoss()
>>>
>>>     def build(self, loss_input_name: str):
>>>         # Add a cross entropy loss on top of the output so far (loss_input_name)
>>>         return self.loss(loss_input_name)

The above example will automatically build the gradient graph for the entire model starting from the output of the loss function.

abstract build(*args, **kwargs)[source]#

Customize the forward graph for this model by stacking up blocks on top of the inputs to this function.

This method should be overridden by the subclass. The output of this method should be the name of the output from where backpropagation must begin (typically the name of the output from the loss function).

requires_grad(argument_name: str, value: bool = True)[source]#

Specify whether the argument requires gradient or not.

The auto-diff will compute the gradient graph for only the arguments that require gradient. By default, none of the arguments require gradient. The user must explicitly specify which arguments require gradient.

Parameters:
  • argument_name (str) – The name of the argument that require/does not require gradient.

  • value (bool) – True if the argument requires gradient, False otherwise.

parameters() Tuple[List[TensorProto], List[TensorProto]][source]#

Trainable as well as non-trainable (frozen) parameters of the model.

Model parameters that are extracted while building the training model are returned by this method.

Note that the parameters are not known before the training model is built. As a result, if this method is invoked before the training model is built, an exception will be raised.

Returns:

The trainable parameters of the model. frozen_params (list of onnx.TensorProto): The non-trainable parameters of the model.

Return type:

trainable_params (list of onnx.TensorProto)

Raises:

RuntimeError – If the build method has not been invoked (i.e. the training model has not been built yet).

to_model_proto() Tuple[ModelProto, ModelProto][source]#

Returns the training and eval models.

Once the gradient graph is built, the training and eval models can be retrieved by invoking this method.

Returns:

The training model. eval_model (onnx.ModelProto): The eval model.

Return type:

training_model (onnx.ModelProto)

Raises:

RuntimeError – If the build method has not been invoked (i.e. the training model has not been built yet).