A class representing a general neural network.
More...
#include <general_net.hpp>
|
void | ResetModel () |
| Resets the current neural network model.
|
|
void | AddLayer (MLPackUtils::LayerType layer_name,...) |
| Adds a layer to the current neural network model.
|
|
void | Train (const arma::mat &data_x, const arma::mat &data_y) |
| Trains the current neural network model using the given input and output data.
|
|
arma::mat | Predict (const arma::mat &data_x) |
| Predicts output values based on input data.
|
|
arma::mat | Classify (const arma::mat &data_x) |
| Classifies input data based on the current neural network model.
|
|
arma::mat | Regress (const arma::mat &data_x) |
| Performs regression on input data using the current neural network model.
|
|
void | Load (std::string const &filename, std::string const &label) |
| Loads a previously saved neural network model from a file.
|
|
void | Save (std::string const &filename, std::string const &label) |
| Saves the current neural network model to a file.
|
|
|
mlpack::FFN | model |
| The ANN model.
|
|
double | step_size {0.01} |
| Machine learning hyper-parameters.
|
|
int | batch_size {32} |
| Machine learning hyper-parameters.
|
|
double | decay_rate_moment {0.9} |
| Machine learning hyper-parameters.
|
|
double | decay_rate_norm {0.9} |
| Machine learning hyper-parameters.
|
|
double | gradient_init_param {1e-8} |
| Machine learning hyper-parameters.
|
|
int | epochs {100} |
| Machine learning hyper-parameters.
|
|
double | stop_tol {1e-8} |
| Machine learning hyper-parameters.
|
|
bool | enable_logging {true} |
| A flag indicating whether to log info.
|
|
A class representing a general neural network.
The GeneralNet class is used to represent a general neural network, and provides methods for adding layers, training the network, and making predictions.
- Examples
- 02_train_classifier_trimesh_plane.cpp, 03_test_classifier_trimesh_plane.cpp, 06_test_ann_vs_geom_trimesh_plane.cpp, 12_train_classifier_trimesh.cpp, 13_test_classifier_trimesh.cpp, 16_test_ann_vs_geom_trimesh.cpp, 17_test_ann_vs_geom_trimesh.cpp, 22_train_classifier_ellipsoid_plane.cpp, 23_test_classifier_ellipsoid_plane.cpp, 26_test_ann_vs_geom_ellipsoid_plane.cpp, 32_train_classifier_ellipsoid.cpp, 33_test_classifier_ellipsoid.cpp, 36_test_ann_vs_geom_ellipsoid.cpp, and 37_test_ann_vs_geom_ellipsoid.cpp.
◆ AddLayer()
◆ Classify()
arma::mat netdem::GeneralNet::Classify |
( |
const arma::mat & | data_x | ) |
|
Classifies input data based on the current neural network model.
- Parameters
-
data_x | The input data to classify. |
- Returns
- The classification results.
- Examples
- 02_train_classifier_trimesh_plane.cpp, 03_test_classifier_trimesh_plane.cpp, 06_test_ann_vs_geom_trimesh_plane.cpp, 12_train_classifier_trimesh.cpp, 13_test_classifier_trimesh.cpp, 16_test_ann_vs_geom_trimesh.cpp, 17_test_ann_vs_geom_trimesh.cpp, 22_train_classifier_ellipsoid_plane.cpp, 23_test_classifier_ellipsoid_plane.cpp, 26_test_ann_vs_geom_ellipsoid_plane.cpp, 32_train_classifier_ellipsoid.cpp, 33_test_classifier_ellipsoid.cpp, 36_test_ann_vs_geom_ellipsoid.cpp, and 37_test_ann_vs_geom_ellipsoid.cpp.
◆ Load()
void netdem::GeneralNet::Load |
( |
std::string const & | filename, |
|
|
std::string const & | label ) |
◆ Predict()
arma::mat netdem::GeneralNet::Predict |
( |
const arma::mat & | data_x | ) |
|
Predicts output values based on input data.
- Parameters
-
data_x | The input data to use for prediction. |
- Returns
- The predicted output values.
◆ Regress()
arma::mat netdem::GeneralNet::Regress |
( |
const arma::mat & | data_x | ) |
|
Performs regression on input data using the current neural network model.
- Parameters
-
data_x | The input data to use for regression. |
- Returns
- The regression results.
◆ ResetModel()
void netdem::GeneralNet::ResetModel |
( |
| ) |
|
Resets the current neural network model.
◆ Save()
void netdem::GeneralNet::Save |
( |
std::string const & | filename, |
|
|
std::string const & | label ) |
◆ Train()
void netdem::GeneralNet::Train |
( |
const arma::mat & | data_x, |
|
|
const arma::mat & | data_y ) |
◆ batch_size
int netdem::GeneralNet::batch_size {32} |
◆ decay_rate_moment
double netdem::GeneralNet::decay_rate_moment {0.9} |
◆ decay_rate_norm
double netdem::GeneralNet::decay_rate_norm {0.9} |
◆ enable_logging
bool netdem::GeneralNet::enable_logging {true} |
A flag indicating whether to log info.
◆ epochs
int netdem::GeneralNet::epochs {100} |
◆ gradient_init_param
double netdem::GeneralNet::gradient_init_param {1e-8} |
Machine learning hyper-parameters.
◆ model
mlpack::FFN netdem::GeneralNet::model |
◆ step_size
double netdem::GeneralNet::step_size {0.01} |
◆ stop_tol
double netdem::GeneralNet::stop_tol {1e-8} |
The documentation for this class was generated from the following files:
- /Users/lzhshou/Documents/Research/myProjects/apaam/repo/netdem/src/mlpack/general_net.hpp
- /Users/lzhshou/Documents/Research/myProjects/apaam/repo/netdem/src/mlpack/general_net.cpp