NetDEM v1.0
Loading...
Searching...
No Matches
17_test_ann_vs_geom_trimesh.cpp

This is an example of how to use the netdem library.

#include "general_net.hpp"
#include "igl_wrapper.hpp"
#include "mlpack_utils.hpp"
#include "particle.hpp"
#include "utils_math.hpp"
#include <fstream>
#include <iostream>
#include <random>
#include <sstream>
#include <string>
using namespace netdem;
using namespace std;
void TestANNvsGeometricTrimesh_v2() {
// load particle
TriMesh tri_mesh_1;
tri_mesh_1.InitFromSTL("data/particle_template.stl");
tri_mesh_1.Decimate(200);
tri_mesh_1.AlignAxes();
tri_mesh_1.SetSize(1.0);
Particle obj_p1 = Particle(&tri_mesh_1);
obj_p1.need_update_stl_model = true;
Particle obj_p2 = Particle(&tri_mesh_1);
obj_p2.need_update_stl_model = true;
cout << "particle created ... " << endl;
obj_p2.SetRodrigues(1.2, 0, 1, 0);
obj_p2.SetPosition(0, 0, 0.94);
SolverBooleanPP cnt_solver;
VolumeBased cnt_model;
string root_dir = "local/examples/netdem/ann_models/trimesh_trimesh/";
GeneralNet classifier;
classifier.Load(root_dir + "ann_classifier.xml", "detection");
RegressionNet regressor;
regressor.Load(root_dir + "ann_regressor.xml", "resolution");
// random generator
UniformDistribution uniform_dist(0.0, 1.0);
// random cases
for (int trial = 0; trial < 100; trial++) {
Vec3d pos;
pos[0] = obj_p2.pos[0] + (uniform_dist.Get() - 0.5) * 0.001;
pos[1] = obj_p2.pos[1] + (uniform_dist.Get() - 0.5) * 0.001;
pos[2] = obj_p2.pos[2] + (uniform_dist.Get() - 0.5) * 0.001;
double rot_angle{1.0e-4};
Vec3d rot_axis;
rot_axis[0] = uniform_dist.Get() * 2.0 - 1.0;
rot_axis[1] = uniform_dist.Get() * 2.0 - 1.0;
rot_axis[2] = uniform_dist.Get() * 2.0 - 1.0;
auto dquat = Math::Quaternion::FromRodrigues(rot_angle, rot_axis);
auto quat = Math::Quaternion::Multiply(dquat, obj_p2.quaternion);
Math::Quaternion::Normalize(&quat);
// apply to particle
obj_p2.SetPosition(pos[0], pos[1], pos[2]);
obj_p2.SetQuaternion(quat[0], quat[1], quat[2], quat[3]);
// contact detection and resolution
arma::mat input(7, 1, arma::fill::zeros);
input(0, 0) = obj_p2.pos[0];
input(1, 0) = obj_p2.pos[1];
input(2, 0) = obj_p2.pos[2];
input(3, 0) = obj_p2.quaternion[0] * Math::Sign(obj_p2.quaternion[0]);
input(4, 0) = obj_p2.quaternion[1] * Math::Sign(obj_p2.quaternion[0]);
input(5, 0) = obj_p2.quaternion[2] * Math::Sign(obj_p2.quaternion[0]);
input(6, 0) = obj_p2.quaternion[3] * Math::Sign(obj_p2.quaternion[0]);
auto output = classifier.Classify(input);
cnt_solver.Init(&obj_p1, &obj_p2);
auto cnt_flag_geo = cnt_solver.Detect();
cout << "ann vs geometric: " << output(0) << ", " << cnt_flag_geo << endl;
if (cnt_flag_geo) {
auto cnt = ContactPP(&obj_p1, &obj_p2);
cnt.SetCollisionModel(&cnt_model);
cnt_solver.ResolveInit(&cnt, 1.0e-4);
auto &cnt_geoms = cnt.collision_entries[0].cnt_geoms;
// skip the contact if overlap is too large
if (cnt_geoms.vol * cnt_geoms.sn > 6.0e-4)
continue;
auto output = regressor.Predict(input);
cout << ">> ann: " << pow(output(0, 0) / 40.0, 2.0) << ", "
<< output(1, 0) / 40.0 << ", " << output(2, 0) << ", "
<< output(3, 0) << ", " << output(4, 0) << ", " << output(5, 0)
<< ", " << output(6, 0) << ", " << output(7, 0) << endl;
cout << ">> geo: " << cnt_geoms.vol << ", " << cnt_geoms.sn << ", "
<< cnt_geoms.dir_n[0] << ", " << cnt_geoms.dir_n[1] << ", "
<< cnt_geoms.dir_n[2] << ", " << cnt_geoms.pos[0] << ", "
<< cnt_geoms.pos[1] << ", " << cnt_geoms.pos[2] << endl;
}
}
}
A class representing a contact between two particles.
Definition contact_pp.hpp:20
A class representing a general neural network.
Definition general_net.hpp:19
void Load(std::string const &filename, std::string const &label)
Loads a previously saved neural network model from a file.
Definition general_net.cpp:74
arma::mat Classify(const arma::mat &data_x)
Classifies input data based on the current neural network model.
Definition general_net.cpp:65
Definition particle.hpp:26
virtual void SetQuaternion(double q_0, double q_1, double q_2, double q_3)
Sets the orientation of the particle using a quaternion.
Definition particle.cpp:103
virtual void SetRodrigues(double angle, double axis_x, double axis_y, double axis_z)
Sets the orientation of the particle using a Rodrigues rotation vector.
Definition particle.cpp:95
virtual void SetPosition(double pos_x, double pos_y, double pos_z)
Sets the position of the particle.
Definition particle.cpp:83
Vec4d quaternion
The quaternion of the particle.
Definition particle.hpp:108
Vec3d pos
The position of the particle.
Definition particle.hpp:103
bool need_update_stl_model
Flag indicating whether STl model intersection-based contact detection and resolution is needed.
Definition particle.hpp:207
A class that represents a feedforward neural network for regression.
Definition regression_net.hpp:21
void Load(std::string const &filename, std::string const &label)
Loads the neural network model from disk.
Definition regression_net.cpp:95
arma::mat Predict(const arma::mat &data_x)
Predicts with the neural network model using input data.
Definition regression_net.cpp:62
Solver for triangle mesh contacts between two particles using boolean operations.
Definition solver_boolean_pp.hpp:18
bool Detect() override
Detects collisions between two particles using boolean operations on their triangle meshes.
Definition solver_boolean_pp.cpp:44
void ResolveInit(ContactPP *const cnt, double timestep) override
Initializes the contact point between two particles at time t = 0.
Definition solver_boolean_pp.cpp:74
void Init(Particle *const p1, Particle *const p2) override
Initializes the collision solver with two particles.
Definition solver_boolean_pp.cpp:22
A class representing a triangular mesh in 3D space.
Definition shape_trimesh.hpp:23
void InitFromSTL(std::string const &file)
Initialize the TriMesh object from an STL file.
void Decimate(int num_nodes)
Decimate the TriMesh object.
Definition shape_trimesh.cpp:121
void AlignAxes()
Align the axes of the TriMesh object.
Definition shape_trimesh.cpp:107
void SetSize(double d) override
Set the size of the TriMesh object.
Definition shape_trimesh.cpp:207
Generates random numbers from a uniform distribution.
Definition distribution_uniform.hpp:15
double Get() override
Get a single random number from the uniform distribution.
Definition distribution_uniform.hpp:58
Contact model that evaluates forces and moments based on volume overlap and relative velocity.
Definition model_volume_based.hpp:13
Definition bond_entry.hpp:7
std::array< double, 3 > Vec3d
Definition utils_macros.hpp:18