ORTTrainingSession
Objective-C
@interface ORTTrainingSession : NSObject
Swift
class ORTTrainingSession : NSObject
Trainer class that provides methods to train, evaluate and optimize ONNX models.
The training session requires four training artifacts:
- Training onnx model
- Evaluation onnx model (optional)
- Optimizer onnx model
- Checkpoint directory
onnxruntime-training python utility can be used to generate above training artifacts.
Available since 1.16.
Note
This class is only available when the training APIs are enabled.-
Unavailable
Declaration
Objective-C
- (instancetype)init NS_UNAVAILABLE;
-
Creates a training session from the training artifacts that can be used to begin or resume training.
The initializer instantiates the training session based on provided env and session options, which can be used to begin or resume training from a given checkpoint state. The checkpoint state represents the parameters of training session which will be moved to the device specified in the session option if needed.
Note
Note that the training session created with a checkpoint state uses this state to store the entire training state (including model parameters, its gradients, the optimizer states and the properties). The training session keeps a strong (owning) pointer to the checkpoint state.
Declaration
Objective-C
- (nullable instancetype)initWithEnv:(nonnull ORTEnv *)env sessionOptions:(nonnull ORTSessionOptions *)sessionOptions checkpoint:(nonnull ORTCheckpoint *)checkpoint trainModelPath:(nonnull NSString *)trainModelPath evalModelPath:(nullable NSString *)evalModelPath optimizerModelPath:(nullable NSString *)optimizerModelPath error:(NSError *_Nullable *_Nullable)error;
Swift
init(env: ORTEnv, sessionOptions: ORTSessionOptions, checkpoint: ORTCheckpoint, trainModelPath: String, evalModelPath: String?, optimizerModelPath: String?) throws
Parameters
env
The
ORTEnv
instance to use for the training session.sessionOptions
The
ORTSessionOptions
to use for the training session.checkpoint
Training states that are used as a starting point for training.
trainModelPath
The path to the training onnx model.
evalModelPath
The path to the evaluation onnx model.
optimizerModelPath
The path to the optimizer onnx model used to perform gradient descent.
error
Optional error information set if an error occurs.
Return Value
The instance, or nil if an error occurs.
-
Performs a training step, which is equivalent to a forward and backward propagation in a single step.
The training step computes the outputs of the training model and the gradients of the trainable parameters for the given input values. The train step is performed based on the training model that was provided to the training session. It is equivalent to running forward and backward propagation in a single step. The computed gradients are stored inside the training session state so they can be later consumed by
optimizerStep
. The gradients can be lazily reset by callinglazyResetGrad
method.Declaration
Parameters
inputs
The input values to the training model.
error
Optional error information set if an error occurs.
Return Value
The output values of the training model.
-
Performs a evaluation step that computes the outputs of the evaluation model for the given inputs. The eval step is performed based on the evaluation model that was provided to the training session.
Declaration
Parameters
inputs
The input values to the eval model.
error
Optional error information set if an error occurs.
Return Value
The output values of the eval model.
-
Reset the gradients of all trainable parameters to zero lazily.
Calling this method sets the internal state of the training session such that the gradients of the trainable parameters in the ORTCheckpoint will be scheduled to be reset just before the new gradients are computed on the next invocation of the
trainStep
method.Declaration
Objective-C
- (BOOL)lazyResetGradWithError:(NSError *_Nullable *_Nullable)error;
Swift
func lazyResetGrad() throws
Parameters
error
Optional error information set if an error occurs.
Return Value
YES if the gradients are set to reset successfully, NO otherwise.
-
Performs the weight updates for the trainable parameters using the optimizer model. The optimizer step is performed based on the optimizer model that was provided to the training session. The updated parameters are stored inside the training state so that they can be used by the next
trainStep
method call.Declaration
Objective-C
- (BOOL)optimizerStepWithError:(NSError *_Nullable *_Nullable)error;
Swift
func optimizerStep() throws
Parameters
error
Optional error information set if an error occurs.
Return Value
YES if the optimizer step was performed successfully, NO otherwise.
-
Returns the names of the user inputs for the training model that can be associated with the
ORTValue
provided to thetrainStep
.Declaration
Objective-C
- (nullable NSArray<NSString *> *)getTrainInputNamesWithError: (NSError *_Nullable *_Nullable)error;
Swift
func getTrainInputNames() throws -> [String]
Parameters
error
Optional error information set if an error occurs.
Return Value
The names of the user inputs for the training model.
-
Returns the names of the user inputs for the evaluation model that can be associated with the
ORTValue
provided to theevalStep
.Declaration
Objective-C
- (nullable NSArray<NSString *> *)getEvalInputNamesWithError: (NSError *_Nullable *_Nullable)error;
Swift
func getEvalInputNames() throws -> [String]
Parameters
error
Optional error information set if an error occurs.
Return Value
The names of the user inputs for the evaluation model.
-
Returns the names of the user outputs for the training model that can be associated with the
ORTValue
returned by thetrainStep
.Declaration
Objective-C
- (nullable NSArray<NSString *> *)getTrainOutputNamesWithError: (NSError *_Nullable *_Nullable)error;
Swift
func getTrainOutputNames() throws -> [String]
Parameters
error
Optional error information set if an error occurs.
Return Value
The names of the user outputs for the training model.
-
Returns the names of the user outputs for the evaluation model that can be associated with the
ORTValue
returned by theevalStep
.Declaration
Objective-C
- (nullable NSArray<NSString *> *)getEvalOutputNamesWithError: (NSError *_Nullable *_Nullable)error;
Swift
func getEvalOutputNames() throws -> [String]
Parameters
error
Optional error information set if an error occurs.
Return Value
The names of the user outputs for the evaluation model.
-
Registers a linear learning rate scheduler for the training session.
The scheduler gradually decreases the learning rate from the initial value to zero over the course of the training. The decrease is performed by multiplying the current learning rate by a linearly updated factor. Before the decrease, the learning rate is gradually increased from zero to the initial value during a warmup phase.
Declaration
Objective-C
- (BOOL) registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount totalStepCount:(int64_t)totalStepCount initialLr:(float)initialLr error:(NSError *_Nullable *_Nullable) error;
Swift
func registerLinearLRScheduler(withWarmupStepCount warmupStepCount: Int64, totalStepCount: Int64, initialLr: Float) throws
Parameters
warmupStepCount
The number of steps to perform the linear warmup.
totalStepCount
The total number of steps to perform the linear decay.
initialLr
The initial learning rate.
error
Optional error information set if an error occurs.
Return Value
YES if the scheduler was registered successfully, NO otherwise.
-
Update the learning rate based on the registered learning rate scheduler.
Performs a scheduler step that updates the learning rate that is being used by the training session. This function should typically be called before invoking the optimizer step for each round, or as necessary to update the learning rate being used by the training session.
Note
A valid predefined learning rate scheduler must be first registered to invoke this method.
Declaration
Objective-C
- (BOOL)schedulerStepWithError:(NSError *_Nullable *_Nullable)error;
Swift
func schedulerStep() throws
Parameters
error
Optional error information set if an error occurs.
Return Value
YES if the scheduler step was performed successfully, NO otherwise.
-
Returns the current learning rate being used by the training session.
Declaration
Objective-C
- (float)getLearningRateWithError:(NSError *_Nullable *_Nullable)error;
Swift
func getLearningRate() throws -> Float
Parameters
error
Optional error information set if an error occurs.
Return Value
The current learning rate or 0.0f if an error occurs.
-
Sets the learning rate being used by the training session.
The current learning rate is maintained by the training session and can be overwritten by invoking this method with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered. It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant learning rate to be used throughout the training session.
Note
It does not set the initial learning rate that may be needed by the predefined learning rate schedulers. To set the initial learning rate for learning rate schedulers, use the
registerLinearLRScheduler
method.Declaration
Objective-C
- (BOOL)setLearningRate:(float)lr error:(NSError *_Nullable *_Nullable)error;
Swift
func setLearningRate(_ lr: Float) throws
Parameters
lr
The learning rate to be used by the training session.
error
Optional error information set if an error occurs.
Return Value
YES if the learning rate was set successfully, NO otherwise.
-
Loads the training session model parameters from a contiguous buffer.
Declaration
Objective-C
- (BOOL)fromBufferWithValue:(nonnull ORTValue *)buffer error:(NSError *_Nullable *_Nullable)error;
Swift
func fromBuffer(with buffer: ORTValue) throws
Parameters
buffer
Contiguous buffer to load the parameters from.
error
Optional error information set if an error occurs.
Return Value
YES if the parameters were loaded successfully, NO otherwise.
-
Returns a contiguous buffer that holds a copy of all training state parameters.
Declaration
Objective-C
- (nullable ORTValue *)toBufferWithTrainable:(BOOL)onlyTrainable error: (NSError *_Nullable *_Nullable)error;
Swift
func toBuffer(withTrainable onlyTrainable: Bool) throws -> ORTValue
Parameters
onlyTrainable
If YES, returns a buffer that holds only the trainable parameters, otherwise returns a buffer that holds all the parameters.
error
Optional error information set if an error occurs.
Return Value
A contiguous buffer that holds a copy of all training state parameters.
-
Exports the training session model that can be used for inference.
If the training session was provided with an eval model, the training session can generate an inference model if it knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the inference model’s outputs align with the provided outputs. The exported model is saved at the path provided and can be used for inferencing with
ORTSession
.Note
The method reloads the eval model from the path provided to the initializer and expects this path to be valid.
Declaration
Objective-C
- (BOOL) exportModelForInferenceWithOutputPath:(nonnull NSString *)inferenceModelPath graphOutputNames: (nonnull NSArray<NSString *> *)graphOutputNames error:(NSError *_Nullable *_Nullable)error;
Swift
func exportModelForInference(withOutputPath inferenceModelPath: String, graphOutputNames: [String]) throws
Parameters
inferenceModelPath
The path to the serialized the inference model.
graphOutputNames
The names of the outputs that are needed in the inference model.
error
Optional error information set if an error occurs.
Return Value
YES if the inference model was exported successfully, NO otherwise.