NetDEM v1.0
Loading...
Searching...
No Matches
general_net.hpp
Go to the documentation of this file.
1#pragma once
2
3// netdem
4#include "mlpack_utils.hpp"
5
6// std
7#include <cstdarg>
8#include <string>
9
10namespace netdem {
11
20public:
21 mlpack::FFN<> model;
22
23 double step_size{0.01};
24 int batch_size{32};
25 double decay_rate_moment{0.9};
26 double decay_rate_norm{0.9};
27 double gradient_init_param{1e-8};
28 int epochs{100};
29 double stop_tol{1e-8};
30
31 bool enable_logging{true};
32
36 void ResetModel();
37
47 void AddLayer(MLPackUtils::LayerType layer_name, ...);
48
56 void Train(const arma::mat &data_x, const arma::mat &data_y);
57
64 arma::mat Predict(const arma::mat &data_x);
65
72 arma::mat Classify(const arma::mat &data_x);
73
81 arma::mat Regress(const arma::mat &data_x);
82
89 void Load(std::string const &filename, std::string const &label);
90
97 void Save(std::string const &filename, std::string const &label);
98};
99
100} // namespace netdem
A class representing a general neural network.
Definition general_net.hpp:19
void Load(std::string const &filename, std::string const &label)
Loads a previously saved neural network model from a file.
Definition general_net.cpp:74
void Train(const arma::mat &data_x, const arma::mat &data_y)
Trains the current neural network model using the given input and output data.
Definition general_net.cpp:48
void ResetModel()
Resets the current neural network model.
Definition general_net.cpp:10
bool enable_logging
A flag indicating whether to log info.
Definition general_net.hpp:31
void Save(std::string const &filename, std::string const &label)
Saves the current neural network model to a file.
Definition general_net.cpp:78
arma::mat Predict(const arma::mat &data_x)
Predicts output values based on input data.
Definition general_net.cpp:59
void AddLayer(MLPackUtils::LayerType layer_name,...)
Adds a layer to the current neural network model.
Definition general_net.cpp:12
double gradient_init_param
Machine learning hyper-parameters.
Definition general_net.hpp:27
mlpack::FFN model
The ANN model.
Definition general_net.hpp:21
int batch_size
Machine learning hyper-parameters.
Definition general_net.hpp:24
int epochs
Machine learning hyper-parameters.
Definition general_net.hpp:28
double decay_rate_norm
Machine learning hyper-parameters.
Definition general_net.hpp:26
double decay_rate_moment
Machine learning hyper-parameters.
Definition general_net.hpp:25
double stop_tol
Machine learning hyper-parameters.
Definition general_net.hpp:29
double step_size
Machine learning hyper-parameters.
Definition general_net.hpp:23
arma::mat Regress(const arma::mat &data_x)
Performs regression on input data using the current neural network model.
Definition general_net.cpp:70
arma::mat Classify(const arma::mat &data_x)
Classifies input data based on the current neural network model.
Definition general_net.cpp:65
LayerType
An enumeration of neural network layer types.
Definition mlpack_utils.hpp:19
Definition bond_entry.hpp:7