56 void Train(
const arma::mat &data_x,
const arma::mat &data_y);
64 arma::mat
Predict(
const arma::mat &data_x);
72 arma::mat
Classify(
const arma::mat &data_x);
81 arma::mat
Regress(
const arma::mat &data_x);
89 void Load(std::string
const &filename, std::string
const &label);
97 void Save(std::string
const &filename, std::string
const &label);
A class representing a general neural network.
Definition general_net.hpp:19
void Load(std::string const &filename, std::string const &label)
Loads a previously saved neural network model from a file.
Definition general_net.cpp:74
void Train(const arma::mat &data_x, const arma::mat &data_y)
Trains the current neural network model using the given input and output data.
Definition general_net.cpp:48
void ResetModel()
Resets the current neural network model.
Definition general_net.cpp:10
bool enable_logging
A flag indicating whether to log info.
Definition general_net.hpp:31
void Save(std::string const &filename, std::string const &label)
Saves the current neural network model to a file.
Definition general_net.cpp:78
arma::mat Predict(const arma::mat &data_x)
Predicts output values based on input data.
Definition general_net.cpp:59
void AddLayer(MLPackUtils::LayerType layer_name,...)
Adds a layer to the current neural network model.
Definition general_net.cpp:12
double gradient_init_param
Machine learning hyper-parameters.
Definition general_net.hpp:27
mlpack::FFN model
The ANN model.
Definition general_net.hpp:21
int batch_size
Machine learning hyper-parameters.
Definition general_net.hpp:24
int epochs
Machine learning hyper-parameters.
Definition general_net.hpp:28
double decay_rate_norm
Machine learning hyper-parameters.
Definition general_net.hpp:26
double decay_rate_moment
Machine learning hyper-parameters.
Definition general_net.hpp:25
double stop_tol
Machine learning hyper-parameters.
Definition general_net.hpp:29
double step_size
Machine learning hyper-parameters.
Definition general_net.hpp:23
arma::mat Regress(const arma::mat &data_x)
Performs regression on input data using the current neural network model.
Definition general_net.cpp:70
arma::mat Classify(const arma::mat &data_x)
Classifies input data based on the current neural network model.
Definition general_net.cpp:65
LayerType
An enumeration of neural network layer types.
Definition mlpack_utils.hpp:19
Definition bond_entry.hpp:7