pdfbtrain_thmt.cc
上传用户:l56789
上传日期:2022-02-25
资源大小:2422k
文件大小:4k
源码类别:

图形图像处理

开发平台:

Matlab

  1. //-----------------------------------------------------------------------------
  2. // pdfbtrain_thmt.cc  -  Train Tying HMT from a random model
  3. //
  4. // input arguments: - nSt - number of hidden states
  5. //                  - nLev - number of levels
  6. //                  - zm - yes for zeromean, no for nonzero mean
  7. //                  - data - input data file used to train the tree model
  8. //                           (for data file format, refer to tree.hh)
  9. //                  - mod - output file for the trained model
  10. //                          (for model file format, refer to thmt.hh)
  11. //                  - mD - minimum delta to decide convergence
  12. //-----------------------------------------------------------------------------
  13. #include <stdio.h>
  14. #include <stdlib.h>
  15. #include <string.h>
  16. #include <iostream>
  17. #include <math.h>
  18. #include "tree.hh"
  19. #include "pdfbthmt.hh"
  20. #include "mex.h"
  21. //-----------------------------------------------------------------------------
  22. void pdfbtrain_thmt(int nSt, int nLev, int* levndir, bool zm, 
  23.     tree<double>* dataTree, const mxArray* model, double mD, 
  24.     const mxArray* stateprob)
  25. {
  26.    
  27.     // A random model
  28.     THMT thmt(nSt, nLev, levndir, zm);
  29.     // Training
  30.     thmt.batch_train(dataTree, mD);
  31.     // Saving
  32.     thmt.dump_model_struct(model);
  33.     thmt.dump_state_prob(stateprob);
  34. }
  35. //-----------------------------------------------------------------------------
  36. void pdfbtrain_thmt(int nSt, int nLev, int* levndir, bool zm, 
  37.     tree<double>* dataTree, const mxArray* model, double mD)
  38. {
  39.    
  40.     // A random model
  41.     THMT thmt(nSt, nLev, levndir, zm);
  42.     // Training
  43.     thmt.batch_train(dataTree, mD);
  44.     // Saving
  45.     thmt.dump_model_struct(model);
  46. }
  47. //-----------------------------------------------------------------------------
  48. void mexFunction( int nlhs, mxArray *plhs[], 
  49.   int nrhs, const mxArray *prhs[] )     
  50. {
  51.     double mD, *data, *templevndir; 
  52.     int ns, nl, status, zeromeanlen, datafilelen, modelfilelen, 
  53.       spfilelen, i, n, rrow, rcol, *levndir;
  54.     char *zeromean, *tempsp;
  55.     tree<double>* dataTree;
  56.     const mxArray *datacell;
  57.     bool zm;
  58.     
  59.     /* Check for proper number of arguments */
  60.     
  61.     if ((nrhs != 7) && (nrhs != 8)) { 
  62. mexErrMsgTxt("Wrong number of arguments."); 
  63.     } else if (nlhs > 1) {
  64. mexErrMsgTxt("Too many output arguments."); 
  65.     } 
  66.     ns = (int)mxGetScalar(prhs[0]); 
  67.     nl = (int)mxGetScalar(prhs[1]);
  68.     levndir = new int[nl];
  69.     templevndir = mxGetPr(prhs[2]);
  70.     for(i = 0; i<nl; i++)
  71.       levndir[i] = (int)(templevndir[i]);
  72.     
  73.     zeromeanlen = (mxGetM(prhs[3]) * mxGetN(prhs[3])) + 1;
  74.     zeromean = (char*)mxCalloc(zeromeanlen, sizeof(char)); 
  75.     status = mxGetString(prhs[3], zeromean, zeromeanlen);
  76.     if(status != 0) 
  77.       mexWarnMsgTxt("Not enough space. String is truncated.");
  78.     datacell = mxGetCell(prhs[4],0);
  79.     rrow = mxGetM(datacell);
  80.     rcol = mxGetN(datacell);
  81.     if (rcol!=1)
  82.       mexErrMsgTxt("Error: Data Tree from matlab has incorrect format.");
  83.     dataTree = new tree<double>(rrow, 4, nl, 0);
  84.     for (i = 0; i < nl; i++)
  85.       {
  86. datacell = mxGetCell(prhs[4], i);
  87. data = mxGetPr(datacell);
  88.  
  89. for (n = 0; n < (*dataTree)[i].size(); n++)
  90.   (*dataTree)[i][n] = *data++;
  91.       }
  92.     if (strcmp(zeromean, "yes") == 0) 
  93.       zm = true;
  94.     else if (strcmp(zeromean, "no") == 0)
  95.       zm = false;
  96.     else
  97.       mexErrMsgTxt("Either yes or no for zeromean");
  98.     mD = mxGetScalar(prhs[5]);
  99.     if (nrhs == 8) {
  100.       /* Do the actual computations in a subroutine */
  101.       pdfbtrain_thmt(ns, nl, levndir, zm, dataTree, prhs[6], mD, prhs[7]);
  102.     }
  103.     else
  104.       /* Do the actual computations in a subroutine */
  105.       pdfbtrain_thmt(ns, nl, levndir, zm, dataTree, prhs[6], mD);
  106.     return;    
  107. }