ONNX Runtime
Loading...
Searching...
No Matches
Ort::CheckpointState Class Reference

Holds the state of the training session. More...

#include <onnxruntime_training_cxx_api.h>

Inheritance diagram for Ort::CheckpointState:
Ort::detail::Base< OrtCheckpointState >

Public Member Functions

 CheckpointState ()=delete
 
- Public Member Functions inherited from Ort::detail::Base< OrtCheckpointState >
constexpr Base ()=default
 
constexpr Base (contained_type *p) noexcept
 
 Base (const Base &)=delete
 
 Base (Base &&v) noexcept
 
 ~Base ()
 
Baseoperator= (const Base &)=delete
 
Baseoperator= (Base &&v) noexcept
 
constexpr operator contained_type * () const noexcept
 
contained_typerelease ()
 Relinquishes ownership of the contained C object pointer The underlying object is not destroyed.
 

Accessing The Training Session State

void AddProperty (const std::string &property_name, const Property &property_value)
 Adds the given property to the checkpoint state.
 
Property GetProperty (const std::string &property_name)
 Gets the property value associated with the given name from the checkpoint state.
 
static CheckpointState LoadCheckpoint (const std::basic_string< char > &path_to_checkpoint)
 Load a checkpoint state from a file on disk into checkpoint_state.
 
static CheckpointState LoadCheckpointFromBuffer (const std::vector< uint8_t > &buffer)
 Load a checkpoint state from a buffer.
 
static void SaveCheckpoint (const CheckpointState &checkpoint_state, const std::basic_string< char > &path_to_checkpoint, const bool include_optimizer_state=false)
 Save the given state to a checkpoint file on disk.
 

Additional Inherited Members

- Public Types inherited from Ort::detail::Base< OrtCheckpointState >
using contained_type = OrtCheckpointState
 
- Protected Attributes inherited from Ort::detail::Base< OrtCheckpointState >
contained_typep_
 

Detailed Description

Holds the state of the training session.

This class holds the entire training session state that includes model parameters, their gradients, optimizer parameters, and user properties. The Ort::TrainingSession leverages the Ort::CheckpointState by accessing and updating the contained training state.

Note
Note that the training session created with a checkpoint state uses this state to store the entire training state (including model parameters, its gradients, the optimizer states and the properties). The Ort::TrainingSession does not hold a copy of the Ort::CheckpointState and as a result, it is required that the checkpoint state outlive the lifetime of the training session.

Constructor & Destructor Documentation

◆ CheckpointState()

Ort::CheckpointState::CheckpointState ( )
delete

Member Function Documentation

◆ AddProperty()

void Ort::CheckpointState::AddProperty ( const std::string &  property_name,
const Property property_value 
)

Adds the given property to the checkpoint state.

Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint state by the user if they desire by calling this function with the appropriate property name and value. The given property name must be unique to be able to successfully add the property.

Parameters
[in]property_nameUnique name of the property being added.
[in]property_valueProperty value associated with the given name.

◆ GetProperty()

Property Ort::CheckpointState::GetProperty ( const std::string &  property_name)

Gets the property value associated with the given name from the checkpoint state.

Gets the property value from an existing entry in the checkpoint state. The property must exist in the checkpoint state to be able to retrieve it successfully.

Parameters
[in]property_nameUnique name of the property being retrieved.
Returns
Property value associated with the given property name.

◆ LoadCheckpoint()

static CheckpointState Ort::CheckpointState::LoadCheckpoint ( const std::basic_string< char > &  path_to_checkpoint)
static

Load a checkpoint state from a file on disk into checkpoint_state.

This function will parse a checkpoint file, pull relevant data and load the training state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the training session by instantiating Ort::TrainingSession. By doing so, the training session will resume training from the given checkpoint state.

Parameters
[in]path_to_checkpointPath to the checkpoint file
Returns
Ort::CheckpointState object which holds the state of the training session parameters.

◆ LoadCheckpointFromBuffer()

static CheckpointState Ort::CheckpointState::LoadCheckpointFromBuffer ( const std::vector< uint8_t > &  buffer)
static

Load a checkpoint state from a buffer.

This function will parse a checkpoint buffer, pull relevant data and load the training state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the training session by instantiating Ort::TrainingSession. By doing so, the training session will resume training from the given checkpoint state.

Parameters
[in]bufferBuffer containing the checkpoint data.
Returns
Ort::CheckpointState object which holds the state of the training session parameters.

◆ SaveCheckpoint()

static void Ort::CheckpointState::SaveCheckpoint ( const CheckpointState checkpoint_state,
const std::basic_string< char > &  path_to_checkpoint,
const bool  include_optimizer_state = false 
)
static

Save the given state to a checkpoint file on disk.

This function serializes the provided checkpoint state to a file on disk. This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume training from this snapshot of the state.

Parameters
[in]checkpoint_stateThe checkpoint state to save.
[in]path_to_checkpointPath to the checkpoint file.
[in]include_optimizer_stateFlag to indicate whether to save the optimizer state or not.