ONNX Runtime
Loading...
Searching...
No Matches
onnxruntime_training_cxx_api.h
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include "onnxruntime_training_c_api.h"
6#include <optional>
7#include <variant>
8
9namespace Ort::detail {
10
11#define ORT_DECLARE_TRAINING_RELEASE(NAME) \
12 void OrtRelease(Ort##NAME* ptr);
13
14// These release methods must be forward declared before including onnxruntime_cxx_api.h
15// otherwise class Base won't be aware of them
16ORT_DECLARE_TRAINING_RELEASE(CheckpointState);
17ORT_DECLARE_TRAINING_RELEASE(TrainingSession);
18
19} // namespace Ort::detail
20
21#include "onnxruntime_cxx_api.h"
22
23namespace Ort {
24
31
32namespace detail {
33
34#define ORT_DEFINE_TRAINING_RELEASE(NAME) \
35 inline void OrtRelease(Ort##NAME* ptr) { GetTrainingApi().Release##NAME(ptr); }
36
37ORT_DEFINE_TRAINING_RELEASE(CheckpointState);
38ORT_DEFINE_TRAINING_RELEASE(TrainingSession);
39
40#undef ORT_DECLARE_TRAINING_RELEASE
41#undef ORT_DEFINE_TRAINING_RELEASE
42
43} // namespace detail
44
45using Property = std::variant<int64_t, float, std::string>;
46
63class CheckpointState : public detail::Base<OrtCheckpointState> {
64 private:
65 CheckpointState(OrtCheckpointState* checkpoint_state) { p_ = checkpoint_state; }
66
67 public:
68 // Construct the checkpoint state by loading the checkpoint by calling LoadCheckpoint
69 CheckpointState() = delete;
70
73
85 static CheckpointState LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint);
86
98 static CheckpointState LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer);
99
111 static void SaveCheckpoint(const CheckpointState& checkpoint_state,
112 const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
113 const bool include_optimizer_state = false);
114
125 void AddProperty(const std::string& property_name, const Property& property_value);
126
136 Property GetProperty(const std::string& property_name);
137
149 void UpdateParameter(const std::string& parameter_name, const Value& parameter);
150
162 Value GetParameter(const std::string& parameter_name);
163
165};
166
178class TrainingSession : public detail::Base<OrtTrainingSession> {
179 private:
180 size_t training_model_output_count_, eval_model_output_count_;
181
182 public:
185
200 TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
201 const std::basic_string<ORTCHAR_T>& train_model_path,
202 const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path = std::nullopt,
203 const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path = std::nullopt);
204
216 TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
217 const std::vector<uint8_t>& train_model_data, const std::vector<uint8_t>& eval_model_data = {},
218 const std::vector<uint8_t>& optim_model_data = {});
220
223
239 std::vector<Value> TrainStep(const std::vector<Value>& input_values);
240
249
259 std::vector<Value> EvalStep(const std::vector<Value>& input_values);
260
276 void SetLearningRate(float learning_rate);
277
287 float GetLearningRate() const;
288
301 void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count,
302 float initial_lr);
303
314
325
327
330
344 void ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
345 const std::vector<std::string>& graph_output_names);
346
348
351
361 std::vector<std::string> InputNames(const bool training);
362
373 std::vector<std::string> OutputNames(const bool training);
374
376
379
386 Value ToBuffer(const bool only_trainable);
387
392 void FromBuffer(Value& buffer);
393
395};
396
399
406void SetSeed(const int64_t seed);
408
410
411} // namespace Ort
412
413#include "onnxruntime_training_cxx_inline.h"
Holds the state of the training session.
Definition onnxruntime_training_cxx_api.h:63
Value GetParameter(const std::string &parameter_name)
Gets the data associated with the model parameter from the checkpoint state for the given parameter n...
static CheckpointState LoadCheckpointFromBuffer(const std::vector< uint8_t > &buffer)
Load a checkpoint state from a buffer.
void AddProperty(const std::string &property_name, const Property &property_value)
Adds or updates the given property to/in the checkpoint state.
static CheckpointState LoadCheckpoint(const std::basic_string< char > &path_to_checkpoint)
Load a checkpoint state from a file on disk into checkpoint_state.
void UpdateParameter(const std::string &parameter_name, const Value &parameter)
Updates the data associated with the model parameter in the checkpoint state for the given parameter ...
static void SaveCheckpoint(const CheckpointState &checkpoint_state, const std::basic_string< char > &path_to_checkpoint, const bool include_optimizer_state=false)
Save the given state to a checkpoint file on disk.
Property GetProperty(const std::string &property_name)
Gets the property value associated with the given name from the checkpoint state.
Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.
Definition onnxruntime_training_cxx_api.h:178
void OptimizerStep()
Performs the weight updates for the trainable parameters using the optimizer model.
void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count, float initial_lr)
Registers a linear learning rate scheduler for the training session.
std::vector< Value > EvalStep(const std::vector< Value > &input_values)
Computes the outputs for the eval model for the given inputs.
std::vector< std::string > InputNames(const bool training)
Retrieves the names of the user inputs for the training and eval models.
float GetLearningRate() const
Gets the current learning rate for this training session.
void ExportModelForInferencing(const std::basic_string< char > &inference_model_path, const std::vector< std::string > &graph_output_names)
Export a model that can be used for inferencing.
void LazyResetGrad()
Reset the gradients of all trainable parameters to zero lazily.
Value ToBuffer(const bool only_trainable)
Returns a contiguous buffer that holds a copy of all training state parameters.
std::vector< Value > TrainStep(const std::vector< Value > &input_values)
Computes the outputs of the training model and the gradients of the trainable parameters for the give...
TrainingSession(const Env &env, const SessionOptions &session_options, CheckpointState &checkpoint_state, const std::basic_string< char > &train_model_path, const std::optional< std::basic_string< char > > &eval_model_path=std::nullopt, const std::optional< std::basic_string< char > > &optimizer_model_path=std::nullopt)
Create a training session that can be used to begin or resume training.
void SchedulerStep()
Update the learning rate based on the registered learing rate scheduler.
TrainingSession(const Env &env, const SessionOptions &session_options, CheckpointState &checkpoint_state, const std::vector< uint8_t > &train_model_data, const std::vector< uint8_t > &eval_model_data={}, const std::vector< uint8_t > &optim_model_data={})
Create a training session that can be used to begin or resume training. This constructor allows the u...
std::vector< std::string > OutputNames(const bool training)
Retrieves the names of the user outputs for the training and eval models.
void FromBuffer(Value &buffer)
Loads the training session model parameters from a contiguous buffer.
void SetLearningRate(float learning_rate)
Sets the learning rate for this training session.
#define ORT_API_VERSION
The API version defined in this header.
Definition onnxruntime_c_api.h:41
struct OrtCheckpointState OrtCheckpointState
Definition onnxruntime_training_c_api.h:105
void SetSeed(const int64_t seed)
This function sets the seed for generating random numbers.
Definition onnxruntime_cxx_api.h:499
All C++ Onnxruntime APIs are defined inside this namespace.
Definition onnxruntime_cxx_api.h:47
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
Definition onnxruntime_cxx_api.h:124
std::variant< int64_t, float, std::string > Property
Definition onnxruntime_training_cxx_api.h:45
const OrtTrainingApi & GetTrainingApi()
This function returns the C training api struct with the pointers to the ort training C functions....
Definition onnxruntime_training_cxx_api.h:30
The Env (Environment)
Definition onnxruntime_cxx_api.h:697
Wrapper around OrtSessionOptions.
Definition onnxruntime_cxx_api.h:913
Wrapper around OrtValue.
Definition onnxruntime_cxx_api.h:1608
Used internally by the C++ API. C++ wrapper types inherit from this. This is a zero cost abstraction ...
Definition onnxruntime_cxx_api.h:556
contained_type * p_
Definition onnxruntime_cxx_api.h:584
const OrtTrainingApi *(* GetTrainingApi)(uint32_t version)
Gets the Training C Api struct.
Definition onnxruntime_c_api.h:3685
The Training C API that holds onnxruntime training function pointers.
Definition onnxruntime_training_c_api.h:122