In order to train a model with onnxruntime, the following training artifacts must be generated:
These training artifacts can be generated as part of an offline step using the python utilities made available in the onnxruntime-training
python package.
After these artifacts have been generated, the C and C++ utilities listed in this documentation can be leveraged to perform training.
This C structure contains functions that enable users to perform training with onnxruntime.
#include <onnxruntime_training_api.h>
g_ort_api->
CreateEnv(logging_level, logid, &env);
state, eval_model_path, optimizer_model_path,
&training_session);
{
}
struct OrtSessionOptions OrtSessionOptions
Definition onnxruntime_c_api.h:288
struct OrtEnv OrtEnv
Definition onnxruntime_c_api.h:276
#define ORT_API_VERSION
The API version defined in this header.
Definition onnxruntime_c_api.h:40
const OrtApiBase * OrtGetApiBase(void)
The Onnxruntime library's entry point to access the C API.
struct OrtTrainingSession OrtTrainingSession
Definition onnxruntime_training_c_api.h:104
struct OrtCheckpointState OrtCheckpointState
Definition onnxruntime_training_c_api.h:105
const OrtApi *(* GetApi)(uint32_t version)
Get a pointer to the requested version of the OrtApi.
Definition onnxruntime_c_api.h:655
The C API.
Definition onnxruntime_c_api.h:715
OrtStatus * CreateSessionOptions(OrtSessionOptions **options)
Create an OrtSessionOptions object.
const OrtTrainingApi *(* GetTrainingApi)(uint32_t version)
Gets the Training C Api struct.
Definition onnxruntime_c_api.h:3658
OrtStatus * CreateEnv(OrtLoggingLevel log_severity_level, const char *logid, OrtEnv **out)
Create an OrtEnv.
The Training C API that holds onnxruntime training function pointers.
Definition onnxruntime_training_c_api.h:122
OrtStatus * LazyResetGrad(OrtTrainingSession *session)
Reset the gradients of all trainable parameters to zero lazily.
OrtStatus * CreateTrainingSession(const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const char *train_model_path, const char *eval_model_path, const char *optimizer_model_path, OrtTrainingSession **out)
Create a training session that can be used to begin or resume training.
OrtStatus * LoadCheckpoint(const char *checkpoint_path, OrtCheckpointState **checkpoint_state)
Load a checkpoint state from a file on disk into checkpoint_state.
OrtStatus * TrainStep(OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
Computes the outputs of the training model and the gradients of the trainable parameters for the give...
OrtStatus * ExportModelForInferencing(OrtTrainingSession *sess, const char *inference_model_path, size_t graph_outputs_len, const char *const *graph_output_names)
Export a model that can be used for inferencing.
void ReleaseTrainingSession(OrtTrainingSession *input)
Frees up the memory used up by the training session.
void ReleaseCheckpointState(OrtCheckpointState *input)
Frees up the memory used up by the checkpoint state.
OrtStatus * SaveCheckpoint(OrtCheckpointState *checkpoint_state, const char *checkpoint_path, const bool include_optimizer_state)
Save the given state to a checkpoint file on disk.
OrtStatus * OptimizerStep(OrtTrainingSession *sess, const OrtRunOptions *run_options)
Performs the weight updates for the trainable parameters using the optimizer model.
These C++ classes and functions enable users to perform training with onnxruntime.
#include <onnxruntime_training_cxx_api.h>
eval_model_path, optimizer_model_path);
{
training_session.TrainStep(...);
training_session.OptimizerStep(...);
training_session.LazyResetGrad(...);
}
training_session->ExportModelForInferencing(inference_model_path, ...);
static CheckpointState LoadCheckpoint(const std::basic_string< char > &path_to_checkpoint)
Load a checkpoint state from a file on disk into checkpoint_state.
static void SaveCheckpoint(const CheckpointState &checkpoint_state, const std::basic_string< char > &path_to_checkpoint, const bool include_optimizer_state=false)
Save the given state to a checkpoint file on disk.
Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.
Definition onnxruntime_training_cxx_api.h:152
The Env (Environment)
Definition onnxruntime_cxx_api.h:697
Wrapper around OrtSessionOptions.
Definition onnxruntime_cxx_api.h:910