ml.h
上传用户:soukeisyuu
上传日期:2022-07-03
资源大小:5943k
文件大小:71k
源码类别:

波变换

开发平台:

Visual C++

  1. /*M///////////////////////////////////////////////////////////////////////////////////////
  2. //
  3. //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
  4. //
  5. //  By downloading, copying, installing or using the software you agree to this license.
  6. //  If you do not agree to this license, do not download, install,
  7. //  copy or use the software.
  8. //
  9. //
  10. //                        Intel License Agreement
  11. //
  12. // Copyright (C) 2000, Intel Corporation, all rights reserved.
  13. // Third party copyrights are property of their respective owners.
  14. //
  15. // Redistribution and use in source and binary forms, with or without modification,
  16. // are permitted provided that the following conditions are met:
  17. //
  18. //   * Redistribution's of source code must retain the above copyright notice,
  19. //     this list of conditions and the following disclaimer.
  20. //
  21. //   * Redistribution's in binary form must reproduce the above copyright notice,
  22. //     this list of conditions and the following disclaimer in the documentation
  23. //     and/or other materials provided with the distribution.
  24. //
  25. //   * The name of Intel Corporation may not be used to endorse or promote products
  26. //     derived from this software without specific prior written permission.
  27. //
  28. // This software is provided by the copyright holders and contributors "as is" and
  29. // any express or implied warranties, including, but not limited to, the implied
  30. // warranties of merchantability and fitness for a particular purpose are disclaimed.
  31. // In no event shall the Intel Corporation or contributors be liable for any direct,
  32. // indirect, incidental, special, exemplary, or consequential damages
  33. // (including, but not limited to, procurement of substitute goods or services;
  34. // loss of use, data, or profits; or business interruption) however caused
  35. // and on any theory of liability, whether in contract, strict liability,
  36. // or tort (including negligence or otherwise) arising in any way out of
  37. // the use of this software, even if advised of the possibility of such damage.
  38. //
  39. //M*/
  40. #ifndef __ML_H__
  41. #define __ML_H__
  42. // disable deprecation warning which appears in VisualStudio 8.0
  43. #if _MSC_VER >= 1400
  44. #pragma warning( disable : 4996 )
  45. #endif
  46. #ifndef SKIP_INCLUDES
  47.   #include "cxcore.h"
  48.   #include <limits.h>
  49.   #if defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64
  50.     #include <windows.h>
  51.   #endif
  52. #else // SKIP_INCLUDES
  53.   #if defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64
  54.     #define CV_CDECL __cdecl
  55.     #define CV_STDCALL __stdcall
  56.   #else
  57.     #define CV_CDECL
  58.     #define CV_STDCALL
  59.   #endif
  60.   #ifndef CV_EXTERN_C
  61.     #ifdef __cplusplus
  62.       #define CV_EXTERN_C extern "C"
  63.       #define CV_DEFAULT(val) = val
  64.     #else
  65.       #define CV_EXTERN_C
  66.       #define CV_DEFAULT(val)
  67.     #endif
  68.   #endif
  69.   #ifndef CV_EXTERN_C_FUNCPTR
  70.     #ifdef __cplusplus
  71.       #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; }
  72.     #else
  73.       #define CV_EXTERN_C_FUNCPTR(x) typedef x
  74.     #endif
  75.   #endif
  76.   #ifndef CV_INLINE
  77.     #if defined __cplusplus
  78.       #define CV_INLINE inline
  79.     #elif (defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64) && !defined __GNUC__
  80.       #define CV_INLINE __inline
  81.     #else
  82.       #define CV_INLINE static
  83.     #endif
  84.   #endif /* CV_INLINE */
  85.   #if (defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64) && defined CVAPI_EXPORTS
  86.     #define CV_EXPORTS __declspec(dllexport)
  87.   #else
  88.     #define CV_EXPORTS
  89.   #endif
  90.   #ifndef CVAPI
  91.     #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL
  92.   #endif
  93. #endif // SKIP_INCLUDES
  94. #ifdef __cplusplus
  95. // Apple defines a check() macro somewhere in the debug headers
  96. // that interferes with a method definiton in this header
  97. using namespace std;
  98. #undef check
  99. /****************************************************************************************
  100. *                               Main struct definitions                                  *
  101. ****************************************************************************************/
  102. /* log(2*PI) */
  103. #define CV_LOG2PI (1.8378770664093454835606594728112)
  104. /* columns of <trainData> matrix are training samples */
  105. #define CV_COL_SAMPLE 0
  106. /* rows of <trainData> matrix are training samples */
  107. #define CV_ROW_SAMPLE 1
  108. #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
  109. struct CvVectors
  110. {
  111.     int type;
  112.     int dims, count;
  113.     CvVectors* next;
  114.     union
  115.     {
  116.         uchar** ptr;
  117.         float** fl;
  118.         double** db;
  119.     } data;
  120. };
  121. #if 0
  122. /* A structure, representing the lattice range of statmodel parameters.
  123.    It is used for optimizing statmodel parameters by cross-validation method.
  124.    The lattice is logarithmic, so <step> must be greater then 1. */
  125. typedef struct CvParamLattice
  126. {
  127.     double min_val;
  128.     double max_val;
  129.     double step;
  130. }
  131. CvParamLattice;
  132. CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
  133.                                          double log_step )
  134. {
  135.     CvParamLattice pl;
  136.     pl.min_val = MIN( min_val, max_val );
  137.     pl.max_val = MAX( min_val, max_val );
  138.     pl.step = MAX( log_step, 1. );
  139.     return pl;
  140. }
  141. CV_INLINE CvParamLattice cvDefaultParamLattice( void )
  142. {
  143.     CvParamLattice pl = {0,0,0};
  144.     return pl;
  145. }
  146. #endif
  147. /* Variable type */
  148. #define CV_VAR_NUMERICAL    0
  149. #define CV_VAR_ORDERED      0
  150. #define CV_VAR_CATEGORICAL  1
  151. #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
  152. #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
  153. #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
  154. #define CV_TYPE_NAME_ML_EM          "opencv-ml-em"
  155. #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
  156. #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
  157. #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
  158. #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
  159. #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
  160. #define CV_TRAIN_ERROR  0
  161. #define CV_TEST_ERROR   1
  162. class CV_EXPORTS CvStatModel
  163. {
  164. public:
  165.     CvStatModel();
  166.     virtual ~CvStatModel();
  167.     virtual void clear();
  168.     virtual void save( const char* filename, const char* name=0 ) const;
  169.     virtual void load( const char* filename, const char* name=0 );
  170.     virtual void write( CvFileStorage* storage, const char* name ) const;
  171.     virtual void read( CvFileStorage* storage, CvFileNode* node );
  172. protected:
  173.     const char* default_model_name;
  174. };
  175. /****************************************************************************************
  176. *                                 Normal Bayes Classifier                                *
  177. ****************************************************************************************/
  178. /* The structure, representing the grid range of statmodel parameters.
  179.    It is used for optimizing statmodel accuracy by varying model parameters,
  180.    the accuracy estimate being computed by cross-validation.
  181.    The grid is logarithmic, so <step> must be greater then 1. */
  182. class CvMLData;
  183. struct CV_EXPORTS CvParamGrid
  184. {
  185.     // SVM params type
  186.     enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
  187.     CvParamGrid()
  188.     {
  189.         min_val = max_val = step = 0;
  190.     }
  191.     CvParamGrid( double _min_val, double _max_val, double log_step )
  192.     {
  193.         min_val = _min_val;
  194.         max_val = _max_val;
  195.         step = log_step;
  196.     }
  197.     //CvParamGrid( int param_id );
  198.     bool check() const;
  199.     double min_val;
  200.     double max_val;
  201.     double step;
  202. };
  203. class CV_EXPORTS CvNormalBayesClassifier : public CvStatModel
  204. {
  205. public:
  206.     CvNormalBayesClassifier();
  207.     virtual ~CvNormalBayesClassifier();
  208.     CvNormalBayesClassifier( const CvMat* _train_data, const CvMat* _responses,
  209.         const CvMat* _var_idx=0, const CvMat* _sample_idx=0 );
  210.     
  211.     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
  212.         const CvMat* _var_idx = 0, const CvMat* _sample_idx=0, bool update=false );
  213.    
  214.     virtual float predict( const CvMat* _samples, CvMat* results=0 ) const;
  215.     virtual void clear();
  216. #ifndef SWIG
  217.     CvNormalBayesClassifier( const cv::Mat& _train_data, const cv::Mat& _responses,
  218.                             const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat() );
  219.     virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
  220.                        const cv::Mat& _var_idx = cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
  221.                        bool update=false );
  222.     virtual float predict( const cv::Mat& _samples, cv::Mat* results=0 ) const;
  223. #endif
  224.     
  225.     virtual void write( CvFileStorage* storage, const char* name ) const;
  226.     virtual void read( CvFileStorage* storage, CvFileNode* node );
  227. protected:
  228.     int     var_count, var_all;
  229.     CvMat*  var_idx;
  230.     CvMat*  cls_labels;
  231.     CvMat** count;
  232.     CvMat** sum;
  233.     CvMat** productsum;
  234.     CvMat** avg;
  235.     CvMat** inv_eigen_values;
  236.     CvMat** cov_rotate_mats;
  237.     CvMat*  c;
  238. };
  239. /****************************************************************************************
  240. *                          K-Nearest Neighbour Classifier                                *
  241. ****************************************************************************************/
  242. // k Nearest Neighbors
  243. class CV_EXPORTS CvKNearest : public CvStatModel
  244. {
  245. public:
  246.     CvKNearest();
  247.     virtual ~CvKNearest();
  248.     CvKNearest( const CvMat* _train_data, const CvMat* _responses,
  249.                 const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );
  250.     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
  251.                         const CvMat* _sample_idx=0, bool is_regression=false,
  252.                         int _max_k=32, bool _update_base=false );
  253.     
  254.     virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0,
  255.         const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
  256.     
  257. #ifndef SWIG
  258.     CvKNearest( const cv::Mat& _train_data, const cv::Mat& _responses,
  259.                const cv::Mat& _sample_idx=cv::Mat(), bool _is_regression=false, int max_k=32 );
  260.     
  261.     virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
  262.                        const cv::Mat& _sample_idx=cv::Mat(), bool is_regression=false,
  263.                        int _max_k=32, bool _update_base=false );    
  264.     
  265.     virtual float find_nearest( const cv::Mat& _samples, int k, cv::Mat* results=0,
  266.                                 const float** neighbors=0,
  267.                                 cv::Mat* neighbor_responses=0,
  268.                                 cv::Mat* dist=0 ) const;
  269. #endif
  270.     
  271.     virtual void clear();
  272.     int get_max_k() const;
  273.     int get_var_count() const;
  274.     int get_sample_count() const;
  275.     bool is_regression() const;
  276. protected:
  277.     virtual float write_results( int k, int k1, int start, int end,
  278.         const float* neighbor_responses, const float* dist, CvMat* _results,
  279.         CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
  280.     virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
  281.         float* neighbor_responses, const float** neighbors, float* dist ) const;
  282.     int max_k, var_count;
  283.     int total;
  284.     bool regression;
  285.     CvVectors* samples;
  286. };
  287. /****************************************************************************************
  288. *                                   Support Vector Machines                              *
  289. ****************************************************************************************/
  290. // SVM training parameters
  291. struct CV_EXPORTS CvSVMParams
  292. {
  293.     CvSVMParams();
  294.     CvSVMParams( int _svm_type, int _kernel_type,
  295.                  double _degree, double _gamma, double _coef0,
  296.                  double _C, double _nu, double _p,
  297.                  CvMat* _class_weights, CvTermCriteria _term_crit );
  298.     int         svm_type;
  299.     int         kernel_type;
  300.     double      degree; // for poly
  301.     double      gamma;  // for poly/rbf/sigmoid
  302.     double      coef0;  // for poly/sigmoid
  303.     double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
  304.     double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
  305.     double      p; // for CV_SVM_EPS_SVR
  306.     CvMat*      class_weights; // for CV_SVM_C_SVC
  307.     CvTermCriteria term_crit; // termination criteria
  308. };
  309. struct CV_EXPORTS CvSVMKernel
  310. {
  311.     typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
  312.                                        const float* another, float* results );
  313.     CvSVMKernel();
  314.     CvSVMKernel( const CvSVMParams* _params, Calc _calc_func );
  315.     virtual bool create( const CvSVMParams* _params, Calc _calc_func );
  316.     virtual ~CvSVMKernel();
  317.     virtual void clear();
  318.     virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
  319.     const CvSVMParams* params;
  320.     Calc calc_func;
  321.     virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
  322.                                     const float* another, float* results,
  323.                                     double alpha, double beta );
  324.     virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
  325.                               const float* another, float* results );
  326.     virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
  327.                            const float* another, float* results );
  328.     virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
  329.                             const float* another, float* results );
  330.     virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
  331.                                const float* another, float* results );
  332. };
  333. struct CvSVMKernelRow
  334. {
  335.     CvSVMKernelRow* prev;
  336.     CvSVMKernelRow* next;
  337.     float* data;
  338. };
  339. struct CvSVMSolutionInfo
  340. {
  341.     double obj;
  342.     double rho;
  343.     double upper_bound_p;
  344.     double upper_bound_n;
  345.     double r;   // for Solver_NU
  346. };
  347. class CV_EXPORTS CvSVMSolver
  348. {
  349. public:
  350.     typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
  351.     typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
  352.     typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
  353.     CvSVMSolver();
  354.     CvSVMSolver( int count, int var_count, const float** samples, schar* y,
  355.                  int alpha_count, double* alpha, double Cp, double Cn,
  356.                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
  357.                  SelectWorkingSet select_working_set, CalcRho calc_rho );
  358.     virtual bool create( int count, int var_count, const float** samples, schar* y,
  359.                  int alpha_count, double* alpha, double Cp, double Cn,
  360.                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
  361.                  SelectWorkingSet select_working_set, CalcRho calc_rho );
  362.     virtual ~CvSVMSolver();
  363.     virtual void clear();
  364.     virtual bool solve_generic( CvSVMSolutionInfo& si );
  365.     virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
  366.                               double Cp, double Cn, CvMemStorage* storage,
  367.                               CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
  368.     virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
  369.                                CvMemStorage* storage, CvSVMKernel* kernel,
  370.                                double* alpha, CvSVMSolutionInfo& si );
  371.     virtual bool solve_one_class( int count, int var_count, const float** samples,
  372.                                   CvMemStorage* storage, CvSVMKernel* kernel,
  373.                                   double* alpha, CvSVMSolutionInfo& si );
  374.     virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
  375.                                 CvMemStorage* storage, CvSVMKernel* kernel,
  376.                                 double* alpha, CvSVMSolutionInfo& si );
  377.     virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
  378.                                CvMemStorage* storage, CvSVMKernel* kernel,
  379.                                double* alpha, CvSVMSolutionInfo& si );
  380.     virtual float* get_row_base( int i, bool* _existed );
  381.     virtual float* get_row( int i, float* dst );
  382.     int sample_count;
  383.     int var_count;
  384.     int cache_size;
  385.     int cache_line_size;
  386.     const float** samples;
  387.     const CvSVMParams* params;
  388.     CvMemStorage* storage;
  389.     CvSVMKernelRow lru_list;
  390.     CvSVMKernelRow* rows;
  391.     int alpha_count;
  392.     double* G;
  393.     double* alpha;
  394.     // -1 - lower bound, 0 - free, 1 - upper bound
  395.     schar* alpha_status;
  396.     schar* y;
  397.     double* b;
  398.     float* buf[2];
  399.     double eps;
  400.     int max_iter;
  401.     double C[2];  // C[0] == Cn, C[1] == Cp
  402.     CvSVMKernel* kernel;
  403.     SelectWorkingSet select_working_set_func;
  404.     CalcRho calc_rho_func;
  405.     GetRow get_row_func;
  406.     virtual bool select_working_set( int& i, int& j );
  407.     virtual bool select_working_set_nu_svm( int& i, int& j );
  408.     virtual void calc_rho( double& rho, double& r );
  409.     virtual void calc_rho_nu_svm( double& rho, double& r );
  410.     virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
  411.     virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
  412.     virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
  413. };
  414. struct CvSVMDecisionFunc
  415. {
  416.     double rho;
  417.     int sv_count;
  418.     double* alpha;
  419.     int* sv_index;
  420. };
  421. // SVM model
  422. class CV_EXPORTS CvSVM : public CvStatModel
  423. {
  424. public:
  425.     // SVM type
  426.     enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
  427.     // SVM kernel type
  428.     enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
  429.     // SVM params type
  430.     enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
  431.     CvSVM();
  432.     virtual ~CvSVM();
  433.     CvSVM( const CvMat* _train_data, const CvMat* _responses,
  434.            const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
  435.            CvSVMParams _params=CvSVMParams() );
  436.     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
  437.                         const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
  438.                         CvSVMParams _params=CvSVMParams() );
  439.     
  440.     virtual bool train_auto( const CvMat* _train_data, const CvMat* _responses,
  441.         const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params,
  442.         int k_fold = 10,
  443.         CvParamGrid C_grid      = get_default_grid(CvSVM::C),
  444.         CvParamGrid gamma_grid  = get_default_grid(CvSVM::GAMMA),
  445.         CvParamGrid p_grid      = get_default_grid(CvSVM::P),
  446.         CvParamGrid nu_grid     = get_default_grid(CvSVM::NU),
  447.         CvParamGrid coef_grid   = get_default_grid(CvSVM::COEF),
  448.         CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) );
  449.     virtual float predict( const CvMat* _sample, bool returnDFVal=false ) const;
  450. #ifndef SWIG
  451.     CvSVM( const cv::Mat& _train_data, const cv::Mat& _responses,
  452.           const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
  453.           CvSVMParams _params=CvSVMParams() );
  454.     
  455.     virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
  456.                        const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
  457.                        CvSVMParams _params=CvSVMParams() );
  458.     
  459.     virtual bool train_auto( const cv::Mat& _train_data, const cv::Mat& _responses,
  460.                             const cv::Mat& _var_idx, const cv::Mat& _sample_idx, CvSVMParams _params,
  461.                             int k_fold = 10,
  462.                             CvParamGrid C_grid      = get_default_grid(CvSVM::C),
  463.                             CvParamGrid gamma_grid  = get_default_grid(CvSVM::GAMMA),
  464.                             CvParamGrid p_grid      = get_default_grid(CvSVM::P),
  465.                             CvParamGrid nu_grid     = get_default_grid(CvSVM::NU),
  466.                             CvParamGrid coef_grid   = get_default_grid(CvSVM::COEF),
  467.                             CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) );
  468.     virtual float predict( const cv::Mat& _sample, bool returnDFVal=false ) const;    
  469. #endif
  470.     
  471.     virtual int get_support_vector_count() const;
  472.     virtual const float* get_support_vector(int i) const;
  473.     virtual CvSVMParams get_params() const { return params; };
  474.     virtual void clear();
  475.     static CvParamGrid get_default_grid( int param_id );
  476.     virtual void write( CvFileStorage* storage, const char* name ) const;
  477.     virtual void read( CvFileStorage* storage, CvFileNode* node );
  478.     int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
  479. protected:
  480.     virtual bool set_params( const CvSVMParams& _params );
  481.     virtual bool train1( int sample_count, int var_count, const float** samples,
  482.                     const void* _responses, double Cp, double Cn,
  483.                     CvMemStorage* _storage, double* alpha, double& rho );
  484.     virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
  485.                     const CvMat* _responses, CvMemStorage* _storage, double* alpha );
  486.     virtual void create_kernel();
  487.     virtual void create_solver();
  488.     virtual void write_params( CvFileStorage* fs ) const;
  489.     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
  490.     CvSVMParams params;
  491.     CvMat* class_labels;
  492.     int var_all;
  493.     float** sv;
  494.     int sv_total;
  495.     CvMat* var_idx;
  496.     CvMat* class_weights;
  497.     CvSVMDecisionFunc* decision_func;
  498.     CvMemStorage* storage;
  499.     CvSVMSolver* solver;
  500.     CvSVMKernel* kernel;
  501. };
  502. /****************************************************************************************
  503. *                              Expectation - Maximization                                *
  504. ****************************************************************************************/
  505. struct CV_EXPORTS CvEMParams
  506. {
  507.     CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/),
  508.         start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0)
  509.     {
  510.         term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
  511.     }
  512.     CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
  513.                 int _start_step=0/*CvEM::START_AUTO_STEP*/,
  514.                 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
  515.                 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
  516.                 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
  517.                 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
  518.     {}
  519.     int nclusters;
  520.     int cov_mat_type;
  521.     int start_step;
  522.     const CvMat* probs;
  523.     const CvMat* weights;
  524.     const CvMat* means;
  525.     const CvMat** covs;
  526.     CvTermCriteria term_crit;
  527. };
  528. class CV_EXPORTS CvEM : public CvStatModel
  529. {
  530. public:
  531.     // Type of covariation matrices
  532.     enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
  533.     // The initial step
  534.     enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
  535.     CvEM();
  536.     CvEM( const CvMat* samples, const CvMat* sample_idx=0,
  537.           CvEMParams params=CvEMParams(), CvMat* labels=0 );
  538.     //CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights, CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats);
  539.     virtual ~CvEM();
  540.     virtual bool train( const CvMat* samples, const CvMat* sample_idx=0,
  541.                         CvEMParams params=CvEMParams(), CvMat* labels=0 );
  542.     virtual float predict( const CvMat* sample, CvMat* probs ) const;
  543. #ifndef SWIG
  544.     CvEM( const cv::Mat& samples, const cv::Mat& sample_idx=cv::Mat(),
  545.          CvEMParams params=CvEMParams(), cv::Mat* labels=0 );
  546.     
  547.     virtual bool train( const cv::Mat& samples, const cv::Mat& sample_idx=cv::Mat(),
  548.                        CvEMParams params=CvEMParams(), cv::Mat* labels=0 );
  549.     
  550.     virtual float predict( const cv::Mat& sample, cv::Mat* probs ) const;
  551. #endif
  552.     
  553.     virtual void clear();
  554.     int           get_nclusters() const;
  555.     const CvMat*  get_means()     const;
  556.     const CvMat** get_covs()      const;
  557.     const CvMat*  get_weights()   const;
  558.     const CvMat*  get_probs()     const;
  559.     inline double         get_log_likelihood     () const { return log_likelihood;     };
  560.     
  561. //    inline const CvMat *  get_log_weight_div_det () const { return log_weight_div_det; };
  562. //    inline const CvMat *  get_inv_eigen_values   () const { return inv_eigen_values;   };
  563. //    inline const CvMat ** get_cov_rotate_mats    () const { return cov_rotate_mats;    };
  564. protected:
  565.     virtual void set_params( const CvEMParams& params,
  566.                              const CvVectors& train_data );
  567.     virtual void init_em( const CvVectors& train_data );
  568.     virtual double run_em( const CvVectors& train_data );
  569.     virtual void init_auto( const CvVectors& samples );
  570.     virtual void kmeans( const CvVectors& train_data, int nclusters,
  571.                          CvMat* labels, CvTermCriteria criteria,
  572.                          const CvMat* means );
  573.     CvEMParams params;
  574.     double log_likelihood;
  575.     CvMat* means;
  576.     CvMat** covs;
  577.     CvMat* weights;
  578.     CvMat* probs;
  579.     CvMat* log_weight_div_det;
  580.     CvMat* inv_eigen_values;
  581.     CvMat** cov_rotate_mats;
  582. };
  583. /****************************************************************************************
  584. *                                      Decision Tree                                     *
  585. ****************************************************************************************/
  586. struct CvPair16u32s
  587. {
  588.     unsigned short* u;
  589.     int* i;
  590. };
  591. #define CV_DTREE_CAT_DIR(idx,subset) 
  592.     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
  593. struct CvDTreeSplit
  594. {
  595.     int var_idx;
  596.     int condensed_idx;
  597.     int inversed;
  598.     float quality;
  599.     CvDTreeSplit* next;
  600.     union
  601.     {
  602.         int subset[2];
  603.         struct
  604.         {
  605.             float c;
  606.             int split_point;
  607.         }
  608.         ord;
  609.     };
  610. };
  611. struct CvDTreeNode
  612. {
  613.     int class_idx;
  614.     int Tn;
  615.     double value;
  616.     CvDTreeNode* parent;
  617.     CvDTreeNode* left;
  618.     CvDTreeNode* right;
  619.     CvDTreeSplit* split;
  620.     int sample_count;
  621.     int depth;
  622.     int* num_valid;
  623.     int offset;
  624.     int buf_idx;
  625.     double maxlr;
  626.     // global pruning data
  627.     int complexity;
  628.     double alpha;
  629.     double node_risk, tree_risk, tree_error;
  630.     // cross-validation pruning data
  631.     int* cv_Tn;
  632.     double* cv_node_risk;
  633.     double* cv_node_error;
  634.     int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
  635.     void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
  636. };
  637. struct CV_EXPORTS CvDTreeParams
  638. {
  639.     int   max_categories;
  640.     int   max_depth;
  641.     int   min_sample_count;
  642.     int   cv_folds;
  643.     bool  use_surrogates;
  644.     bool  use_1se_rule;
  645.     bool  truncate_pruned_tree;
  646.     float regression_accuracy;
  647.     const float* priors;
  648.     CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
  649.         cv_folds(10), use_surrogates(true), use_1se_rule(true),
  650.         truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
  651.     {}
  652.     CvDTreeParams( int _max_depth, int _min_sample_count,
  653.                    float _regression_accuracy, bool _use_surrogates,
  654.                    int _max_categories, int _cv_folds,
  655.                    bool _use_1se_rule, bool _truncate_pruned_tree,
  656.                    const float* _priors ) :
  657.         max_categories(_max_categories), max_depth(_max_depth),
  658.         min_sample_count(_min_sample_count), cv_folds (_cv_folds),
  659.         use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
  660.         truncate_pruned_tree(_truncate_pruned_tree),
  661.         regression_accuracy(_regression_accuracy),
  662.         priors(_priors)
  663.     {}
  664. };
  665. struct CV_EXPORTS CvDTreeTrainData
  666. {
  667.     CvDTreeTrainData();
  668.     CvDTreeTrainData( const CvMat* _train_data, int _tflag,
  669.                       const CvMat* _responses, const CvMat* _var_idx=0,
  670.                       const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  671.                       const CvMat* _missing_mask=0,
  672.                       const CvDTreeParams& _params=CvDTreeParams(),
  673.                       bool _shared=false, bool _add_labels=false );
  674.     virtual ~CvDTreeTrainData();
  675.     virtual void set_data( const CvMat* _train_data, int _tflag,
  676.                           const CvMat* _responses, const CvMat* _var_idx=0,
  677.                           const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  678.                           const CvMat* _missing_mask=0,
  679.                           const CvDTreeParams& _params=CvDTreeParams(),
  680.                           bool _shared=false, bool _add_labels=false,
  681.                           bool _update_data=false );
  682.     virtual void do_responses_copy();
  683.     virtual void get_vectors( const CvMat* _subsample_idx,
  684.          float* values, uchar* missing, float* responses, bool get_class_idx=false );
  685.     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
  686.     virtual void write_params( CvFileStorage* fs ) const;
  687.     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
  688.     // release all the data
  689.     virtual void clear();
  690.     int get_num_classes() const;
  691.     int get_var_type(int vi) const;
  692.     int get_work_var_count() const {return work_var_count;}
  693.     virtual void get_ord_responses( CvDTreeNode* n, float* values_buf, const float** values );    
  694.     virtual void get_class_labels( CvDTreeNode* n, int* labels_buf, const int** labels );
  695.     virtual void get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels );
  696.     virtual void get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** labels );
  697.     virtual int get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values );
  698.     virtual int get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* indices_buf,
  699.         const float** ord_values, const int** indices );
  700.     virtual int get_child_buf_idx( CvDTreeNode* n );
  701.     ////////////////////////////////////
  702.     virtual bool set_params( const CvDTreeParams& params );
  703.     virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
  704.                                    int storage_idx, int offset );
  705.     virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
  706.                 int split_point, int inversed, float quality );
  707.     virtual CvDTreeSplit* new_split_cat( int vi, float quality );
  708.     virtual void free_node_data( CvDTreeNode* node );
  709.     virtual void free_train_data();
  710.     virtual void free_node( CvDTreeNode* node );
  711.     // inner arrays for getting predictors and responses
  712.     float* get_pred_float_buf();
  713.     int* get_pred_int_buf();
  714.     float* get_resp_float_buf();
  715.     int* get_resp_int_buf();
  716.     int* get_cv_lables_buf();
  717.     int* get_sample_idx_buf();
  718.     vector<vector<float> > pred_float_buf;
  719.     vector<vector<int> > pred_int_buf;
  720.     vector<vector<float> > resp_float_buf;
  721.     vector<vector<int> > resp_int_buf;
  722.     vector<vector<int> > cv_lables_buf;
  723.     vector<vector<int> > sample_idx_buf;
  724.     int sample_count, var_all, var_count, max_c_count;
  725.     int ord_var_count, cat_var_count, work_var_count;
  726.     bool have_labels, have_priors;
  727.     bool is_classifier;
  728.     int tflag;
  729.     const CvMat* train_data;
  730.     const CvMat* responses;
  731.     CvMat* responses_copy; // used in Boosting
  732.     int buf_count, buf_size;
  733.     bool shared;
  734.     int is_buf_16u;
  735.     
  736.     CvMat* cat_count;
  737.     CvMat* cat_ofs;
  738.     CvMat* cat_map;
  739.     CvMat* counts;
  740.     CvMat* buf;
  741.     CvMat* direction;
  742.     CvMat* split_buf;
  743.     CvMat* var_idx;
  744.     CvMat* var_type; // i-th element =
  745.                      //   k<0  - ordered
  746.                      //   k>=0 - categorical, see k-th element of cat_* arrays
  747.     CvMat* priors;
  748.     CvMat* priors_mult;
  749.     CvDTreeParams params;
  750.     CvMemStorage* tree_storage;
  751.     CvMemStorage* temp_storage;
  752.     CvDTreeNode* data_root;
  753.     CvSet* node_heap;
  754.     CvSet* split_heap;
  755.     CvSet* cv_heap;
  756.     CvSet* nv_heap;
  757.     CvRNG rng;
  758. };
  759. class CV_EXPORTS CvDTree : public CvStatModel
  760. {
  761. public:
  762.     CvDTree();
  763.     virtual ~CvDTree();
  764.     virtual bool train( const CvMat* _train_data, int _tflag,
  765.                         const CvMat* _responses, const CvMat* _var_idx=0,
  766.                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  767.                         const CvMat* _missing_mask=0,
  768.                         CvDTreeParams params=CvDTreeParams() );
  769.     virtual bool train( CvMLData* _data, CvDTreeParams _params=CvDTreeParams() );
  770.     virtual float calc_error( CvMLData* _data, int type , vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
  771.     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
  772.     virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0,
  773.                                   bool preprocessed_input=false ) const;
  774. #ifndef SWIG
  775.     virtual bool train( const cv::Mat& _train_data, int _tflag,
  776.                        const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
  777.                        const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
  778.                        const cv::Mat& _missing_mask=cv::Mat(),
  779.                        CvDTreeParams params=CvDTreeParams() );
  780.     
  781.     virtual CvDTreeNode* predict( const cv::Mat& _sample, const cv::Mat& _missing_data_mask=cv::Mat(),
  782.                                   bool preprocessed_input=false ) const;
  783. #endif
  784.     
  785.     virtual const CvMat* get_var_importance();
  786.     virtual void clear();
  787.     virtual void read( CvFileStorage* fs, CvFileNode* node );
  788.     virtual void write( CvFileStorage* fs, const char* name ) const;
  789.     // special read & write methods for trees in the tree ensembles
  790.     virtual void read( CvFileStorage* fs, CvFileNode* node,
  791.                        CvDTreeTrainData* data );
  792.     virtual void write( CvFileStorage* fs ) const;
  793.     const CvDTreeNode* get_root() const;
  794.     int get_pruned_tree_idx() const;
  795.     CvDTreeTrainData* get_data();
  796. protected:
  797.     virtual bool do_train( const CvMat* _subsample_idx );
  798.     virtual void try_split_node( CvDTreeNode* n );
  799.     virtual void split_node_data( CvDTreeNode* n );
  800.     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
  801.     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
  802.                             float init_quality = 0, CvDTreeSplit* _split = 0 );
  803.     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
  804.                             float init_quality = 0, CvDTreeSplit* _split = 0 );
  805.     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
  806.                             float init_quality = 0, CvDTreeSplit* _split = 0 );
  807.     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
  808.                             float init_quality = 0, CvDTreeSplit* _split = 0 );
  809.     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
  810.     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
  811.     virtual double calc_node_dir( CvDTreeNode* node );
  812.     virtual void complete_node_dir( CvDTreeNode* node );
  813.     virtual void cluster_categories( const int* vectors, int vector_count,
  814.         int var_count, int* sums, int k, int* cluster_labels );
  815.     virtual void calc_node_value( CvDTreeNode* node );
  816.     virtual void prune_cv();
  817.     virtual double update_tree_rnc( int T, int fold );
  818.     virtual int cut_tree( int T, int fold, double min_alpha );
  819.     virtual void free_prune_data(bool cut_tree);
  820.     virtual void free_tree();
  821.     virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
  822.     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
  823.     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
  824.     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
  825.     virtual void write_tree_nodes( CvFileStorage* fs ) const;
  826.     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
  827.     CvDTreeNode* root;
  828.     CvMat* var_importance;
  829.     CvDTreeTrainData* data;
  830. public:
  831.     int pruned_tree_idx;
  832. };
  833. /****************************************************************************************
  834. *                                   Random Trees Classifier                              *
  835. ****************************************************************************************/
  836. class CvRTrees;
  837. class CV_EXPORTS CvForestTree: public CvDTree
  838. {
  839. public:
  840.     CvForestTree();
  841.     virtual ~CvForestTree();
  842.     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvRTrees* forest );
  843.     virtual int get_var_count() const {return data ? data->var_count : 0;}
  844.     virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
  845.     /* dummy methods to avoid warnings: BEGIN */
  846.     virtual bool train( const CvMat* _train_data, int _tflag,
  847.                         const CvMat* _responses, const CvMat* _var_idx=0,
  848.                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  849.                         const CvMat* _missing_mask=0,
  850.                         CvDTreeParams params=CvDTreeParams() );
  851.     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
  852.     virtual void read( CvFileStorage* fs, CvFileNode* node );
  853.     virtual void read( CvFileStorage* fs, CvFileNode* node,
  854.                        CvDTreeTrainData* data );
  855.     /* dummy methods to avoid warnings: END */
  856. protected:
  857.     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
  858.     CvRTrees* forest;
  859. };
  860. struct CV_EXPORTS CvRTParams : public CvDTreeParams
  861. {
  862.     //Parameters for the forest
  863.     bool calc_var_importance; // true <=> RF processes variable importance
  864.     int nactive_vars;
  865.     CvTermCriteria term_crit;
  866.     CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
  867.         calc_var_importance(false), nactive_vars(0)
  868.     {
  869.         term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
  870.     }
  871.     CvRTParams( int _max_depth, int _min_sample_count,
  872.                 float _regression_accuracy, bool _use_surrogates,
  873.                 int _max_categories, const float* _priors, bool _calc_var_importance,
  874.                 int _nactive_vars, int max_num_of_trees_in_the_forest,
  875.                 float forest_accuracy, int termcrit_type ) :
  876.         CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
  877.                        _use_surrogates, _max_categories, 0,
  878.                        false, false, _priors ),
  879.         calc_var_importance(_calc_var_importance),
  880.         nactive_vars(_nactive_vars)
  881.     {
  882.         term_crit = cvTermCriteria(termcrit_type,
  883.             max_num_of_trees_in_the_forest, forest_accuracy);
  884.     }
  885. };
  886. class CV_EXPORTS CvRTrees : public CvStatModel
  887. {
  888. public:
  889.     CvRTrees();
  890.     virtual ~CvRTrees();
  891.     virtual bool train( const CvMat* _train_data, int _tflag,
  892.                         const CvMat* _responses, const CvMat* _var_idx=0,
  893.                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  894.                         const CvMat* _missing_mask=0,
  895.                         CvRTParams params=CvRTParams() );
  896.     
  897.     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
  898.     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
  899.     virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
  900. #ifndef SWIG
  901.     virtual bool train( const cv::Mat& _train_data, int _tflag,
  902.                        const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
  903.                        const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
  904.                        const cv::Mat& _missing_mask=cv::Mat(),
  905.                        CvRTParams params=CvRTParams() );
  906.     virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
  907.     virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
  908. #endif
  909.     
  910.     virtual void clear();
  911.     virtual const CvMat* get_var_importance();
  912.     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
  913.         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
  914.     
  915.     virtual float calc_error( CvMLData* _data, int type , vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
  916.     virtual float get_train_error();    
  917.     virtual void read( CvFileStorage* fs, CvFileNode* node );
  918.     virtual void write( CvFileStorage* fs, const char* name ) const;
  919.     CvMat* get_active_var_mask();
  920.     CvRNG* get_rng();
  921.     int get_tree_count() const;
  922.     CvForestTree* get_tree(int i) const;
  923. protected:
  924.     virtual bool grow_forest( const CvTermCriteria term_crit );
  925.     // array of the trees of the forest
  926.     CvForestTree** trees;
  927.     CvDTreeTrainData* data;
  928.     int ntrees;
  929.     int nclasses;
  930.     double oob_error;
  931.     CvMat* var_importance;
  932.     int nsamples;
  933.     CvRNG rng;
  934.     CvMat* active_var_mask;
  935. };
  936. /****************************************************************************************
  937. *                           Extremely randomized trees Classifier                        *
  938. ****************************************************************************************/
  939. struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
  940. {
  941.     virtual void set_data( const CvMat* _train_data, int _tflag,
  942.                           const CvMat* _responses, const CvMat* _var_idx=0,
  943.                           const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  944.                           const CvMat* _missing_mask=0,
  945.                           const CvDTreeParams& _params=CvDTreeParams(),
  946.                           bool _shared=false, bool _add_labels=false,
  947.                           bool _update_data=false );
  948.     virtual int get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
  949.         const float** ord_values, const int** missing );
  950.     virtual void get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices );
  951.     virtual void get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels );
  952.     virtual int get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values );
  953.     virtual void get_vectors( const CvMat* _subsample_idx,
  954.          float* values, uchar* missing, float* responses, bool get_class_idx=false );
  955.     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
  956.     const CvMat* missing_mask;
  957. };
  958. class CV_EXPORTS CvForestERTree : public CvForestTree
  959. {
  960. protected:
  961.     virtual double calc_node_dir( CvDTreeNode* node );
  962.     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
  963.         float init_quality = 0, CvDTreeSplit* _split = 0 );
  964.     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
  965.         float init_quality = 0, CvDTreeSplit* _split = 0 );
  966.     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
  967.         float init_quality = 0, CvDTreeSplit* _split = 0 );
  968.     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
  969.         float init_quality = 0, CvDTreeSplit* _split = 0 );
  970.     //virtual void complete_node_dir( CvDTreeNode* node );
  971.     virtual void split_node_data( CvDTreeNode* n );
  972. };
  973. class CV_EXPORTS CvERTrees : public CvRTrees
  974. {
  975. public:
  976.     CvERTrees();
  977.     virtual ~CvERTrees();
  978.     virtual bool train( const CvMat* _train_data, int _tflag,
  979.                         const CvMat* _responses, const CvMat* _var_idx=0,
  980.                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  981.                         const CvMat* _missing_mask=0,
  982.                         CvRTParams params=CvRTParams());
  983. #ifndef SWIG
  984.     virtual bool train( const cv::Mat& _train_data, int _tflag,
  985.                        const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
  986.                        const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
  987.                        const cv::Mat& _missing_mask=cv::Mat(),
  988.                        CvRTParams params=CvRTParams());
  989. #endif
  990.     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
  991. protected:
  992.     virtual bool grow_forest( const CvTermCriteria term_crit );
  993. };
  994. /****************************************************************************************
  995. *                                   Boosted tree classifier                              *
  996. ****************************************************************************************/
  997. struct CV_EXPORTS CvBoostParams : public CvDTreeParams
  998. {
  999.     int boost_type;
  1000.     int weak_count;
  1001.     int split_criteria;
  1002.     double weight_trim_rate;
  1003.     CvBoostParams();
  1004.     CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
  1005.                    int max_depth, bool use_surrogates, const float* priors );
  1006. };
  1007. class CvBoost;
  1008. class CV_EXPORTS CvBoostTree: public CvDTree
  1009. {
  1010. public:
  1011.     CvBoostTree();
  1012.     virtual ~CvBoostTree();
  1013.     virtual bool train( CvDTreeTrainData* _train_data,
  1014.                         const CvMat* subsample_idx, CvBoost* ensemble );
  1015.     virtual void scale( double s );
  1016.     virtual void read( CvFileStorage* fs, CvFileNode* node,
  1017.                        CvBoost* ensemble, CvDTreeTrainData* _data );
  1018.     virtual void clear();
  1019.     /* dummy methods to avoid warnings: BEGIN */
  1020.     virtual bool train( const CvMat* _train_data, int _tflag,
  1021.                         const CvMat* _responses, const CvMat* _var_idx=0,
  1022.                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  1023.                         const CvMat* _missing_mask=0,
  1024.                         CvDTreeParams params=CvDTreeParams() );
  1025.     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
  1026.     virtual void read( CvFileStorage* fs, CvFileNode* node );
  1027.     virtual void read( CvFileStorage* fs, CvFileNode* node,
  1028.                        CvDTreeTrainData* data );
  1029.     /* dummy methods to avoid warnings: END */
  1030. protected:
  1031.     virtual void try_split_node( CvDTreeNode* n );
  1032.     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
  1033.     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
  1034.     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
  1035.         float init_quality = 0, CvDTreeSplit* _split = 0 );
  1036.     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
  1037.         float init_quality = 0, CvDTreeSplit* _split = 0 );
  1038.     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
  1039.         float init_quality = 0, CvDTreeSplit* _split = 0 );
  1040.     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
  1041.         float init_quality = 0, CvDTreeSplit* _split = 0 );
  1042.     virtual void calc_node_value( CvDTreeNode* n );
  1043.     virtual double calc_node_dir( CvDTreeNode* n );
  1044.     CvBoost* ensemble;
  1045. };
  1046. class CV_EXPORTS CvBoost : public CvStatModel
  1047. {
  1048. public:
  1049.     // Boosting type
  1050.     enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
  1051.     // Splitting criteria
  1052.     enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
  1053.     CvBoost();
  1054.     virtual ~CvBoost();
  1055.     CvBoost( const CvMat* _train_data, int _tflag,
  1056.              const CvMat* _responses, const CvMat* _var_idx=0,
  1057.              const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  1058.              const CvMat* _missing_mask=0,
  1059.              CvBoostParams params=CvBoostParams() );
  1060.     
  1061.     virtual bool train( const CvMat* _train_data, int _tflag,
  1062.              const CvMat* _responses, const CvMat* _var_idx=0,
  1063.              const CvMat* _sample_idx=0, const CvMat* _var_type=0,
  1064.              const CvMat* _missing_mask=0,
  1065.              CvBoostParams params=CvBoostParams(),
  1066.              bool update=false );
  1067.     
  1068.     virtual bool train( CvMLData* data,
  1069.              CvBoostParams params=CvBoostParams(),
  1070.              bool update=false );
  1071.     virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
  1072.                            CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
  1073.                            bool raw_mode=false, bool return_sum=false ) const;
  1074. #ifndef SWIG
  1075.     CvBoost( const cv::Mat& _train_data, int _tflag,
  1076.             const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
  1077.             const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
  1078.             const cv::Mat& _missing_mask=cv::Mat(),
  1079.             CvBoostParams params=CvBoostParams() );
  1080.     
  1081.     virtual bool train( const cv::Mat& _train_data, int _tflag,
  1082.                        const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
  1083.                        const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
  1084.                        const cv::Mat& _missing_mask=cv::Mat(),
  1085.                        CvBoostParams params=CvBoostParams(),
  1086.                        bool update=false );
  1087.     
  1088.     virtual float predict( const cv::Mat& _sample, const cv::Mat& _missing=cv::Mat(),
  1089.                           cv::Mat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
  1090.                           bool raw_mode=false, bool return_sum=false ) const;
  1091. #endif
  1092.     
  1093.     virtual float calc_error( CvMLData* _data, int type , vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
  1094.     virtual void prune( CvSlice slice );
  1095.     virtual void clear();
  1096.     virtual void write( CvFileStorage* storage, const char* name ) const;
  1097.     virtual void read( CvFileStorage* storage, CvFileNode* node );
  1098.     virtual const CvMat* get_active_vars(bool absolute_idx=true);
  1099.     CvSeq* get_weak_predictors();
  1100.     CvMat* get_weights();
  1101.     CvMat* get_subtree_weights();
  1102.     CvMat* get_weak_response();
  1103.     const CvBoostParams& get_params() const;
  1104.     const CvDTreeTrainData* get_data() const;
  1105. protected:
  1106.     virtual bool set_params( const CvBoostParams& _params );
  1107.     virtual void update_weights( CvBoostTree* tree );
  1108.     virtual void trim_weights();
  1109.     virtual void write_params( CvFileStorage* fs ) const;
  1110.     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
  1111.     CvDTreeTrainData* data;
  1112.     CvBoostParams params;
  1113.     CvSeq* weak;
  1114.     CvMat* active_vars;
  1115.     CvMat* active_vars_abs;
  1116.     bool have_active_cat_vars;
  1117.     CvMat* orig_response;
  1118.     CvMat* sum_response;
  1119.     CvMat* weak_eval;
  1120.     CvMat* subsample_mask;
  1121.     CvMat* weights;
  1122.     CvMat* subtree_weights;
  1123.     bool have_subsample;
  1124. };
  1125. /****************************************************************************************
  1126. *                              Artificial Neural Networks (ANN)                          *
  1127. ****************************************************************************************/
  1128. /////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
  1129. struct CV_EXPORTS CvANN_MLP_TrainParams
  1130. {
  1131.     CvANN_MLP_TrainParams();
  1132.     CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
  1133.                            double param1, double param2=0 );
  1134.     ~CvANN_MLP_TrainParams();
  1135.     enum { BACKPROP=0, RPROP=1 };
  1136.     CvTermCriteria term_crit;
  1137.     int train_method;
  1138.     // backpropagation parameters
  1139.     double bp_dw_scale, bp_moment_scale;
  1140.     // rprop parameters
  1141.     double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
  1142. };
  1143. class CV_EXPORTS CvANN_MLP : public CvStatModel
  1144. {
  1145. public:
  1146.     CvANN_MLP();
  1147.     CvANN_MLP( const CvMat* _layer_sizes,
  1148.                int _activ_func=SIGMOID_SYM,
  1149.                double _f_param1=0, double _f_param2=0 );
  1150.     virtual ~CvANN_MLP();
  1151.     virtual void create( const CvMat* _layer_sizes,
  1152.                          int _activ_func=SIGMOID_SYM,
  1153.                          double _f_param1=0, double _f_param2=0 );
  1154.     
  1155.     virtual int train( const CvMat* _inputs, const CvMat* _outputs,
  1156.                        const CvMat* _sample_weights, const CvMat* _sample_idx=0,
  1157.                        CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
  1158.                        int flags=0 );
  1159.     virtual float predict( const CvMat* _inputs, CvMat* _outputs ) const;
  1160.     
  1161. #ifndef SWIG
  1162.     CvANN_MLP( const cv::Mat& _layer_sizes,
  1163.               int _activ_func=SIGMOID_SYM,
  1164.               double _f_param1=0, double _f_param2=0 );
  1165.     
  1166.     virtual void create( const cv::Mat& _layer_sizes,
  1167.                         int _activ_func=SIGMOID_SYM,
  1168.                         double _f_param1=0, double _f_param2=0 );    
  1169.     
  1170.     virtual int train( const cv::Mat& _inputs, const cv::Mat& _outputs,
  1171.                       const cv::Mat& _sample_weights, const cv::Mat& _sample_idx=cv::Mat(),
  1172.                       CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
  1173.                       int flags=0 );    
  1174.     
  1175.     virtual float predict( const cv::Mat& _inputs, cv::Mat& _outputs ) const;
  1176. #endif
  1177.     
  1178.     virtual void clear();
  1179.     // possible activation functions
  1180.     enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
  1181.     // available training flags
  1182.     enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
  1183.     virtual void read( CvFileStorage* fs, CvFileNode* node );
  1184.     virtual void write( CvFileStorage* storage, const char* name ) const;
  1185.     int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
  1186.     const CvMat* get_layer_sizes() { return layer_sizes; }
  1187.     double* get_weights(int layer)
  1188.     {
  1189.         return layer_sizes && weights &&
  1190.             (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
  1191.     }
  1192. protected:
  1193.     virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
  1194.             const CvMat* _sample_weights, const CvMat* _sample_idx,
  1195.             CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
  1196.     // sequential random backpropagation
  1197.     virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
  1198.     // RPROP algorithm
  1199.     virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
  1200.     virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
  1201.     virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
  1202.     virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
  1203.                                  double _f_param1=0, double _f_param2=0 );
  1204.     virtual void init_weights();
  1205.     virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
  1206.     virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
  1207.     virtual void calc_input_scale( const CvVectors* vecs, int flags );
  1208.     virtual void calc_output_scale( const CvVectors* vecs, int flags );
  1209.     virtual void write_params( CvFileStorage* fs ) const;
  1210.     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
  1211.     CvMat* layer_sizes;
  1212.     CvMat* wbuf;
  1213.     CvMat* sample_weights;
  1214.     double** weights;
  1215.     double f_param1, f_param2;
  1216.     double min_val, max_val, min_val1, max_val1;
  1217.     int activ_func;
  1218.     int max_count, max_buf_sz;
  1219.     CvANN_MLP_TrainParams params;
  1220.     CvRNG rng;
  1221. };
  1222. #if 0
  1223. /****************************************************************************************
  1224. *                            Convolutional Neural Network                                *
  1225. ****************************************************************************************/
  1226. typedef struct CvCNNLayer CvCNNLayer;
  1227. typedef struct CvCNNetwork CvCNNetwork;
  1228. #define CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY  1
  1229. #define CV_CNN_LEARN_RATE_DECREASE_SQRT_INV        2
  1230. #define CV_CNN_LEARN_RATE_DECREASE_LOG_INV         3
  1231. #define CV_CNN_GRAD_ESTIM_RANDOM        0
  1232. #define CV_CNN_GRAD_ESTIM_BY_WORST_IMG  1
  1233. #define ICV_CNN_LAYER                0x55550000
  1234. #define ICV_CNN_CONVOLUTION_LAYER    0x00001111
  1235. #define ICV_CNN_SUBSAMPLING_LAYER    0x00002222
  1236. #define ICV_CNN_FULLCONNECT_LAYER    0x00003333
  1237. #define ICV_IS_CNN_LAYER( layer )                                          
  1238.     ( ((layer) != NULL) && ((((CvCNNLayer*)(layer))->flags & CV_MAGIC_MASK)
  1239.         == ICV_CNN_LAYER ))
  1240. #define ICV_IS_CNN_CONVOLUTION_LAYER( layer )                              
  1241.     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       
  1242.         & ~CV_MAGIC_MASK) == ICV_CNN_CONVOLUTION_LAYER )
  1243. #define ICV_IS_CNN_SUBSAMPLING_LAYER( layer )                              
  1244.     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       
  1245.         & ~CV_MAGIC_MASK) == ICV_CNN_SUBSAMPLING_LAYER )
  1246. #define ICV_IS_CNN_FULLCONNECT_LAYER( layer )                              
  1247.     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       
  1248.         & ~CV_MAGIC_MASK) == ICV_CNN_FULLCONNECT_LAYER )
  1249. typedef void (CV_CDECL *CvCNNLayerForward)
  1250.     ( CvCNNLayer* layer, const CvMat* input, CvMat* output );
  1251. typedef void (CV_CDECL *CvCNNLayerBackward)
  1252.     ( CvCNNLayer* layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX );
  1253. typedef void (CV_CDECL *CvCNNLayerRelease)
  1254.     (CvCNNLayer** layer);
  1255. typedef void (CV_CDECL *CvCNNetworkAddLayer)
  1256.     (CvCNNetwork* network, CvCNNLayer* layer);
  1257. typedef void (CV_CDECL *CvCNNetworkRelease)
  1258.     (CvCNNetwork** network);
  1259. #define CV_CNN_LAYER_FIELDS()           
  1260.     /* Indicator of the layer's type */ 
  1261.     int flags;                          
  1262.                                         
  1263.     /* Number of input images */        
  1264.     int n_input_planes;                 
  1265.     /* Height of each input image */    
  1266.     int input_height;                   
  1267.     /* Width of each input image */     
  1268.     int input_width;                    
  1269.                                         
  1270.     /* Number of output images */       
  1271.     int n_output_planes;                
  1272.     /* Height of each output image */   
  1273.     int output_height;                  
  1274.     /* Width of each output image */    
  1275.     int output_width;                   
  1276.                                         
  1277.     /* Learning rate at the first iteration */                      
  1278.     float init_learn_rate;                                          
  1279.     /* Dynamics of learning rate decreasing */                      
  1280.     int learn_rate_decrease_type;                                   
  1281.     /* Trainable weights of the layer (including bias) */           
  1282.     /* i-th row is a set of weights of the i-th output plane */     
  1283.     CvMat* weights;                                                 
  1284.                                                                     
  1285.     CvCNNLayerForward  forward;                                     
  1286.     CvCNNLayerBackward backward;                                    
  1287.     CvCNNLayerRelease  release;                                     
  1288.     /* Pointers to the previous and next layers in the network */   
  1289.     CvCNNLayer* prev_layer;                                         
  1290.     CvCNNLayer* next_layer
  1291. typedef struct CvCNNLayer
  1292. {
  1293.     CV_CNN_LAYER_FIELDS();
  1294. }CvCNNLayer;
  1295. typedef struct CvCNNConvolutionLayer
  1296. {
  1297.     CV_CNN_LAYER_FIELDS();
  1298.     // Kernel size (height and width) for convolution.
  1299.     int K;
  1300.     // connections matrix, (i,j)-th element is 1 iff there is a connection between
  1301.     // i-th plane of the current layer and j-th plane of the previous layer;
  1302.     // (i,j)-th element is equal to 0 otherwise
  1303.     CvMat *connect_mask;
  1304.     // value of the learning rate for updating weights at the first iteration
  1305. }CvCNNConvolutionLayer;
  1306. typedef struct CvCNNSubSamplingLayer
  1307. {
  1308.     CV_CNN_LAYER_FIELDS();
  1309.     // ratio between the heights (or widths - ratios are supposed to be equal)
  1310.     // of the input and output planes
  1311.     int sub_samp_scale;
  1312.     // amplitude of sigmoid activation function
  1313.     float a;
  1314.     // scale parameter of sigmoid activation function
  1315.     float s;
  1316.     // exp2ssumWX = exp(2<s>*(bias+w*(x1+...+x4))), where x1,...x4 are some elements of X
  1317.     // - is the vector used in computing of the activation function in backward
  1318.     CvMat* exp2ssumWX;
  1319.     // (x1+x2+x3+x4), where x1,...x4 are some elements of X
  1320.     // - is the vector used in computing of the activation function in backward
  1321.     CvMat* sumX;
  1322. }CvCNNSubSamplingLayer;
  1323. // Structure of the last layer.
  1324. typedef struct CvCNNFullConnectLayer
  1325. {
  1326.     CV_CNN_LAYER_FIELDS();
  1327.     // amplitude of sigmoid activation function
  1328.     float a;
  1329.     // scale parameter of sigmoid activation function
  1330.     float s;
  1331.     // exp2ssumWX = exp(2*<s>*(W*X)) - is the vector used in computing of the
  1332.     // activation function and it's derivative by the formulae
  1333.     // activ.func. = <a>(exp(2<s>WX)-1)/(exp(2<s>WX)+1) == <a> - 2<a>/(<exp2ssumWX> + 1)
  1334.     // (activ.func.)' = 4<a><s>exp(2<s>WX)/(exp(2<s>WX)+1)^2
  1335.     CvMat* exp2ssumWX;
  1336. }CvCNNFullConnectLayer;
  1337. typedef struct CvCNNetwork
  1338. {
  1339.     int n_layers;
  1340.     CvCNNLayer* layers;
  1341.     CvCNNetworkAddLayer add_layer;
  1342.     CvCNNetworkRelease release;
  1343. }CvCNNetwork;
  1344. typedef struct CvCNNStatModel
  1345. {
  1346.     CV_STAT_MODEL_FIELDS();
  1347.     CvCNNetwork* network;
  1348.     // etalons are allocated as rows, the i-th etalon has label cls_labeles[i]
  1349.     CvMat* etalons;
  1350.     // classes labels
  1351.     CvMat* cls_labels;
  1352. }CvCNNStatModel;
  1353. typedef struct CvCNNStatModelParams
  1354. {
  1355.     CV_STAT_MODEL_PARAM_FIELDS();
  1356.     // network must be created by the functions cvCreateCNNetwork and <add_layer>
  1357.     CvCNNetwork* network;
  1358.     CvMat* etalons;
  1359.     // termination criteria
  1360.     int max_iter;
  1361.     int start_iter;
  1362.     int grad_estim_type;
  1363. }CvCNNStatModelParams;
  1364. CVAPI(CvCNNLayer*) cvCreateCNNConvolutionLayer(
  1365.     int n_input_planes, int input_height, int input_width,
  1366.     int n_output_planes, int K,
  1367.     float init_learn_rate, int learn_rate_decrease_type,
  1368.     CvMat* connect_mask CV_DEFAULT(0), CvMat* weights CV_DEFAULT(0) );
  1369. CVAPI(CvCNNLayer*) cvCreateCNNSubSamplingLayer(
  1370.     int n_input_planes, int input_height, int input_width,
  1371.     int sub_samp_scale, float a, float s,
  1372.     float init_learn_rate, int learn_rate_decrease_type, CvMat* weights CV_DEFAULT(0) );
  1373. CVAPI(CvCNNLayer*) cvCreateCNNFullConnectLayer(
  1374.     int n_inputs, int n_outputs, float a, float s,
  1375.     float init_learn_rate, int learning_type, CvMat* weights CV_DEFAULT(0) );
  1376. CVAPI(CvCNNetwork*) cvCreateCNNetwork( CvCNNLayer* first_layer );
  1377. CVAPI(CvStatModel*) cvTrainCNNClassifier(
  1378.             const CvMat* train_data, int tflag,
  1379.             const CvMat* responses,
  1380.             const CvStatModelParams* params,
  1381.             const CvMat* CV_DEFAULT(0),
  1382.             const CvMat* sample_idx CV_DEFAULT(0),
  1383.             const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0) );
  1384. /****************************************************************************************
  1385. *                               Estimate classifiers algorithms                          *
  1386. ****************************************************************************************/
  1387. typedef const CvMat* (CV_CDECL *CvStatModelEstimateGetMat)
  1388.                     ( const CvStatModel* estimateModel );
  1389. typedef int (CV_CDECL *CvStatModelEstimateNextStep)
  1390.                     ( CvStatModel* estimateModel );
  1391. typedef void (CV_CDECL *CvStatModelEstimateCheckClassifier)
  1392.                     ( CvStatModel* estimateModel,
  1393.                 const CvStatModel* model,
  1394.                 const CvMat*       features,
  1395.                       int          sample_t_flag,
  1396.                 const CvMat*       responses );
  1397. typedef void (CV_CDECL *CvStatModelEstimateCheckClassifierEasy)
  1398.                     ( CvStatModel* estimateModel,
  1399.                 const CvStatModel* model );
  1400. typedef float (CV_CDECL *CvStatModelEstimateGetCurrentResult)
  1401.                     ( const CvStatModel* estimateModel,
  1402.                             float*       correlation );
  1403. typedef void (CV_CDECL *CvStatModelEstimateReset)
  1404.                     ( CvStatModel* estimateModel );
  1405. //-------------------------------- Cross-validation --------------------------------------
  1406. #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS()    
  1407.     CV_STAT_MODEL_PARAM_FIELDS();                                 
  1408.     int     k_fold;                                               
  1409.     int     is_regression;                                        
  1410.     CvRNG*  rng
  1411. typedef struct CvCrossValidationParams
  1412. {
  1413.     CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS();
  1414. } CvCrossValidationParams;
  1415. #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS()    
  1416.     CvStatModelEstimateGetMat               getTrainIdxMat; 
  1417.     CvStatModelEstimateGetMat               getCheckIdxMat; 
  1418.     CvStatModelEstimateNextStep             nextStep;       
  1419.     CvStatModelEstimateCheckClassifier      check;          
  1420.     CvStatModelEstimateGetCurrentResult     getResult;      
  1421.     CvStatModelEstimateReset                reset;          
  1422.     int     is_regression;                                  
  1423.     int     folds_all;                                      
  1424.     int     samples_all;                                    
  1425.     int*    sampleIdxAll;                                   
  1426.     int*    folds;                                          
  1427.     int     max_fold_size;                                  
  1428.     int         current_fold;                               
  1429.     int         is_checked;                                 
  1430.     CvMat*      sampleIdxTrain;                             
  1431.     CvMat*      sampleIdxEval;                              
  1432.     CvMat*      predict_results;                            
  1433.     int     correct_results;                                
  1434.     int     all_results;                                    
  1435.     double  sq_error;                                       
  1436.     double  sum_correct;                                    
  1437.     double  sum_predict;                                    
  1438.     double  sum_cc;                                         
  1439.     double  sum_pp;                                         
  1440.     double  sum_cp
  1441. typedef struct CvCrossValidationModel
  1442. {
  1443.     CV_STAT_MODEL_FIELDS();
  1444.     CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS();
  1445. } CvCrossValidationModel;
  1446. CVAPI(CvStatModel*)
  1447. cvCreateCrossValidationEstimateModel
  1448.            ( int                samples_all,
  1449.        const CvStatModelParams* estimateParams CV_DEFAULT(0),
  1450.        const CvMat*             sampleIdx CV_DEFAULT(0) );
  1451. CVAPI(float)
  1452. cvCrossValidation( const CvMat*             trueData,
  1453.                          int                tflag,
  1454.                    const CvMat*             trueClasses,
  1455.                          CvStatModel*     (*createClassifier)( const CvMat*,
  1456.                                                                      int,
  1457.                                                                const CvMat*,
  1458.                                                                const CvStatModelParams*,
  1459.                                                                const CvMat*,
  1460.                                                                const CvMat*,
  1461.                                                                const CvMat*,
  1462.                                                                const CvMat* ),
  1463.                    const CvStatModelParams* estimateParams CV_DEFAULT(0),
  1464.                    const CvStatModelParams* trainParams CV_DEFAULT(0),
  1465.                    const CvMat*             compIdx CV_DEFAULT(0),
  1466.                    const CvMat*             sampleIdx CV_DEFAULT(0),
  1467.                          CvStatModel**      pCrValModel CV_DEFAULT(0),
  1468.                    const CvMat*             typeMask CV_DEFAULT(0),
  1469.                    const CvMat*             missedMeasurementMask CV_DEFAULT(0) );
  1470. #endif
  1471. /****************************************************************************************
  1472. *                           Auxilary functions declarations                              *
  1473. ****************************************************************************************/
  1474. /* Generates <sample> from multivariate normal distribution, where <mean> - is an
  1475.    average row vector, <cov> - symmetric covariation matrix */
  1476. CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
  1477.                            CvRNG* rng CV_DEFAULT(0) );
  1478. /* Generates sample from gaussian mixture distribution */
  1479. CVAPI(void) cvRandGaussMixture( CvMat* means[],
  1480.                                CvMat* covs[],
  1481.                                float weights[],
  1482.                                int clsnum,
  1483.                                CvMat* sample,
  1484.                                CvMat* sampClasses CV_DEFAULT(0) );
  1485. #define CV_TS_CONCENTRIC_SPHERES 0
  1486. /* creates test set */
  1487. CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
  1488.                  int num_samples,
  1489.                  int num_features,
  1490.                  CvMat** responses,
  1491.                  int num_classes, ... );
  1492. #endif
  1493. /****************************************************************************************
  1494. *                                      Data                                             *
  1495. ****************************************************************************************/
  1496. #include <map>
  1497. #include <string>
  1498. #include <iostream>
  1499. using namespace std;
  1500. #define CV_COUNT     0
  1501. #define CV_PORTION   1
  1502. struct CV_EXPORTS CvTrainTestSplit
  1503. {
  1504. public:
  1505.     CvTrainTestSplit();
  1506.     CvTrainTestSplit( int _train_sample_count, bool _mix = true);
  1507.     CvTrainTestSplit( float _train_sample_portion, bool _mix = true);
  1508.     union
  1509.     {
  1510.         int count;
  1511.         float portion;
  1512.     } train_sample_part;
  1513.     int train_sample_part_mode;
  1514.     union
  1515.     {
  1516.         int *count;
  1517.         float *portion;
  1518.     } *class_part;
  1519.     int class_part_mode;
  1520.     bool mix;    
  1521. };
  1522. class CV_EXPORTS CvMLData
  1523. {
  1524. public:
  1525.     CvMLData();
  1526.     virtual ~CvMLData();
  1527.     // returns:
  1528.     // 0 - OK  
  1529.     // 1 - file can not be opened or is not correct
  1530.     int read_csv(const char* filename);
  1531.     const CvMat* get_values(){ return values; };
  1532.     const CvMat* get_responses();
  1533.     const CvMat* get_missing(){ return missing; };
  1534.     void set_response_idx( int idx ); // idx < 0 to set all vars as predictors
  1535.     int get_response_idx() { return response_idx; }
  1536.     const CvMat* get_train_sample_idx() { return train_sample_idx; };
  1537.     const CvMat* get_test_sample_idx() { return test_sample_idx; };
  1538.     void mix_train_and_test_idx();
  1539.     void set_train_test_split( const CvTrainTestSplit * spl);
  1540.     
  1541.     const CvMat* get_var_idx();
  1542.     void chahge_var_idx( int vi, bool state );
  1543.     const CvMat* get_var_types();
  1544.     int get_var_type( int var_idx ) { return var_types->data.ptr[var_idx]; };
  1545.     // following 2 methods enable to change vars type
  1546.     // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
  1547.     // with numerical labels; in the other cases var types are correctly determined automatically
  1548.     void set_var_types( const char* str );  // str examples:
  1549.                                             // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
  1550.                                             // "cat", "ord" (all vars are categorical/ordered)
  1551.     void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }    
  1552.  
  1553.     void set_delimiter( char ch );
  1554.     char get_delimiter() { return delimiter; };
  1555.     void set_miss_ch( char ch );
  1556.     char get_miss_ch() { return miss_ch; };
  1557.     
  1558. protected:
  1559.     virtual void clear();
  1560.     void str_to_flt_elem( const char* token, float& flt_elem, int& type);
  1561.     void free_train_test_idx();
  1562.     
  1563.     char delimiter;
  1564.     char miss_ch;
  1565.     //char flt_separator;
  1566.     CvMat* values;
  1567.     CvMat* missing;
  1568.     CvMat* var_types;
  1569.     CvMat* var_idx_mask;
  1570.     CvMat* response_out; // header
  1571.     CvMat* var_idx_out; // mat
  1572.     CvMat* var_types_out; // mat
  1573.     int response_idx;
  1574.     int train_sample_count;
  1575.     bool mix;
  1576.    
  1577.     int total_class_count;
  1578.     map<string, int> *class_map;
  1579.     CvMat* train_sample_idx;
  1580.     CvMat* test_sample_idx;
  1581.     int* sample_idx; // data of train_sample_idx and test_sample_idx
  1582.     CvRNG rng;
  1583. };
  1584. namespace cv
  1585. {
  1586.     
  1587. typedef CvStatModel StatModel;
  1588. typedef CvParamGrid ParamGrid;
  1589. typedef CvNormalBayesClassifier NormalBayesClassifier;
  1590. typedef CvKNearest KNearest;
  1591. typedef CvSVMParams SVMParams;
  1592. typedef CvSVMKernel SVMKernel;
  1593. typedef CvSVMSolver SVMSolver;
  1594. typedef CvSVM SVM;
  1595. typedef CvEMParams EMParams;
  1596. typedef CvEM ExpectationMaximization;
  1597. typedef CvDTreeParams DTreeParams;
  1598. typedef CvMLData TrainData;
  1599. typedef CvDTree DecisionTree;
  1600. typedef CvForestTree ForestTree;
  1601. typedef CvRTParams RandomTreeParams;
  1602. typedef CvRTrees RandomTrees;
  1603. typedef CvERTreeTrainData ERTreeTRainData;
  1604. typedef CvForestERTree ERTree;
  1605. typedef CvERTrees ERTrees;
  1606. typedef CvBoostParams BoostParams;
  1607. typedef CvBoostTree BoostTree;
  1608. typedef CvBoost Boost;
  1609. typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
  1610. typedef CvANN_MLP NeuralNet_MLP;
  1611.     
  1612. }
  1613. #endif /*__ML_H__*/
  1614. /* End of file. */