ONNX Runtime
Loading...
Searching...
No Matches
Ort::CheckpointState Class Reference

Holds the state of the training session. More...

#include <onnxruntime_training_cxx_api.h>

Inheritance diagram for Ort::CheckpointState:
Ort::detail::Base< OrtCheckpointState >

Public Member Functions

 CheckpointState ()=delete
 
- Public Member Functions inherited from Ort::detail::Base< OrtCheckpointState >
constexpr Base ()=default
 
constexpr Base (contained_type *p) noexcept
 
 Base (const Base &)=delete
 
 Base (Base &&v) noexcept
 
 ~Base ()
 
Baseoperator= (const Base &)=delete
 
Baseoperator= (Base &&v) noexcept
 
constexpr operator contained_type * () const noexcept
 
contained_typerelease ()
 Relinquishes ownership of the contained C object pointer The underlying object is not destroyed.
 

Accessing The Training Session State

void AddProperty (const std::string &property_name, const Property &property_value)
 Adds or updates the given property to/in the checkpoint state.
 
Property GetProperty (const std::string &property_name)
 Gets the property value associated with the given name from the checkpoint state.
 
void UpdateParameter (const std::string &parameter_name, const Value &parameter)
 Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
 
Value GetParameter (const std::string &parameter_name)
 Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
 
static CheckpointState LoadCheckpoint (const std::basic_string< char > &path_to_checkpoint)
 Load a checkpoint state from a file on disk into checkpoint_state.
 
static CheckpointState LoadCheckpointFromBuffer (const std::vector< uint8_t > &buffer)
 Load a checkpoint state from a buffer.
 
static void SaveCheckpoint (const CheckpointState &checkpoint_state, const std::basic_string< char > &path_to_checkpoint, const bool include_optimizer_state=false)
 Save the given state to a checkpoint file on disk.
 

Additional Inherited Members

- Public Types inherited from Ort::detail::Base< OrtCheckpointState >
using contained_type = OrtCheckpointState
 
- Protected Attributes inherited from Ort::detail::Base< OrtCheckpointState >
contained_typep_
 

Detailed Description

Holds the state of the training session.

This class holds the entire training session state that includes model parameters, their gradients, optimizer parameters, and user properties. The Ort::TrainingSession leverages the Ort::CheckpointState by accessing and updating the contained training state.

Note
Note that the training session created with a checkpoint state uses this state to store the entire training state (including model parameters, its gradients, the optimizer states and the properties). The Ort::TrainingSession does not hold a copy of the Ort::CheckpointState and as a result, it is required that the checkpoint state outlive the lifetime of the training session.

Constructor & Destructor Documentation

◆ CheckpointState()

Ort::CheckpointState::CheckpointState ( )
delete

Member Function Documentation

◆ AddProperty()

void Ort::CheckpointState::AddProperty ( const std::string &  property_name,
const Property property_value 
)

Adds or updates the given property to/in the checkpoint state.

Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint state by the user by calling this function with the corresponding property name and value. The given property name must be unique to be able to successfully add the property.

Parameters
[in]property_nameName of the property being added or updated.
[in]property_valueProperty value associated with the given name.

◆ GetParameter()

Value Ort::CheckpointState::GetParameter ( const std::string &  parameter_name)

Gets the data associated with the model parameter from the checkpoint state for the given parameter name.

This function retrieves the model parameter data from the checkpoint state for the given parameter name. The parameter is copied over to the provided OrtValue. The training session must be already created with the checkpoint state that contains the parameter being retrieved. The parameter must exist in the checkpoint state to be able to retrieve it successfully.

Parameters
[in]parameter_nameName of the parameter being retrieved.
Returns
The parameter data that is retrieved from the checkpoint state.

◆ GetProperty()

Property Ort::CheckpointState::GetProperty ( const std::string &  property_name)

Gets the property value associated with the given name from the checkpoint state.

Gets the property value from an existing entry in the checkpoint state. The property must exist in the checkpoint state to be able to retrieve it successfully.

Parameters
[in]property_nameName of the property being retrieved.
Returns
Property value associated with the given property name.

◆ LoadCheckpoint()

static CheckpointState Ort::CheckpointState::LoadCheckpoint ( const std::basic_string< char > &  path_to_checkpoint)
static

Load a checkpoint state from a file on disk into checkpoint_state.

This function will parse a checkpoint file, pull relevant data and load the training state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the training session by instantiating Ort::TrainingSession. By doing so, the training session will resume training from the given checkpoint state.

Parameters
[in]path_to_checkpointPath to the checkpoint file
Returns
Ort::CheckpointState object which holds the state of the training session parameters.

◆ LoadCheckpointFromBuffer()

static CheckpointState Ort::CheckpointState::LoadCheckpointFromBuffer ( const std::vector< uint8_t > &  buffer)
static

Load a checkpoint state from a buffer.

This function will parse a checkpoint buffer, pull relevant data and load the training state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the training session by instantiating Ort::TrainingSession. By doing so, the training session will resume training from the given checkpoint state.

Parameters
[in]bufferBuffer containing the checkpoint data.
Returns
Ort::CheckpointState object which holds the state of the training session parameters.

◆ SaveCheckpoint()

static void Ort::CheckpointState::SaveCheckpoint ( const CheckpointState checkpoint_state,
const std::basic_string< char > &  path_to_checkpoint,
const bool  include_optimizer_state = false 
)
static

Save the given state to a checkpoint file on disk.

This function serializes the provided checkpoint state to a file on disk. This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume training from this snapshot of the state.

Parameters
[in]checkpoint_stateThe checkpoint state to save.
[in]path_to_checkpointPath to the checkpoint file.
[in]include_optimizer_stateFlag to indicate whether to save the optimizer state or not.

◆ UpdateParameter()

void Ort::CheckpointState::UpdateParameter ( const std::string &  parameter_name,
const Value parameter 
)

Updates the data associated with the model parameter in the checkpoint state for the given parameter name.

This function updates a model parameter in the checkpoint state with the given parameter data. The training session must be already created with the checkpoint state that contains the parameter being updated. The given parameter is copied over to the registered device for the training session. The parameter must exist in the checkpoint state to be able to update it successfully.

Parameters
[in]parameter_nameName of the parameter being updated.
[in]parameterThe parameter data that should replace the existing parameter data.