ONNX Runtime
Loading...
Searching...
No Matches
OrtTrainingApi Struct Reference

The Training C API that holds onnxruntime training function pointers. More...

#include <onnxruntime_training_c_api.h>

Accessing The Training Session State

OrtStatusLoadCheckpoint (const char *checkpoint_path, OrtCheckpointState **checkpoint_state)
 Load a checkpoint state from a file on disk into checkpoint_state.
 
OrtStatusSaveCheckpoint (OrtCheckpointState *checkpoint_state, const char *checkpoint_path, const bool include_optimizer_state)
 Save the given state to a checkpoint file on disk.
 
OrtStatusGetParametersSize (OrtTrainingSession *sess, size_t *out, bool trainable_only)
 Retrieves the size of all the parameters.
 
OrtStatusCopyParametersToBuffer (OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only)
 Copy all parameters to a contiguous buffer held by the argument parameters_buffer.
 
OrtStatusCopyBufferToParameters (OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only)
 Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state.
 
OrtStatusAddProperty (OrtCheckpointState *checkpoint_state, const char *property_name, enum OrtPropertyType property_type, void *property_value)
 Adds or updates the given property to/in the checkpoint state.
 
OrtStatusGetProperty (const OrtCheckpointState *checkpoint_state, const char *property_name, OrtAllocator *allocator, enum OrtPropertyType *property_type, void **property_value)
 Gets the property value associated with the given name from the checkpoint state.
 
OrtStatusLoadCheckpointFromBuffer (const void *checkpoint_buffer, const size_t num_bytes, OrtCheckpointState **checkpoint_state)
 Load a checkpoint state from a buffer into checkpoint_state.
 
OrtStatusGetParameterTypeAndShape (const OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtTensorTypeAndShapeInfo **parameter_type_and_shape)
 Retrieves the type and shape information of the parameter associated with the given parameter name.
 
OrtStatusUpdateParameter (OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtValue *parameter)
 Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
 
OrtStatusGetParameter (const OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtAllocator *allocator, OrtValue **parameter)
 Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
 

Implementing The Training Loop

OrtStatusCreateTrainingSession (const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const char *train_model_path, const char *eval_model_path, const char *optimizer_model_path, OrtTrainingSession **out)
 Create a training session that can be used to begin or resume training.
 
OrtStatusCreateTrainingSessionFromBuffer (const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const void *train_model_data, size_t train_data_length, const void *eval_model_data, size_t eval_data_length, const void *optim_model_data, size_t optim_data_length, OrtTrainingSession **out)
 Create a training session that can be used to begin or resume training. This api provides a way to load all the training artifacts from buffers instead of files.
 
OrtStatusLazyResetGrad (OrtTrainingSession *session)
 Reset the gradients of all trainable parameters to zero lazily.
 
OrtStatusTrainStep (OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
 Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs.
 
OrtStatusEvalStep (const OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
 Computes the outputs for the eval model for the given inputs.
 
OrtStatusSetLearningRate (OrtTrainingSession *sess, float learning_rate)
 Sets the learning rate for this training session.
 
OrtStatusGetLearningRate (OrtTrainingSession *sess, float *learning_rate)
 Gets the current learning rate for this training session.
 
OrtStatusOptimizerStep (OrtTrainingSession *sess, const OrtRunOptions *run_options)
 Performs the weight updates for the trainable parameters using the optimizer model.
 
OrtStatusRegisterLinearLRScheduler (OrtTrainingSession *sess, const int64_t warmup_step_count, const int64_t total_step_count, const float initial_lr)
 Registers a linear learning rate scheduler for the training session.
 
OrtStatusSchedulerStep (OrtTrainingSession *sess)
 Update the learning rate based on the registered learing rate scheduler.
 

Model IO Information

OrtStatusTrainingSessionGetTrainingModelOutputCount (const OrtTrainingSession *sess, size_t *out)
 Retrieves the number of user outputs in the training model.
 
OrtStatusTrainingSessionGetEvalModelOutputCount (const OrtTrainingSession *sess, size_t *out)
 Retrieves the number of user outputs in the eval model.
 
OrtStatusTrainingSessionGetTrainingModelOutputName (const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
 Retrieves the names of user outputs in the training model.
 
OrtStatusTrainingSessionGetEvalModelOutputName (const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
 Retrieves the names of user outputs in the eval model.
 
OrtStatusTrainingSessionGetTrainingModelInputCount (const OrtTrainingSession *sess, size_t *out)
 Retrieves the number of user inputs in the training model.
 
OrtStatusTrainingSessionGetEvalModelInputCount (const OrtTrainingSession *sess, size_t *out)
 Retrieves the number of user inputs in the eval model.
 
OrtStatusTrainingSessionGetTrainingModelInputName (const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
 Retrieves the name of the user input at given index in the training model.
 
OrtStatusTrainingSessionGetEvalModelInputName (const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
 Retrieves the name of the user input at given index in the eval model.
 

Release Training Resources

void ReleaseTrainingSession (OrtTrainingSession *input)
 Frees up the memory used up by the training session.
 
void ReleaseCheckpointState (OrtCheckpointState *input)
 Frees up the memory used up by the checkpoint state.
 

Prepare For Inferencing

OrtStatusExportModelForInferencing (OrtTrainingSession *sess, const char *inference_model_path, size_t graph_outputs_len, const char *const *graph_output_names)
 Export a model that can be used for inferencing.
 

Training Utilities

OrtStatusSetSeed (const int64_t seed)
 Sets the seed used for random number generation in Onnxruntime.
 

Detailed Description

The Training C API that holds onnxruntime training function pointers.

All the Training C API functions are defined inside this structure as pointers to functions. Call OrtApi::GetTrainingApi to get a pointer to this struct.

Member Function Documentation

◆ AddProperty()

OrtStatus * OrtTrainingApi::AddProperty ( OrtCheckpointState checkpoint_state,
const char *  property_name,
enum OrtPropertyType  property_type,
void *  property_value 
)

Adds or updates the given property to/in the checkpoint state.

Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint state by the user by calling this function with the corresponding property name and value. The given property name must be unique to be able to successfully add the property.

Parameters
[in]checkpoint_stateThe checkpoint state which should hold the property.
[in]property_nameName of the property being added or updated.
[in]property_typeType of the property associated with the given name.
[in]property_valueProperty value associated with the given name.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ CopyBufferToParameters()

OrtStatus * OrtTrainingApi::CopyBufferToParameters ( OrtTrainingSession sess,
OrtValue parameters_buffer,
bool  trainable_only 
)

Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state.

The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call, with matching setting for trainable_only argument. All the target parameters must be of the same datatype. This is a complementary function to OrtTrainingApi::CopyBufferToParameters and can be used to load updated buffer values onto the training state. Parameter ordering is preserved. User is responsible for allocating and freeing the resources used by the parameters_buffer.

Parameters
[in]sessThe this pointer to the training session.
[in]trainable_onlyWhether to skip non-trainable parameters
[out]parameters_bufferThe pre-allocated OrtValue buffer to copy from.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ CopyParametersToBuffer()

OrtStatus * OrtTrainingApi::CopyParametersToBuffer ( OrtTrainingSession sess,
OrtValue parameters_buffer,
bool  trainable_only 
)

Copy all parameters to a contiguous buffer held by the argument parameters_buffer.

The parameters_buffer has to be of the size given by GetParametersSize api call, with matching setting for the argument trainable_only. All the target parameters must be of the same datatype. The OrtValue must be pre-allocated onto the desired device. This is a complementary function to OrtTrainingApi::CopyBufferToParameters. Parameter ordering is preserved. User is responsible for allocating and freeing the resources used by the parameters_buffer.

Parameters
[in]sessThe this pointer to the training session.
[in]trainable_onlyWhether to skip non-trainable parameters
[out]parameters_bufferThe pre-allocated OrtValue buffer to copy onto.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ CreateTrainingSession()

OrtStatus * OrtTrainingApi::CreateTrainingSession ( const OrtEnv env,
const OrtSessionOptions options,
OrtCheckpointState checkpoint_state,
const char *  train_model_path,
const char *  eval_model_path,
const char *  optimizer_model_path,
OrtTrainingSession **  out 
)

Create a training session that can be used to begin or resume training.

This function creates a training session based on the env and session options provided that can begin or resume training from a given checkpoint state for the given onnx models. The checkpoint state represents the parameters of the training session which will be moved to the device specified by the user through the session options (if necessary). The training session requires four training artifacts

  • The training onnx model
  • The evaluation onnx model (optional)
  • The optimizer onnx model
  • The checkpoint file

These artifacts can be generated using the onnxruntime-training python utility.

Parameters
[in]envEnvironment to be used for the training session.
[in]optionsSession options that the user can customize for this training session.
[in]checkpoint_stateTraining states that the training session uses as a starting point for training.
[in]train_model_pathModel to be used to perform training.
[in]eval_model_pathModel to be used to perform evaluation.
[in]optimizer_model_pathModel to be used to perform gradient descent.
[out]outCreated training session.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ CreateTrainingSessionFromBuffer()

OrtStatus * OrtTrainingApi::CreateTrainingSessionFromBuffer ( const OrtEnv env,
const OrtSessionOptions options,
OrtCheckpointState checkpoint_state,
const void *  train_model_data,
size_t  train_data_length,
const void *  eval_model_data,
size_t  eval_data_length,
const void *  optim_model_data,
size_t  optim_data_length,
OrtTrainingSession **  out 
)

Create a training session that can be used to begin or resume training. This api provides a way to load all the training artifacts from buffers instead of files.

Parameters
[in]envEnvironment to be used for the training session.
[in]optionsSession options that the user can customize for this training session.
[in]checkpoint_stateTraining states that the training session uses as a starting point for training.
[in]train_model_dataBuffer containing the model data to be used to perform training
[in]train_data_lengthLength of the buffer containing train_model_data
[in]eval_model_dataBuffer containing the model data to be used to perform evaluation
[in]eval_data_lengthLength of the buffer containing eval_model_data
[in]optim_model_dataBuffer containing the model data to be used to perform weight update
[in]optim_data_lengthLength of the buffer containing optim_model_data
[out]outCreated training session.

◆ EvalStep()

OrtStatus * OrtTrainingApi::EvalStep ( const OrtTrainingSession sess,
const OrtRunOptions run_options,
size_t  inputs_len,
const OrtValue *const *  inputs,
size_t  outputs_len,
OrtValue **  outputs 
)

Computes the outputs for the eval model for the given inputs.

This function performs an eval step that computes the outputs of the eval model for the given inputs. The eval step is performed based on the eval model that was provided to the training session.

Parameters
[in]sessThe this pointer to the training session.
[in]run_optionsRun options for this eval step.
[in]inputs_lenNumber of user inputs to the eval model.
[in]inputsThe user inputs to the eval model.
[in]outputs_lenNumber of user outputs expected from this eval step.
[out]outputsUser outputs computed by eval step.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ ExportModelForInferencing()

OrtStatus * OrtTrainingApi::ExportModelForInferencing ( OrtTrainingSession sess,
const char *  inference_model_path,
size_t  graph_outputs_len,
const char *const *  graph_output_names 
)

Export a model that can be used for inferencing.

If the training session was provided with an eval model, the training session can generate an inference model if it knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the inference model's outputs align with the provided outputs. The exported model is saved at the path provided and can be used for inferencing with InferenceSession.

Note
Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession and expects that this path still be valid.
Parameters
[in]sessThe this pointer to the training session.
[in]inference_model_pathPath where the inference model should be serialized to.
[in]graph_outputs_lenSize of the graph output names array.
[in]graph_output_namesNames of the outputs that are needed in the inference model.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ GetLearningRate()

OrtStatus * OrtTrainingApi::GetLearningRate ( OrtTrainingSession sess,
float *  learning_rate 
)

Gets the current learning rate for this training session.

This function allows users to get the learning rate for the training session. The current learning rate is maintained by the training session, and users can query it for the purpose of implementing their own learning rate schedulers.

Parameters
[in]sessThe this pointer to the training session.
[out]learning_rateLearning rate currently in use by the training session.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ GetParameter()

OrtStatus * OrtTrainingApi::GetParameter ( const OrtCheckpointState checkpoint_state,
const char *  parameter_name,
OrtAllocator allocator,
OrtValue **  parameter 
)

Gets the data associated with the model parameter from the checkpoint state for the given parameter name.

This function retrieves the model parameter data from the checkpoint state for the given parameter name. The parameter is copied over and returned as an OrtValue. The training session must be already created with the checkpoint state that contains the parameter being retrieved. The parameter must exist in the checkpoint state to be able to retrieve it successfully.

Parameters
[in]checkpoint_stateThe checkpoint state.
[in]parameter_nameName of the parameter being retrieved.
[in]allocatorAllocator used to allocate the memory for the parameter.
[out]parameterThe parameter data that is retrieved from the checkpoint state.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ GetParametersSize()

OrtStatus * OrtTrainingApi::GetParametersSize ( OrtTrainingSession sess,
size_t *  out,
bool  trainable_only 
)

Retrieves the size of all the parameters.

Calculates the total number of primitive (datatype of the parameters) elements of all the parameters in the training state. When trainable_only argument is true, the size is calculated for trainable params only.

Parameters
[in]sessThe this pointer to the training session.
[out]outSize of all parameter elements.
[in]trainable_onlyWhether to skip non-trainable parameters

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ GetParameterTypeAndShape()

OrtStatus * OrtTrainingApi::GetParameterTypeAndShape ( const OrtCheckpointState checkpoint_state,
const char *  parameter_name,
OrtTensorTypeAndShapeInfo **  parameter_type_and_shape 
)

Retrieves the type and shape information of the parameter associated with the given parameter name.

This function retrieves the type and shape of the parameter associated with the given parameter name. The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully.

Parameters
[in]checkpoint_stateThe checkpoint state.
[in]parameter_nameName of the parameter being retrieved.
[out]parameter_type_and_shapeThe type and shape of the parameter being retrieved.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ GetProperty()

OrtStatus * OrtTrainingApi::GetProperty ( const OrtCheckpointState checkpoint_state,
const char *  property_name,
OrtAllocator allocator,
enum OrtPropertyType property_type,
void **  property_value 
)

Gets the property value associated with the given name from the checkpoint state.

Gets the property value from an existing entry in the checkpoint state. The property must exist in the checkpoint state to be able to retrieve it successfully.

Parameters
[in]checkpoint_stateThe checkpoint state that is currently holding the property.
[in]property_nameName of the property being retrieved.
[in]allocatorAllocator used to allocate the memory for the property_value.
[out]property_typeType of the property associated with the given name.
[out]property_valueProperty value associated with the given name.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ LazyResetGrad()

OrtStatus * OrtTrainingApi::LazyResetGrad ( OrtTrainingSession session)

Reset the gradients of all trainable parameters to zero lazily.

This function sets the internal state of the training session such that the gradients of the trainable parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are computed on the next invocation of the next OrtTrainingApi::TrainStep.

Parameters
[in]sessionThe this pointer to the training session.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ LoadCheckpoint()

OrtStatus * OrtTrainingApi::LoadCheckpoint ( const char *  checkpoint_path,
OrtCheckpointState **  checkpoint_state 
)

Load a checkpoint state from a file on disk into checkpoint_state.

This function will parse a checkpoint file, pull relevant data and load the training state into the checkpoint_state. This checkpoint state can then be used to create the training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training session will resume training from the given checkpoint state.

Note
Note that the training session created with a checkpoint state uses this state to store the entire training state (including model parameters, its gradients, the optimizer states and the properties). As a result, it is required that the checkpoint state outlive the lifetime of the training session.
Parameters
[in]checkpoint_pathPath to the checkpoint file
[out]checkpoint_stateCheckpoint state that contains the states of the training session.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ LoadCheckpointFromBuffer()

OrtStatus * OrtTrainingApi::LoadCheckpointFromBuffer ( const void *  checkpoint_buffer,
const size_t  num_bytes,
OrtCheckpointState **  checkpoint_state 
)

Load a checkpoint state from a buffer into checkpoint_state.

This function will parse a checkpoint bytes buffer, pull relevant data and load the training state into the checkpoint_state. This checkpoint state can then be used to create the training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training session will resume training from the given checkpoint state.

Note
Note that the training session created with a checkpoint state uses this state to store the entire training state (including model parameters, its gradients, the optimizer states and the properties). As a result, it is required that the checkpoint state outlive the lifetime of the training session.
Parameters
[in]checkpoint_bufferPath to the checkpoint bytes buffer.
[in]num_bytesNumber of bytes in the checkpoint buffer.
[out]checkpoint_stateCheckpoint state that contains the states of the training session.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ OptimizerStep()

OrtStatus * OrtTrainingApi::OptimizerStep ( OrtTrainingSession sess,
const OrtRunOptions run_options 
)

Performs the weight updates for the trainable parameters using the optimizer model.

This function performs the weight update step that updates the trainable parameters such that they take a step in the direction of their gradients (gradient descent). The optimizer step is performed based on the optimizer model that was provided to the training session. The updated parameters are stored inside the training state so that they can be used by the next OrtTrainingApi::TrainStep function call.

Parameters
[in]sessThe this pointer to the training session.
[in]run_optionsRun options for this optimizer step.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ RegisterLinearLRScheduler()

OrtStatus * OrtTrainingApi::RegisterLinearLRScheduler ( OrtTrainingSession sess,
const int64_t  warmup_step_count,
const int64_t  total_step_count,
const float  initial_lr 
)

Registers a linear learning rate scheduler for the training session.

Register a linear learning rate scheduler that decays the learning rate by linearly updated multiplicative factor from the initial learning rate set on the training session to 0. The decay is performed after the initial warm up phase where the learning rate is linearly incremented from 0 to the initial learning rate provided.

Parameters
[in]sessThe this pointer to the training session.
[in]warmup_step_countWarmup steps for LR warmup.
[in]total_step_countTotal step count.
[in]initial_lrThe initial learning rate to be used by the training session.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ ReleaseCheckpointState()

void OrtTrainingApi::ReleaseCheckpointState ( OrtCheckpointState input)

Frees up the memory used up by the checkpoint state.

This function frees up any memory that was allocated in the checkpoint state. The checkpoint state can no longer be used after this call.

Note
Note that the checkpoint state must be released only after the training session has been released.

◆ ReleaseTrainingSession()

void OrtTrainingApi::ReleaseTrainingSession ( OrtTrainingSession input)

Frees up the memory used up by the training session.

This function frees up any memory that was allocated in the training session. The training session can no longer be used after this call.

◆ SaveCheckpoint()

OrtStatus * OrtTrainingApi::SaveCheckpoint ( OrtCheckpointState checkpoint_state,
const char *  checkpoint_path,
const bool  include_optimizer_state 
)

Save the given state to a checkpoint file on disk.

This function serializes the provided checkpoint state to a file on disk. This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume training from this snapshot of the state.

Parameters
[in]checkpoint_stateThe checkpoint state to save.
[in]checkpoint_pathPath to the checkpoint file.
[in]include_optimizer_stateFlag to indicate whether to save the optimizer state or not.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ SchedulerStep()

OrtStatus * OrtTrainingApi::SchedulerStep ( OrtTrainingSession sess)

Update the learning rate based on the registered learing rate scheduler.

Takes a scheduler step that updates the learning rate that is being used by the training session. This function should typically be called before invoking the optimizer step for each round, or as determined necessary to update the learning rate being used by the training session.

Note
Please note that a valid predefined learning rate scheduler must be first registered to invoke this function.
Parameters
[in]sessThe this pointer to the training session.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ SetLearningRate()

OrtStatus * OrtTrainingApi::SetLearningRate ( OrtTrainingSession sess,
float  learning_rate 
)

Sets the learning rate for this training session.

This function allows users to set the learning rate for the training session. The current learning rate is maintained by the training session and can be overwritten by invoking this function with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered. It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant learning rate to be used throughout the training session.

Note
Please note that this function does not set the initial learning rate that may be needed by the predefined learning rate schedulers. To set the initial learning rate for learning rate schedulers, please look at the function OrtTrainingApi::RegisterLinearLRScheduler.
Parameters
[in]sessThe this pointer to the training session.
[in]learning_rateDesired learning rate to be set.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ SetSeed()

OrtStatus * OrtTrainingApi::SetSeed ( const int64_t  seed)

Sets the seed used for random number generation in Onnxruntime.

Use this function to generate reproducible results. It should be noted that completely reproducible results are not guaranteed.

Parameters
[in]seedThe seed to be set.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ TrainingSessionGetEvalModelInputCount()

OrtStatus * OrtTrainingApi::TrainingSessionGetEvalModelInputCount ( const OrtTrainingSession sess,
size_t *  out 
)

Retrieves the number of user inputs in the eval model.

This function returns the number of inputs of the eval model so that the user can accordingly allocate the OrtValue(s) provided to the OrtTrainingApi::EvalStep function.

Parameters
[in]sessThe this pointer to the training session.
[out]outNumber of user inputs in the eval model.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ TrainingSessionGetEvalModelInputName()

OrtStatus * OrtTrainingApi::TrainingSessionGetEvalModelInputName ( const OrtTrainingSession sess,
size_t  index,
OrtAllocator allocator,
char **  output 
)

Retrieves the name of the user input at given index in the eval model.

This function returns the names of inputs of the eval model that can be associated with the OrtValue(s) provided to the OrtTrainingApi::EvalStep function.

Parameters
[in]sessThe this pointer to the training session.
[in]indexThe index of the eval model input name requested.
[in]allocatorThe allocator to use to allocate the memory for the requested name.
[out]outputName of the user input for the eval model at the given index.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ TrainingSessionGetEvalModelOutputCount()

OrtStatus * OrtTrainingApi::TrainingSessionGetEvalModelOutputCount ( const OrtTrainingSession sess,
size_t *  out 
)

Retrieves the number of user outputs in the eval model.

This function returns the number of outputs of the eval model so that the user can allocate space for the number of outputs when OrtTrainingApi::EvalStep is invoked.

Parameters
[in]sessThe this pointer to the training session.
[out]outNumber of user outputs in the eval model.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ TrainingSessionGetEvalModelOutputName()

OrtStatus * OrtTrainingApi::TrainingSessionGetEvalModelOutputName ( const OrtTrainingSession sess,
size_t  index,
OrtAllocator allocator,
char **  output 
)

Retrieves the names of user outputs in the eval model.

This function returns the names of outputs of the eval model that can be associated with the OrtValue(s) returned by the OrtTrainingApi::EvalStep function.

Parameters
[in]sessThe this pointer to the training session.
[in]indexIndex of the output name requested.
[in]allocatorAllocator to use to allocate the memory for the name.
[out]outputName of the eval model output at the given index.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ TrainingSessionGetTrainingModelInputCount()

OrtStatus * OrtTrainingApi::TrainingSessionGetTrainingModelInputCount ( const OrtTrainingSession sess,
size_t *  out 
)

Retrieves the number of user inputs in the training model.

This function returns the number of inputs of the training model so that the user can accordingly allocate the OrtValue(s) provided to the OrtTrainingApi::TrainStep function.

Parameters
[in]sessThe this pointer to the training session.
[out]outNumber of user inputs in the training model.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ TrainingSessionGetTrainingModelInputName()

OrtStatus * OrtTrainingApi::TrainingSessionGetTrainingModelInputName ( const OrtTrainingSession sess,
size_t  index,
OrtAllocator allocator,
char **  output 
)

Retrieves the name of the user input at given index in the training model.

This function returns the names of inputs of the training model that can be associated with the OrtValue(s) provided to the OrtTrainingApi::TrainStep function.

Parameters
[in]sessThe this pointer to the training session.
[in]indexThe index of the training model input name requested.
[in]allocatorThe allocator to use to allocate the memory for the requested name.
[out]outputName of the user input for the training model at the given index.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ TrainingSessionGetTrainingModelOutputCount()

OrtStatus * OrtTrainingApi::TrainingSessionGetTrainingModelOutputCount ( const OrtTrainingSession sess,
size_t *  out 
)

Retrieves the number of user outputs in the training model.

This function returns the number of outputs of the training model so that the user can allocate space for the number of outputs when OrtTrainingApi::TrainStep is invoked.

Parameters
[in]sessThe this pointer to the training session.
[out]outNumber of user outputs in the training model.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ TrainingSessionGetTrainingModelOutputName()

OrtStatus * OrtTrainingApi::TrainingSessionGetTrainingModelOutputName ( const OrtTrainingSession sess,
size_t  index,
OrtAllocator allocator,
char **  output 
)

Retrieves the names of user outputs in the training model.

This function returns the names of outputs of the training model that can be associated with the OrtValue(s) returned by the OrtTrainingApi::TrainStep function.

Parameters
[in]sessThe this pointer to the training session.
[in]indexIndex of the output name requested.
[in]allocatorAllocator to use to allocate the memory for the name.
[out]outputName of the training model output at the given index.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ TrainStep()

OrtStatus * OrtTrainingApi::TrainStep ( OrtTrainingSession sess,
const OrtRunOptions run_options,
size_t  inputs_len,
const OrtValue *const *  inputs,
size_t  outputs_len,
OrtValue **  outputs 
)

Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs.

This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. The train step is performed based on the training model that was provided to the training session. The OrtTrainingApi::TrainStep is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the OrtTrainingApi::OptimizerStep function. The gradients can be lazily reset by invoking the OrtTrainingApi::LazyResetGrad function.

Parameters
[in]sessThe this pointer to the training session.
[in]run_optionsRun options for this training step.
[in]inputs_lenNumber of user inputs to the training model.
[in]inputsThe user inputs to the training model.
[in]outputs_lenNumber of user outputs expected from this training step.
[out]outputsUser outputs computed by train step.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.

◆ UpdateParameter()

OrtStatus * OrtTrainingApi::UpdateParameter ( OrtCheckpointState checkpoint_state,
const char *  parameter_name,
OrtValue parameter 
)

Updates the data associated with the model parameter in the checkpoint state for the given parameter name.

This function updates a model parameter in the checkpoint state with the given parameter data. The training session must be already created with the checkpoint state that contains the parameter being updated. The given parameter is copied over to the registered device for the training session. The parameter must exist in the checkpoint state to be able to update it successfully.

Parameters
[in]checkpoint_stateThe checkpoint state.
[in]parameter_nameName of the parameter being updated.
[in]parameterThe parameter data that should replace the existing parameter data.

Returns
If no error, nullptr will be returned. If there is an error, a pointer to an OrtStatus that contains error details will be returned. Use OrtApi::ReleaseStatus to free this pointer.