8#include "onnxruntime_c_api.h"
104ORT_RUNTIME_CLASS(TrainingSession);
105ORT_RUNTIME_CLASS(CheckpointState);
142 ORT_API2_STATUS(
LoadCheckpoint, _In_
const ORTCHAR_T* checkpoint_path,
159 const bool include_optimizer_state);
192 _In_
const ORTCHAR_T* eval_model_path, _In_
const ORTCHAR_T* optimizer_model_path,
212 _In_
const void* train_model_data,
size_t train_data_length,
213 _In_
const void* eval_model_data,
size_t eval_data_length,
214 _In_
const void* optim_model_data,
size_t optim_data_length,
318 _In_
size_t inputs_len, _In_reads_(inputs_len)
const OrtValue*
const* inputs,
319 _In_
size_t outputs_len, _Inout_updates_all_(outputs_len)
OrtValue** outputs);
337 _In_
size_t inputs_len, _In_reads_(inputs_len)
const OrtValue*
const* inputs,
338 _In_
size_t outputs_len, _Inout_updates_all_(outputs_len)
OrtValue** outputs);
407 _In_
const int64_t total_step_count, _In_
const float initial_lr);
460 _Inout_
OrtValue* parameters_buffer,
bool trainable_only);
479 _Inout_
OrtValue* parameters_buffer,
bool trainable_only);
492 ORT_CLASS_RELEASE(TrainingSession);
501 ORT_CLASS_RELEASE(CheckpointState);
525 _In_
const ORTCHAR_T* inference_model_path,
size_t graph_outputs_len,
526 _In_reads_(graph_outputs_len)
const char*
const* graph_output_names);
542 ORT_API2_STATUS(
SetSeed, _In_
const int64_t seed);
626 _In_
const char* property_name, _In_
enum OrtPropertyType property_type,
627 _In_
void* property_value);
644 _In_
const char* property_name, _Inout_
OrtAllocator* allocator,
702 _In_
const char* parameter_name, _In_
OrtValue* parameter);
720 _In_
const char* parameter_name, _Inout_
OrtAllocator* allocator,
struct OrtTensorTypeAndShapeInfo OrtTensorTypeAndShapeInfo
Definition onnxruntime_c_api.h:285
struct OrtRunOptions OrtRunOptions
Definition onnxruntime_c_api.h:283
struct OrtSessionOptions OrtSessionOptions
Definition onnxruntime_c_api.h:289
struct OrtValue OrtValue
Definition onnxruntime_c_api.h:282
struct OrtEnv OrtEnv
Definition onnxruntime_c_api.h:277
struct OrtTrainingSession OrtTrainingSession
Definition onnxruntime_training_c_api.h:104
struct OrtCheckpointState OrtCheckpointState
Definition onnxruntime_training_c_api.h:105
OrtPropertyType
Type of property to be added to or returned from the OrtCheckpointState.
Definition onnxruntime_training_c_api.h:109
@ OrtIntProperty
Definition onnxruntime_training_c_api.h:110
@ OrtStringProperty
Definition onnxruntime_training_c_api.h:112
@ OrtFloatProperty
Definition onnxruntime_training_c_api.h:111
Memory allocation interface.
Definition onnxruntime_c_api.h:317
The Training C API that holds onnxruntime training function pointers.
Definition onnxruntime_training_c_api.h:122
OrtStatus * CreateTrainingSessionFromBuffer(const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const void *train_model_data, size_t train_data_length, const void *eval_model_data, size_t eval_data_length, const void *optim_model_data, size_t optim_data_length, OrtTrainingSession **out)
Create a training session that can be used to begin or resume training. This api provides a way to lo...
OrtStatus * CopyBufferToParameters(OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only)
Copy parameter values from the given contiguous buffer held by parameters_buffer to the training stat...
OrtStatus * EvalStep(const OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
Computes the outputs for the eval model for the given inputs.
OrtStatus * LazyResetGrad(OrtTrainingSession *session)
Reset the gradients of all trainable parameters to zero lazily.
OrtStatus * TrainingSessionGetEvalModelInputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the name of the user input at given index in the eval model.
OrtStatus * CreateTrainingSession(const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const char *train_model_path, const char *eval_model_path, const char *optimizer_model_path, OrtTrainingSession **out)
Create a training session that can be used to begin or resume training.
OrtStatus * TrainingSessionGetTrainingModelInputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user inputs in the training model.
OrtStatus * LoadCheckpoint(const char *checkpoint_path, OrtCheckpointState **checkpoint_state)
Load a checkpoint state from a file on disk into checkpoint_state.
OrtStatus * TrainStep(OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
Computes the outputs of the training model and the gradients of the trainable parameters for the give...
OrtStatus * ExportModelForInferencing(OrtTrainingSession *sess, const char *inference_model_path, size_t graph_outputs_len, const char *const *graph_output_names)
Export a model that can be used for inferencing.
OrtStatus * GetLearningRate(OrtTrainingSession *sess, float *learning_rate)
Gets the current learning rate for this training session.
OrtStatus * TrainingSessionGetEvalModelInputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user inputs in the eval model.
OrtStatus * TrainingSessionGetEvalModelOutputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user outputs in the eval model.
OrtStatus * RegisterLinearLRScheduler(OrtTrainingSession *sess, const int64_t warmup_step_count, const int64_t total_step_count, const float initial_lr)
Registers a linear learning rate scheduler for the training session.
OrtStatus * CopyParametersToBuffer(OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only)
Copy all parameters to a contiguous buffer held by the argument parameters_buffer.
OrtStatus * GetParameterTypeAndShape(const OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtTensorTypeAndShapeInfo **parameter_type_and_shape)
Retrieves the type and shape information of the parameter associated with the given parameter name.
OrtStatus * UpdateParameter(OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtValue *parameter)
Updates the data associated with the model parameter in the checkpoint state for the given parameter ...
OrtStatus * SetLearningRate(OrtTrainingSession *sess, float learning_rate)
Sets the learning rate for this training session.
OrtStatus * TrainingSessionGetTrainingModelOutputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user outputs in the training model.
OrtStatus * TrainingSessionGetTrainingModelOutputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the names of user outputs in the training model.
OrtStatus * TrainingSessionGetTrainingModelInputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the name of the user input at given index in the training model.
OrtStatus * TrainingSessionGetEvalModelOutputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the names of user outputs in the eval model.
OrtStatus * AddProperty(OrtCheckpointState *checkpoint_state, const char *property_name, enum OrtPropertyType property_type, void *property_value)
Adds or updates the given property to/in the checkpoint state.
OrtStatus * SchedulerStep(OrtTrainingSession *sess)
Update the learning rate based on the registered learing rate scheduler.
OrtStatus * GetParametersSize(OrtTrainingSession *sess, size_t *out, bool trainable_only)
Retrieves the size of all the parameters.
OrtStatus * SetSeed(const int64_t seed)
Sets the seed used for random number generation in Onnxruntime.
OrtStatus * LoadCheckpointFromBuffer(const void *checkpoint_buffer, const size_t num_bytes, OrtCheckpointState **checkpoint_state)
Load a checkpoint state from a buffer into checkpoint_state.
OrtStatus * GetProperty(const OrtCheckpointState *checkpoint_state, const char *property_name, OrtAllocator *allocator, enum OrtPropertyType *property_type, void **property_value)
Gets the property value associated with the given name from the checkpoint state.
OrtStatus * SaveCheckpoint(OrtCheckpointState *checkpoint_state, const char *checkpoint_path, const bool include_optimizer_state)
Save the given state to a checkpoint file on disk.
OrtStatus * GetParameter(const OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtAllocator *allocator, OrtValue **parameter)
Gets the data associated with the model parameter from the checkpoint state for the given parameter n...
OrtStatus * OptimizerStep(OrtTrainingSession *sess, const OrtRunOptions *run_options)
Performs the weight updates for the trainable parameters using the optimizer model.