pdfbprotrain_thmt.cc
上传用户:l56789
上传日期:2022-02-25
资源大小:2422k
文件大小:3k
源码类别:

图形图像处理

开发平台:

Matlab

  1. //-----------------------------------------------------------------------------
  2. // pdfbprotrain_thmt.cc  -  Train Tying HMT from a model provided
  3. //
  4. // input arguments: - nSt - number of hidden states
  5. //                  - nLev - number of levels
  6. //                  - zm - yes for zeromean, no for nonzero mean
  7. //                  - data - input data file used to train the tree model
  8. //                           (for data file format, refer to tree.hh)
  9. //                  - mod - output file for the trained model
  10. //                          (for model file format, refer to thmt.hh)
  11. //                  - mD - minimum delta to decide convergence
  12. //-----------------------------------------------------------------------------
  13. #include <stdio.h>
  14. #include <stdlib.h>
  15. #include <string.h>
  16. #include <iostream>
  17. #include <math.h>
  18. #include "tree.hh"
  19. #include "pdfbthmt.hh"
  20. #include "mex.h"
  21. //-----------------------------------------------------------------------------
  22. void pdfbprotrain_thmt(int* levndir, tree<double>* dataTree, 
  23.        const mxArray* model, double mD, 
  24.        const mxArray* stateprob)
  25. {
  26.     // Read initial model
  27.     THMT thmt(model, levndir);
  28.     // Training
  29.     thmt.batch_train(dataTree, mD);
  30.     // Output to Struct
  31.     thmt.dump_model_struct(model);
  32.     thmt.dump_state_prob(stateprob);
  33. }
  34. //-----------------------------------------------------------------------------
  35. void pdfbprotrain_thmt(int* levndir, tree<double>* dataTree,
  36.        const mxArray* model, double mD)
  37. {
  38.     // Read initial model
  39.     THMT thmt(model, levndir);
  40.     // Training
  41.     thmt.batch_train(dataTree, mD);
  42.     // Output to Struct
  43.     thmt.dump_model_struct(model);
  44. }
  45. //-----------------------------------------------------------------------------
  46. void mexFunction( int nlhs, mxArray *plhs[], 
  47.   int nrhs, const mxArray *prhs[] )     
  48. {  
  49.     double mD, *data, *templevndir; 
  50.     int ns, nl, status, i, n, rrow, rcol, *levndir; 
  51.     tree<double>* dataTree;
  52.     const mxArray* datacell;
  53.     
  54.     /* Check for proper number of arguments */
  55.     
  56.     if ((nrhs != 6) && (nrhs != 7)) { 
  57. mexErrMsgTxt("Wrong number of arguments."); 
  58.     } else if (nlhs > 1) {
  59. mexErrMsgTxt("Too many output arguments."); 
  60.     } 
  61.     ns = (int)mxGetScalar(prhs[0]);  
  62.     nl = (int)mxGetScalar(prhs[1]);
  63.     levndir = new int[nl];
  64.     templevndir = mxGetPr(prhs[2]);
  65.     for(i = 0; i<nl; i++)
  66.       levndir[i] = (int)(templevndir[i]);
  67.     datacell= mxGetCell(prhs[3], 0);
  68.     rrow = mxGetM(datacell);
  69.     rcol = mxGetN(datacell);
  70.     if (rcol!=1)
  71.       mexErrMsgTxt("Error: Data Tree from matlab has incorrect format.");
  72.     dataTree = new tree<double>(rrow, 4, nl, 0);
  73.     for (i = 0; i < nl; i++)
  74.       {
  75. datacell = mxGetCell(prhs[3], i);
  76. data = mxGetPr(datacell);
  77.  
  78. for (n = 0; n < (*dataTree)[i].size(); n++)
  79.   (*dataTree)[i][n] = *data++;
  80.       }
  81.     mD = mxGetScalar(prhs[4]);
  82.     /* Do the actual computations in a subroutine */
  83.     if (nrhs == 7)
  84.       pdfbprotrain_thmt(levndir, dataTree, prhs[5], mD, prhs[6]);
  85.     else
  86.       pdfbprotrain_thmt(levndir, dataTree, prhs[5], mD); 
  87.     
  88.     return;
  89.     
  90. }