A class representing a general neural network.
More...
#include <general_net.hpp>
|
| void | ResetModel () |
| | Resets the current neural network model.
|
| |
| void | AddLayer (MLPackUtils::LayerType layer_name,...) |
| | Adds a layer to the current neural network model.
|
| |
| void | Train (const arma::mat &data_x, const arma::mat &data_y) |
| | Trains the current neural network model using the given input and output data.
|
| |
| arma::mat | Predict (const arma::mat &data_x) |
| | Predicts output values based on input data.
|
| |
| arma::mat | Classify (const arma::mat &data_x) |
| | Classifies input data based on the current neural network model.
|
| |
| arma::mat | Regress (const arma::mat &data_x) |
| | Performs regression on input data using the current neural network model.
|
| |
| void | Load (std::string const &filename, std::string const &label) |
| | Loads a previously saved neural network model from a file.
|
| |
| void | Save (std::string const &filename, std::string const &label) |
| | Saves the current neural network model to a file.
|
| |
|
| mlpack::FFN | model |
| | The ANN model.
|
| |
| double | step_size {0.01} |
| | Machine learning hyper-parameters.
|
| |
| int | batch_size {32} |
| | Machine learning hyper-parameters.
|
| |
| double | decay_rate_moment {0.9} |
| | Machine learning hyper-parameters.
|
| |
| double | decay_rate_norm {0.9} |
| | Machine learning hyper-parameters.
|
| |
| double | gradient_init_param {1e-8} |
| | Machine learning hyper-parameters.
|
| |
| int | epochs {100} |
| | Machine learning hyper-parameters.
|
| |
| double | stop_tol {1e-8} |
| | Machine learning hyper-parameters.
|
| |
| bool | enable_logging {true} |
| | A flag indicating whether to log info.
|
| |
A class representing a general neural network.
The GeneralNet class is used to represent a general neural network, and provides methods for adding layers, training the network, and making predictions.
- Examples
- 02_train_classifier_trimesh_plane.cpp, 03_test_classifier_trimesh_plane.cpp, 06_test_ann_vs_geom_trimesh_plane.cpp, 12_train_classifier_trimesh.cpp, 13_test_classifier_trimesh.cpp, 16_test_ann_vs_geom_trimesh.cpp, 17_test_ann_vs_geom_trimesh.cpp, 22_train_classifier_ellipsoid_plane.cpp, 23_test_classifier_ellipsoid_plane.cpp, 26_test_ann_vs_geom_ellipsoid_plane.cpp, 32_train_classifier_ellipsoid.cpp, 33_test_classifier_ellipsoid.cpp, 36_test_ann_vs_geom_ellipsoid.cpp, and 37_test_ann_vs_geom_ellipsoid.cpp.
◆ AddLayer()
◆ Classify()
| arma::mat netdem::GeneralNet::Classify |
( |
const arma::mat & | data_x | ) |
|
Classifies input data based on the current neural network model.
- Parameters
-
| data_x | The input data to classify. |
- Returns
- The classification results.
- Examples
- 02_train_classifier_trimesh_plane.cpp, 03_test_classifier_trimesh_plane.cpp, 06_test_ann_vs_geom_trimesh_plane.cpp, 12_train_classifier_trimesh.cpp, 13_test_classifier_trimesh.cpp, 16_test_ann_vs_geom_trimesh.cpp, 17_test_ann_vs_geom_trimesh.cpp, 22_train_classifier_ellipsoid_plane.cpp, 23_test_classifier_ellipsoid_plane.cpp, 26_test_ann_vs_geom_ellipsoid_plane.cpp, 32_train_classifier_ellipsoid.cpp, 33_test_classifier_ellipsoid.cpp, 36_test_ann_vs_geom_ellipsoid.cpp, and 37_test_ann_vs_geom_ellipsoid.cpp.
◆ Load()
| void netdem::GeneralNet::Load |
( |
std::string const & | filename, |
|
|
std::string const & | label ) |
◆ Predict()
| arma::mat netdem::GeneralNet::Predict |
( |
const arma::mat & | data_x | ) |
|
Predicts output values based on input data.
- Parameters
-
| data_x | The input data to use for prediction. |
- Returns
- The predicted output values.
◆ Regress()
| arma::mat netdem::GeneralNet::Regress |
( |
const arma::mat & | data_x | ) |
|
Performs regression on input data using the current neural network model.
- Parameters
-
| data_x | The input data to use for regression. |
- Returns
- The regression results.
◆ ResetModel()
| void netdem::GeneralNet::ResetModel |
( |
| ) |
|
Resets the current neural network model.
◆ Save()
| void netdem::GeneralNet::Save |
( |
std::string const & | filename, |
|
|
std::string const & | label ) |
◆ Train()
| void netdem::GeneralNet::Train |
( |
const arma::mat & | data_x, |
|
|
const arma::mat & | data_y ) |
◆ batch_size
| int netdem::GeneralNet::batch_size {32} |
◆ decay_rate_moment
| double netdem::GeneralNet::decay_rate_moment {0.9} |
◆ decay_rate_norm
| double netdem::GeneralNet::decay_rate_norm {0.9} |
◆ enable_logging
| bool netdem::GeneralNet::enable_logging {true} |
A flag indicating whether to log info.
◆ epochs
| int netdem::GeneralNet::epochs {100} |
◆ gradient_init_param
| double netdem::GeneralNet::gradient_init_param {1e-8} |
Machine learning hyper-parameters.
◆ model
| mlpack::FFN netdem::GeneralNet::model |
◆ step_size
| double netdem::GeneralNet::step_size {0.01} |
◆ stop_tol
| double netdem::GeneralNet::stop_tol {1e-8} |
The documentation for this class was generated from the following files:
- /Users/lzhshou/Documents/Research/myProjects/apaam/repo/netdem/src/mlpack/general_net.hpp
- /Users/lzhshou/Documents/Research/myProjects/apaam/repo/netdem/src/mlpack/general_net.cpp