Holds the state of the training session.
More...
#include <onnxruntime_training_cxx_api.h>
|
void | AddProperty (const std::string &property_name, const Property &property_value) |
| Adds or updates the given property to/in the checkpoint state.
|
|
Property | GetProperty (const std::string &property_name) |
| Gets the property value associated with the given name from the checkpoint state.
|
|
void | UpdateParameter (const std::string ¶meter_name, const Value ¶meter) |
| Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
|
|
Value | GetParameter (const std::string ¶meter_name) |
| Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
|
|
static CheckpointState | LoadCheckpoint (const std::basic_string< char > &path_to_checkpoint) |
| Load a checkpoint state from a file on disk into checkpoint_state.
|
|
static CheckpointState | LoadCheckpointFromBuffer (const std::vector< uint8_t > &buffer) |
| Load a checkpoint state from a buffer.
|
|
static void | SaveCheckpoint (const CheckpointState &checkpoint_state, const std::basic_string< char > &path_to_checkpoint, const bool include_optimizer_state=false) |
| Save the given state to a checkpoint file on disk.
|
|
Holds the state of the training session.
This class holds the entire training session state that includes model parameters, their gradients, optimizer parameters, and user properties. The Ort::TrainingSession leverages the Ort::CheckpointState by accessing and updating the contained training state.
- Note
- Note that the training session created with a checkpoint state uses this state to store the entire training state (including model parameters, its gradients, the optimizer states and the properties). The Ort::TrainingSession does not hold a copy of the Ort::CheckpointState and as a result, it is required that the checkpoint state outlive the lifetime of the training session.
◆ CheckpointState()
Ort::CheckpointState::CheckpointState |
( |
| ) |
|
|
delete |
◆ AddProperty()
void Ort::CheckpointState::AddProperty |
( |
const std::string & |
property_name, |
|
|
const Property & |
property_value |
|
) |
| |
Adds or updates the given property to/in the checkpoint state.
Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint state by the user by calling this function with the corresponding property name and value. The given property name must be unique to be able to successfully add the property.
- Parameters
-
[in] | property_name | Name of the property being added or updated. |
[in] | property_value | Property value associated with the given name. |
◆ GetParameter()
Value Ort::CheckpointState::GetParameter |
( |
const std::string & |
parameter_name | ) |
|
Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
This function retrieves the model parameter data from the checkpoint state for the given parameter name. The parameter is copied over to the provided OrtValue. The training session must be already created with the checkpoint state that contains the parameter being retrieved. The parameter must exist in the checkpoint state to be able to retrieve it successfully.
- Parameters
-
[in] | parameter_name | Name of the parameter being retrieved. |
- Returns
- The parameter data that is retrieved from the checkpoint state.
◆ GetProperty()
Property Ort::CheckpointState::GetProperty |
( |
const std::string & |
property_name | ) |
|
Gets the property value associated with the given name from the checkpoint state.
Gets the property value from an existing entry in the checkpoint state. The property must exist in the checkpoint state to be able to retrieve it successfully.
- Parameters
-
[in] | property_name | Name of the property being retrieved. |
- Returns
- Property value associated with the given property name.
◆ LoadCheckpoint()
static CheckpointState Ort::CheckpointState::LoadCheckpoint |
( |
const std::basic_string< char > & |
path_to_checkpoint | ) |
|
|
static |
Load a checkpoint state from a file on disk into checkpoint_state.
This function will parse a checkpoint file, pull relevant data and load the training state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the training session by instantiating Ort::TrainingSession. By doing so, the training session will resume training from the given checkpoint state.
- Parameters
-
[in] | path_to_checkpoint | Path to the checkpoint file |
- Returns
- Ort::CheckpointState object which holds the state of the training session parameters.
◆ LoadCheckpointFromBuffer()
static CheckpointState Ort::CheckpointState::LoadCheckpointFromBuffer |
( |
const std::vector< uint8_t > & |
buffer | ) |
|
|
static |
Load a checkpoint state from a buffer.
This function will parse a checkpoint buffer, pull relevant data and load the training state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the training session by instantiating Ort::TrainingSession. By doing so, the training session will resume training from the given checkpoint state.
- Parameters
-
[in] | buffer | Buffer containing the checkpoint data. |
- Returns
- Ort::CheckpointState object which holds the state of the training session parameters.
◆ SaveCheckpoint()
static void Ort::CheckpointState::SaveCheckpoint |
( |
const CheckpointState & |
checkpoint_state, |
|
|
const std::basic_string< char > & |
path_to_checkpoint, |
|
|
const bool |
include_optimizer_state = false |
|
) |
| |
|
static |
Save the given state to a checkpoint file on disk.
This function serializes the provided checkpoint state to a file on disk. This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume training from this snapshot of the state.
- Parameters
-
[in] | checkpoint_state | The checkpoint state to save. |
[in] | path_to_checkpoint | Path to the checkpoint file. |
[in] | include_optimizer_state | Flag to indicate whether to save the optimizer state or not. |
◆ UpdateParameter()
void Ort::CheckpointState::UpdateParameter |
( |
const std::string & |
parameter_name, |
|
|
const Value & |
parameter |
|
) |
| |
Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
This function updates a model parameter in the checkpoint state with the given parameter data. The training session must be already created with the checkpoint state that contains the parameter being updated. The given parameter is copied over to the registered device for the training session. The parameter must exist in the checkpoint state to be able to update it successfully.
- Parameters
-
[in] | parameter_name | Name of the parameter being updated. |
[in] | parameter | The parameter data that should replace the existing parameter data. |