ONNX Runtime
Loading...
Searching...
No Matches
onnxruntime_training_c_api.h
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4// This file contains the training c apis.
5
6#pragma once
7#include <stdbool.h>
8#include "onnxruntime_c_api.h"
9
104ORT_RUNTIME_CLASS(TrainingSession); // Type that enables performing training for the given user models.
105ORT_RUNTIME_CLASS(CheckpointState); // Type that holds the training states for the training session.
106
114
125
142 ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
143 _Outptr_ OrtCheckpointState** checkpoint_state);
144
158 ORT_API2_STATUS(SaveCheckpoint, _In_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* checkpoint_path,
159 const bool include_optimizer_state);
160
162
165
190 ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
191 _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
192 _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
193 _Outptr_result_maybenull_ OrtTrainingSession** out);
194
210 ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env,
211 _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state,
212 _In_ const void* train_model_data, size_t train_data_length,
213 _In_ const void* eval_model_data, size_t eval_data_length,
214 _In_ const void* optim_model_data, size_t optim_data_length,
215 _Outptr_result_maybenull_ OrtTrainingSession** out);
216
218
221
233 ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
234
246 ORT_API2_STATUS(TrainingSessionGetEvalModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
247
261 ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
262
276 ORT_API2_STATUS(TrainingSessionGetEvalModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
277
279
282
294 ORT_API2_STATUS(LazyResetGrad, _Inout_ OrtTrainingSession* session);
295
317 ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
318 _In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
319 _In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
320
336 ORT_API2_STATUS(EvalStep, _In_ const OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
337 _In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
338 _In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
339
358 ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
359
372 ORT_API2_STATUS(GetLearningRate, _Inout_ OrtTrainingSession* sess, _Out_ float* learning_rate);
373
388 ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
389 _In_opt_ const OrtRunOptions* run_options);
390
406 ORT_API2_STATUS(RegisterLinearLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ const int64_t warmup_step_count,
407 _In_ const int64_t total_step_count, _In_ const float initial_lr);
408
422 ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess);
423
425
428
441 ORT_API2_STATUS(GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only);
442
459 ORT_API2_STATUS(CopyParametersToBuffer, _Inout_ OrtTrainingSession* sess,
460 _Inout_ OrtValue* parameters_buffer, bool trainable_only);
461
478 ORT_API2_STATUS(CopyBufferToParameters, _Inout_ OrtTrainingSession* sess,
479 _Inout_ OrtValue* parameters_buffer, bool trainable_only);
480
482
485
492 ORT_CLASS_RELEASE(TrainingSession);
493
501 ORT_CLASS_RELEASE(CheckpointState);
502
504
507
524 ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
525 _In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
526 _In_reads_(graph_outputs_len) const char* const* graph_output_names);
527
529
532
542 ORT_API2_STATUS(SetSeed, _In_ const int64_t seed);
543
545
548
559 ORT_API2_STATUS(TrainingSessionGetTrainingModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
560
572 ORT_API2_STATUS(TrainingSessionGetEvalModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
573
587 ORT_API2_STATUS(TrainingSessionGetTrainingModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
588 _In_ OrtAllocator* allocator, _Outptr_ char** output);
589
603 ORT_API2_STATUS(TrainingSessionGetEvalModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
604 _In_ OrtAllocator* allocator, _Outptr_ char** output);
605
607
610
625 ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state,
626 _In_ const char* property_name, _In_ enum OrtPropertyType property_type,
627 _In_ void* property_value);
628
643 ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state,
644 _In_ const char* property_name, _Inout_ OrtAllocator* allocator,
645 _Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
646
648
651
669 ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
670 _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
671
684 ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
685 _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
686
701 ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
702 _In_ const char* parameter_name, _In_ OrtValue* parameter);
703
719 ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
720 _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
721 _Outptr_ OrtValue** parameter);
722
724};
725
727
struct OrtTensorTypeAndShapeInfo OrtTensorTypeAndShapeInfo
Definition onnxruntime_c_api.h:285
struct OrtRunOptions OrtRunOptions
Definition onnxruntime_c_api.h:283
struct OrtSessionOptions OrtSessionOptions
Definition onnxruntime_c_api.h:289
struct OrtValue OrtValue
Definition onnxruntime_c_api.h:282
struct OrtEnv OrtEnv
Definition onnxruntime_c_api.h:277
struct OrtTrainingSession OrtTrainingSession
Definition onnxruntime_training_c_api.h:104
struct OrtCheckpointState OrtCheckpointState
Definition onnxruntime_training_c_api.h:105
OrtPropertyType
Type of property to be added to or returned from the OrtCheckpointState.
Definition onnxruntime_training_c_api.h:109
@ OrtIntProperty
Definition onnxruntime_training_c_api.h:110
@ OrtStringProperty
Definition onnxruntime_training_c_api.h:112
@ OrtFloatProperty
Definition onnxruntime_training_c_api.h:111
Memory allocation interface.
Definition onnxruntime_c_api.h:317
The Training C API that holds onnxruntime training function pointers.
Definition onnxruntime_training_c_api.h:122
OrtStatus * CreateTrainingSessionFromBuffer(const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const void *train_model_data, size_t train_data_length, const void *eval_model_data, size_t eval_data_length, const void *optim_model_data, size_t optim_data_length, OrtTrainingSession **out)
Create a training session that can be used to begin or resume training. This api provides a way to lo...
OrtStatus * CopyBufferToParameters(OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only)
Copy parameter values from the given contiguous buffer held by parameters_buffer to the training stat...
OrtStatus * EvalStep(const OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
Computes the outputs for the eval model for the given inputs.
OrtStatus * LazyResetGrad(OrtTrainingSession *session)
Reset the gradients of all trainable parameters to zero lazily.
OrtStatus * TrainingSessionGetEvalModelInputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the name of the user input at given index in the eval model.
OrtStatus * CreateTrainingSession(const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const char *train_model_path, const char *eval_model_path, const char *optimizer_model_path, OrtTrainingSession **out)
Create a training session that can be used to begin or resume training.
OrtStatus * TrainingSessionGetTrainingModelInputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user inputs in the training model.
OrtStatus * LoadCheckpoint(const char *checkpoint_path, OrtCheckpointState **checkpoint_state)
Load a checkpoint state from a file on disk into checkpoint_state.
OrtStatus * TrainStep(OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
Computes the outputs of the training model and the gradients of the trainable parameters for the give...
OrtStatus * ExportModelForInferencing(OrtTrainingSession *sess, const char *inference_model_path, size_t graph_outputs_len, const char *const *graph_output_names)
Export a model that can be used for inferencing.
OrtStatus * GetLearningRate(OrtTrainingSession *sess, float *learning_rate)
Gets the current learning rate for this training session.
OrtStatus * TrainingSessionGetEvalModelInputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user inputs in the eval model.
OrtStatus * TrainingSessionGetEvalModelOutputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user outputs in the eval model.
OrtStatus * RegisterLinearLRScheduler(OrtTrainingSession *sess, const int64_t warmup_step_count, const int64_t total_step_count, const float initial_lr)
Registers a linear learning rate scheduler for the training session.
OrtStatus * CopyParametersToBuffer(OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only)
Copy all parameters to a contiguous buffer held by the argument parameters_buffer.
OrtStatus * GetParameterTypeAndShape(const OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtTensorTypeAndShapeInfo **parameter_type_and_shape)
Retrieves the type and shape information of the parameter associated with the given parameter name.
OrtStatus * UpdateParameter(OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtValue *parameter)
Updates the data associated with the model parameter in the checkpoint state for the given parameter ...
OrtStatus * SetLearningRate(OrtTrainingSession *sess, float learning_rate)
Sets the learning rate for this training session.
OrtStatus * TrainingSessionGetTrainingModelOutputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user outputs in the training model.
OrtStatus * TrainingSessionGetTrainingModelOutputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the names of user outputs in the training model.
OrtStatus * TrainingSessionGetTrainingModelInputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the name of the user input at given index in the training model.
OrtStatus * TrainingSessionGetEvalModelOutputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the names of user outputs in the eval model.
OrtStatus * AddProperty(OrtCheckpointState *checkpoint_state, const char *property_name, enum OrtPropertyType property_type, void *property_value)
Adds or updates the given property to/in the checkpoint state.
OrtStatus * SchedulerStep(OrtTrainingSession *sess)
Update the learning rate based on the registered learing rate scheduler.
OrtStatus * GetParametersSize(OrtTrainingSession *sess, size_t *out, bool trainable_only)
Retrieves the size of all the parameters.
OrtStatus * SetSeed(const int64_t seed)
Sets the seed used for random number generation in Onnxruntime.
OrtStatus * LoadCheckpointFromBuffer(const void *checkpoint_buffer, const size_t num_bytes, OrtCheckpointState **checkpoint_state)
Load a checkpoint state from a buffer into checkpoint_state.
OrtStatus * GetProperty(const OrtCheckpointState *checkpoint_state, const char *property_name, OrtAllocator *allocator, enum OrtPropertyType *property_type, void **property_value)
Gets the property value associated with the given name from the checkpoint state.
OrtStatus * SaveCheckpoint(OrtCheckpointState *checkpoint_state, const char *checkpoint_path, const bool include_optimizer_state)
Save the given state to a checkpoint file on disk.
OrtStatus * GetParameter(const OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtAllocator *allocator, OrtValue **parameter)
Gets the data associated with the model parameter from the checkpoint state for the given parameter n...
OrtStatus * OptimizerStep(OrtTrainingSession *sess, const OrtRunOptions *run_options)
Performs the weight updates for the trainable parameters using the optimizer model.