pdfbthmt.cc
上传用户:l56789
上传日期:2022-02-25
资源大小:2422k
文件大小:41k
源码类别:

图形图像处理

开发平台:

Matlab

  1. //-----------------------------------------------------------------------------
  2. // pdfbthmt.cc  -  Tying Hidden Markov Tree Models for PDFB
  3. //-----------------------------------------------------------------------------
  4. #include <stdlib.h>
  5. #include <stdio.h>
  6. #include <math.h>
  7. #include <time.h>
  8. #include <iostream>
  9. #include "pdfbthmt.hh"
  10. #include "utils.hh"
  11. #include "mex.h"
  12. using std::cerr;
  13. using std::endl;
  14. //-----------------------------------------------------------------------------
  15. THMT::THMT(int ns, int nl, int* levndir, bool zm)
  16.     :M(ns), nCh(4), nLev(nl), zeromean(zm),
  17.      model_trans(nLev, vector<matrix<double> >() ),
  18.      model_mean(nLev, vector< vector<double> >() ), 
  19.      model_stdv(nLev, vector< vector<double> >() )
  20. {
  21.     int i, J, k;
  22.     
  23.     for(i=0; i<nLev; i++){
  24.       model_trans[i] = vector< matrix<double> >((int)pow(2.0,levndir[i]
  25.  -levndir[0]), 
  26. matrix<double>(M,M));
  27.       model_mean[i] = vector< vector<double> >((int)pow(2.0,levndir[i]
  28.  -levndir[0]), 
  29.        vector<double>(M));
  30.       model_stdv[i] = vector< vector<double> >((int)pow(2.0,levndir[i]
  31.  -levndir[0]), 
  32.        vector<double>(M));
  33.     }
  34.     rnd_init_model();
  35.     // subbandtree is a tree structure used to keep track of which
  36.     // directional subband does each coefficient in a tree belong to
  37.     subbandtree = tree<int> (1, nCh, nLev, 0);
  38.     // setting up the subbandtree tree
  39.     for(J = 0; J<nLev-1; J++)
  40.       for(i=0; i<subbandtree[J].size(); i++) 
  41. if(levndir[J+1]==levndir[J])
  42.   for(k=0; k<nCh; k++)
  43.     subbandtree[J+1][i*nCh+k] = subbandtree[J][i];
  44. else
  45.   for(k=0; k<nCh; k++)
  46.     if (k<nCh/2)
  47.       subbandtree[J+1][i*nCh+k] = subbandtree[J][i]*2;
  48.     else
  49.       subbandtree[J+1][i*nCh+k] = subbandtree[J][i]*2+1;
  50. }
  51. //-----------------------------------------------------------------------------
  52. THMT::THMT(const mxArray* initmodel, int* levndir)
  53. {
  54.     register int J, MM, m, mm, n, i, k;   
  55.     double* doubleptr;
  56.     mxArray* pointer, *pointer2, *pointer3;
  57.     int numfields, numels, zmlen;
  58.     char* zm;
  59.     numfields = mxGetNumberOfFields(initmodel);
  60.     numels = mxGetNumberOfElements(initmodel);
  61.     if ((numfields != 6) && (numfields != 7))
  62.       mexErrMsgTxt("ERROR: number of fields in struct model is incorrect");
  63.     if (numels != 1)
  64.       mexErrMsgTxt("ERROR: Too many elements");
  65.     if (strcmp(mxGetFieldNameByNumber(initmodel, 0), "nstates") != 0)
  66.       mexErrMsgTxt("Field 0 has wrong name");
  67.     pointer = mxGetFieldByNumber(initmodel, 0, 0);
  68.     doubleptr = mxGetPr(pointer);
  69.     M = (int)(*doubleptr);
  70.     nCh = 4;
  71.     if (strcmp(mxGetFieldNameByNumber(initmodel, 1), "nlevels") != 0)
  72.       mexErrMsgTxt("Field 1 has wrong name");    
  73.     pointer = mxGetFieldByNumber(initmodel, 0, 1);
  74.     doubleptr = mxGetPr(pointer);
  75.     nLev = (int)*doubleptr;
  76.     if (strcmp(mxGetFieldNameByNumber(initmodel, 2), "zeromean") != 0)
  77.       mexErrMsgTxt("Field 2 has wrong name");
  78.     pointer = mxGetFieldByNumber(initmodel, 0, 2);
  79.     zmlen = mxGetM(pointer) * mxGetN(pointer) + 1; 
  80.     zm = (char*)mxCalloc(zmlen, sizeof(char));
  81.     if (mxGetString(pointer, zm, zmlen) != 0)
  82.       mexErrMsgTxt("Error: cannot get zeromean variable");
  83.     if (strcmp(zm, "yes") == 0)
  84.       zeromean = true;
  85.     else if (strcmp(zm, "no") == 0)
  86.       zeromean = false;
  87.     else
  88.       mexErrMsgTxt("Error: Zeromean can only be yes or no");
  89.     model_trans = vector<vector<matrix<double> > >(nLev, 
  90.   vector<matrix<double> >() );
  91.     model_mean = vector<vector<vector<double> > >(nLev, 
  92.  vector< vector<double> >() );
  93.     model_stdv = vector<vector<vector<double> > >(nLev, 
  94.  vector< vector<double> >() );
  95.     for(i=0; i<nLev; i++){
  96.       model_trans[i] = vector< matrix<double> >((int)pow(2.0,levndir[i]
  97.  -levndir[0]), 
  98. matrix<double>(M,M));
  99.       model_mean[i] = vector< vector<double> >((int)pow(2.0,levndir[i]
  100.  -levndir[0]), 
  101.        vector<double>(M));
  102.       model_stdv[i] = vector< vector<double> >((int)pow(2.0,levndir[i]
  103.  -levndir[0]), 
  104.        vector<double>(M));
  105.     }
  106.     if (strcmp(mxGetFieldNameByNumber(initmodel, 3), "rootprob") != 0)
  107.       mexErrMsgTxt("Field 3 has wrong name");
  108.     pointer = mxGetFieldByNumber(initmodel, 0, 3);
  109.     if (pointer == NULL) {
  110.       mexPrintf("%s%dn",
  111.      "FIELD:", 3);
  112.       mexErrMsgTxt("Above field is empty!"); 
  113.     }
  114.     doubleptr = mxGetPr(pointer);
  115.     for (m = 0; m < M; m++)
  116.          model_trans[0][0][m][0]= doubleptr[m];
  117.  
  118.     if (strcmp(mxGetFieldNameByNumber(initmodel, 4), "transprob") != 0)
  119.       mexErrMsgTxt("Field 4 has wrong name");
  120.     pointer = mxGetFieldByNumber(initmodel, 0, 4);
  121.     if (pointer == NULL) {
  122.       mexPrintf("%s%dn",
  123.      "FIELD:", 4);
  124.       mexErrMsgTxt("Above field is empty!"); 
  125.     }
  126.     // Trans probs
  127.     for (J =  1; J < nLev; J++)
  128.     { 
  129.       pointer2 = mxGetCell(pointer,J-1);
  130.       for (m = 0; m < M; m++){
  131. for (mm = 0; mm < model_trans[J].size(); mm++) {
  132.     pointer3 = mxGetCell(pointer2, mm);
  133.     doubleptr = mxGetPr(pointer3);
  134.     for (n = 0; n < M; n++)
  135.        model_trans[J][mm][m][n] = doubleptr[n*M+m];
  136. }
  137.       }
  138.     }
  139.     // Mean
  140.     if (!zeromean)
  141.     {
  142.         if (strcmp(mxGetFieldNameByNumber(initmodel, 5), "mean") != 0)
  143.   mexErrMsgTxt("Field 5 has wrong name");
  144.         pointer = mxGetFieldByNumber(initmodel, 0, 5);
  145. if (pointer == NULL) {
  146.   mexPrintf("%s%dn",
  147.     "FIELD:", 5);
  148.   mexErrMsgTxt("Above field is empty!");
  149. }
  150. for (J = 0; J < nLev; J++)
  151. {
  152.   pointer2 = mxGetCell(pointer, J);
  153.   for (mm=0; mm < model_mean[J].size(); mm++){
  154.     pointer3 = mxGetCell(pointer2, mm);
  155.     doubleptr = mxGetPr(pointer3);
  156.     for (m=0; m<M; m++)
  157.        model_mean[J][mm][m] = doubleptr[m];
  158.   }
  159. }
  160.     }
  161.     if(!zeromean){
  162.       if (strcmp(mxGetFieldNameByNumber(initmodel, 6), "stdv") != 0)
  163. mexErrMsgTxt("Field 6 has wrong name");
  164.       pointer = mxGetFieldByNumber(initmodel, 0, 6);
  165.       if (pointer == NULL) {
  166. mexPrintf("%s%dn",
  167.   "FIELD:", 6);
  168. mexErrMsgTxt("Above field is empty!");
  169.       }
  170.     }
  171.     else{
  172.       if (strcmp(mxGetFieldNameByNumber(initmodel, 5), "stdv") != 0)
  173. mexErrMsgTxt("Field 5 has wrong name");
  174.       pointer = mxGetFieldByNumber(initmodel, 0, 5);
  175.       if (pointer == NULL) {
  176. mexPrintf("%s%dn",
  177.   "FIELD:", 5);
  178. mexErrMsgTxt("Above field is empty!");
  179.       } 
  180.     }
  181.     // Standard deviation
  182.     for (J =  0; J < nLev; J++)
  183.     {
  184.       pointer2 = mxGetCell(pointer, J);
  185.       for (mm=0; mm < model_stdv[J].size(); mm++){
  186. pointer3 = mxGetCell(pointer2, mm);
  187. doubleptr = mxGetPr(pointer3);
  188. for (m = 0; m < M; m++) 
  189.     model_stdv[J][mm][m] = doubleptr[m];
  190.       }
  191.     }
  192.     subbandtree = tree<int> (1, nCh, nLev, 0);
  193.     // setting up the subbandtree tree
  194.     for(J = 0; J<nLev-1; J++)
  195.       for(i=0; i<subbandtree[J].size(); i++) 
  196. if(levndir[J+1]==levndir[J])
  197.   for(k=0; k<nCh; k++)
  198.     subbandtree[J+1][i*nCh+k] = subbandtree[J][i];
  199. else
  200.   for(k=0; k<nCh; k++)
  201.     if (k<nCh/2)
  202.       subbandtree[J+1][i*nCh+k] = subbandtree[J][i]*2;
  203.     else
  204.       subbandtree[J+1][i*nCh+k] = subbandtree[J][i]*2+1;
  205. }
  206. //-----------------------------------------------------------------------------
  207. THMT::THMT(char *filename, int* levndir, long& startpos)
  208. {
  209.     FILE *fp;
  210.     register int J, MM, m, mm, n, i, k;
  211.     char szm[10];
  212.    
  213.     fp = fopen(filename,"a+");
  214.     if (!fp) 
  215.     {
  216. cerr << "ERROR: can not open for reading " << filename << endl;
  217. exit(-1);
  218.     }
  219.     fseek(fp, startpos, SEEK_SET);
  220.     if (fscanf(fp, "nStates: %dn", &M) != 1) 
  221.     {
  222. cerr << "ERROR: problem reading nStates field of " 
  223.      << filename << endl;
  224. exit(-1);
  225.     }
  226.     nCh = 4;
  227.     if (fscanf(fp, "nLevels: %dn", &nLev) != 1) 
  228.     {
  229. cerr << "ERROR: problem reading nLevels field of " 
  230.      << filename << endl;
  231. exit(-1);
  232.     }
  233.     if (fscanf(fp, "zeroMean: %sn", &szm) != 1) 
  234.     {
  235. cerr << "ERROR: problem reading zeroMean field of " 
  236.      << filename << endl;
  237. exit(-1);
  238.     }
  239.     if (strcmp(szm, "yes") == 0)
  240. zeromean = true;
  241.     else
  242. zeromean = false;
  243.     // Allocate space for model parameters
  244.     model_trans = vector<vector<matrix<double> > >(nLev, 
  245.   vector<matrix<double> >() );
  246.     model_mean = vector<vector<vector<double> > >(nLev, 
  247.  vector< vector<double> >() );
  248.     model_stdv = vector<vector<vector<double> > >(nLev, 
  249.  vector< vector<double> >() );
  250.     for(i=0; i<nLev; i++){
  251.       model_trans[i] = vector< matrix<double> >((int)pow(2.0,levndir[i]
  252.  -levndir[0]), 
  253. matrix<double>(M,M));
  254.       model_mean[i] = vector< vector<double> >((int)pow(2.0,levndir[i]
  255.  -levndir[0]), 
  256.        vector<double>(M));
  257.       model_stdv[i] = vector< vector<double> >((int)pow(2.0,levndir[i]
  258.  -levndir[0]), 
  259.        vector<double>(M));
  260.     }
  261.     // Initial probs
  262.     fscanf(fp, "n");
  263.     for (m = 0; m < M; m++)
  264. fscanf(fp, "%lf ", &model_trans[0][0][m][0]);  
  265.     fscanf(fp, "nn");
  266.   
  267.     // Trans prob
  268.     for (J = 1; J < nLev; J++) 
  269.     {
  270.       for (m = 0; m < M; m++){
  271. for (mm = 0; mm < model_trans[J].size(); mm++)
  272.   {
  273.     for (n = 0; n < M; n++)
  274.       fscanf(fp, "%lf ", &model_trans[J][mm][m][n]);
  275.   }
  276. fscanf(fp, "n");
  277.       }
  278.     }
  279.     fscanf(fp, "n");
  280.     // Mean
  281.     if (zeromean)
  282.     {
  283.         for (J = 0; J < nLev; J++){
  284.   for (mm=0; mm<model_mean[J].size(); mm++)
  285.     for (m = 0; m < M; m++) 
  286. model_mean[J][mm][m] = 0.0;
  287. }
  288.     }
  289.     else
  290.     {
  291. for (J = 0; J < nLev; J++)
  292. {
  293.   for (mm=0; mm<model_mean[J].size(); mm++)
  294.     for (m = 0; m < M; m++) 
  295. fscanf(fp, "%lf ", &model_mean[J][mm][m]);
  296.     
  297.   fscanf(fp, "n");
  298. }
  299. fscanf(fp, "n");
  300.     }
  301.     // Standard deviation
  302.     for (J=0; J < nLev; J++)
  303.     {
  304.       for (mm=0; mm<model_stdv[J].size(); mm++)
  305. for (m = 0; m < M; m++) 
  306.     fscanf(fp, "%lf ", &model_stdv[J][mm][m]);
  307.   
  308.       fscanf(fp, "n");
  309.     }
  310.     fscanf(fp, "n");
  311.     startpos = ftell(fp);
  312.    
  313.     fclose(fp);
  314.     subbandtree = tree<int> (1, nCh, nLev, 0);
  315.     // setting up the subbandtree tree
  316.     for(J = 1; J<nLev-1; J++)
  317.       for(i=0; i<subbandtree[J].size(); i++) 
  318. if(J%2==0)
  319.   for(k=0; k<nCh; k++)
  320.     subbandtree[J+1][i*nCh+k] = subbandtree[J][i];
  321. else
  322.   for(k=0; k<nCh; k++)
  323.     if (k<nCh/2)
  324.       subbandtree[J+1][i*nCh+k] = subbandtree[J][i]*2;
  325.     else
  326.       subbandtree[J+1][i*nCh+k] = subbandtree[J][i]*2+1;
  327. }
  328. //-----------------------------------------------------------------------------
  329. void THMT::allocate_training()
  330. {
  331.     int i, J, k;
  332.     if (alpha.nlev() == 0) // avoid re-allocate
  333.       {
  334. alpha = 
  335.     tree< vector<double> > (1, nCh, nLev, vector<double>(M));
  336. beta = beta_par = 
  337.     tree< vector<double> > (1, nCh, nLev, vector<double>(M));
  338. state_prob = tree< vector<double> > (nObs, nCh, nLev, 
  339.      vector<double>(M));
  340. scaling = tree<double> (1, nCh, nLev);
  341. sum_prob = sum_mean = sum_stdv = 
  342.     vector< vector< vector<double> > >(nLev, vector< vector<double> >());
  343. sum_trans = 
  344.   vector< vector< matrix<double> > > 
  345.   (nLev, vector<matrix<double> >());
  346. for(i=0; i<nLev; i++) {
  347.     sum_prob[i] = vector< vector<double> >(model_stdv[i].size(), 
  348.    vector<double>(M));
  349.     sum_mean[i] = vector< vector<double> >(model_stdv[i].size(), 
  350.    vector<double>(M));
  351.     sum_stdv[i] = vector< vector<double> >(model_stdv[i].size(), 
  352.    vector<double>(M));
  353.     sum_trans[i] = vector< matrix<double> >(model_stdv[i].size(), 
  354.     matrix<double>(M, M)); 
  355. }
  356.     }    
  357. }  
  358. //-----------------------------------------------------------------------------
  359. void THMT::allocate_testing()
  360. {
  361.     int i, J, k;
  362.     if (beta.nlev() == 0) // avoid re-allocate
  363.     {
  364. beta = beta_par = 
  365.     tree< vector<double> > (1, nCh, nLev, vector<double>(M));
  366. scaling = tree<double> (1, nCh, nLev);
  367.     }
  368. }
  369. //-----------------------------------------------------------------------------
  370. void THMT::compute_beta(int ob_ind)
  371. {
  372.     register int J, nNode, i, m, mm, n, c;
  373.     register double o;
  374.     register double sum, prod;
  375.   
  376.     // Initialization of beta down-tree
  377.     J = nLev - 1; // finest scale
  378.     for (i = 0, nNode = beta[J].size(); i < nNode; i++)
  379.     {
  380. // o_{J,i}^{ob_ind}
  381. o = (*obs)[J][ob_ind * ipow(nCh,J) + i];
  382. mm = subbandtree[J][i];
  383. for (m = 0; m < M; m++) 
  384.   beta[J][i][m] = compute_g(o, model_mean[J][mm][m], 
  385.       model_stdv[J][mm][m]);
  386.     } // end initialization
  387.     // Rescale beta
  388.     rescale(J);
  389.     while (J > 0) 
  390.     {
  391. // Compute $beta_{i,p(i)}$
  392. for (i = 0, nNode = beta_par[J].size(); i < nNode; i++) 
  393. {   
  394.   mm = subbandtree[J][i];
  395.   for (m = 0; m < M; m++) 
  396.     {
  397.       sum = 0.0;
  398.       for (n = 0; n < M; n++)
  399. sum += model_trans[J][mm][n][m] * beta[J][i][n];
  400.       beta_par[J][i][m] = sum;
  401.     }
  402. }
  403. // Compute $beta_{p(i)}(m)$
  404. J--;
  405. for (i = 0, nNode = beta[J].size(); i < nNode; i++) 
  406. {
  407.     // o_{J,i}^{ob_ind}
  408.     o = (*obs)[J][ob_ind * ipow(nCh,J) + i];
  409.     mm = subbandtree[J][i];
  410.     for (m = 0; m < M; m++) 
  411.     {
  412. prod = compute_g(o, model_mean[J][mm][m], 
  413.  model_stdv[J][mm][m]);
  414. for (c = 0; c < nCh; c++)  // look for child of $p(i)$
  415.     prod *= beta_par[J+1][i*nCh+c][m];
  416. beta[J][i][m] = prod;
  417.     }
  418. }
  419. // Rescale $beta_{p(i)}(m)$
  420. rescale(J);
  421.     }
  422. }
  423. //-----------------------------------------------------------------------------
  424. void THMT::rescale(int J)
  425. {
  426.     register int i, m, nNode = beta[J].size();
  427.     register double sum;
  428.     // Rescale for each node in this level (J)
  429.     for (i = 0; i < nNode; i++) 
  430.     {
  431. sum = 0.0;
  432. for (m = 0; m < M; m++) 
  433.     sum += beta[J][i][m];
  434.     
  435. scaling[J][i] = sum;
  436. for (m = 0; m < M; m++) 
  437.     beta[J][i][m] /= sum;
  438.     }
  439. }
  440. //-----------------------------------------------------------------------------
  441. void THMT::compute_alpha()
  442. {
  443.     register int J, i, m, mm, n, nNode;
  444.     register double sum;
  445.     // Initialize the coarsest level
  446.     for (m = 0; m < M; m++)
  447. alpha[0][0][m] = model_trans[0][0][m][0];
  448.     for (J = 1; J < nLev; J++) 
  449.     {
  450. for (i = 0, nNode = alpha[J].size(); i < nNode; i++) 
  451. {
  452.   mm = subbandtree[J][i];
  453.   for (m = 0; m < M; m++) 
  454.     {
  455.       sum = 0.0;
  456.       for (n = 0; n < M; n++) 
  457. sum += model_trans[J][mm][m][n] *
  458.   alpha[J-1][i/nCh][n] *
  459.   beta[J-1][i/nCh][n] /
  460.   beta_par[J][i][n];
  461. alpha[J][i][m] = sum;
  462.     } // end for m
  463. } // end for i
  464.     } // end for J
  465. }
  466. //-----------------------------------------------------------------------------
  467. double THMT::compute_likelihood()
  468. {
  469.     register int J, i, nNode, m;
  470.     register double f;
  471.     register double log_scale;
  472.     for (m = 0, f = 0.0; m < M; m++)
  473. f += beta[0][0][m] * model_trans[0][0][m][0];
  474.     // Re-scale back using saved scaling factors
  475.     for (J = 0, log_scale = 0.0; J < nLev; J++)
  476. for (i = 0, nNode = scaling[J].size(); i < nNode; i++) 
  477.     log_scale += log(scaling[J][i]);
  478.     // Final result
  479.     return (log_scale + log(f));
  480. }
  481. //-----------------------------------------------------------------------------
  482. void THMT::update_model(double& delta, double& avf)
  483. {
  484.     register int J, MM, i, m, mm, n, nNode;
  485.     register int ob_ind;           // observation index
  486.     register double o, denom, prob1, prob2, newval;
  487.     // Reinitialize
  488.     delta = avf = 0.0;
  489.     for (J = 0; J < nLev; J++)
  490.       for (mm=0; mm<sum_prob[J].size(); mm++)
  491. for (m = 0; m < M; m++) 
  492. {
  493.     sum_prob[J][mm][m] = 0.0;
  494.     sum_mean[J][mm][m] = 0.0;
  495.     sum_stdv[J][mm][m] = 0.0;
  496.     for (n = 0; n < M; n++)
  497.       sum_trans[J][mm][m][n] = 0.0;
  498. }
  499.   
  500.     // For each training tree
  501.     for (ob_ind = 0; ob_ind < nObs; ob_ind++) 
  502.     {
  503. // Compute $alpha, beta$
  504. compute_beta(ob_ind);
  505. compute_alpha();
  506. // Denominator for restimated probs
  507. denom = 0.0;
  508. for (m = 0; m < M; m++)
  509.     denom += beta[0][0][m] * alpha[0][0][m];
  510. // Compute state probabilities for denoising
  511. for (J = 0; J < nLev; J++)
  512.     for (i = 0, nNode = alpha[J].size(); i < nNode; i++)
  513.       for (m = 0; m < M; m++)
  514. {
  515.   state_prob[J][ob_ind*ipow(nCh,J)+i][m] = 
  516.     alpha[J][i][m]*beta[J][i][m]/denom;
  517. }
  518. // Update total log-likelihood
  519. avf += compute_likelihood();
  520. for (J = 0; J < nLev; J++)
  521.             for (i = 0, nNode = alpha[J].size(); i < nNode; i++) 
  522.            {
  523.         // o_{J,i}^{ob_ind}
  524.         o = (*obs)[J][ob_ind * ipow(nCh,J) + i];
  525.         for (m = 0; m < M; m++) 
  526.      {
  527.     // Compute $prob(S_i | O)$
  528.             prob1 = alpha[J][i][m] * beta[J][i][m]
  529.         / denom;
  530.     mm = subbandtree[J][i];
  531.     // Summing for all trees
  532.             sum_prob[J][mm][m] += prob1;
  533.             sum_mean[J][mm][m] += o * prob1;
  534.             sum_stdv[J][mm][m] += (o - model_mean[J][mm][m]) *
  535.         (o - model_mean[J][mm][m]) *
  536.         prob1;
  537.         }
  538.            }
  539. for (J = 1; J < nLev; J++)
  540.     for (i = 0, nNode = alpha[J].size(); i < nNode; i++)
  541. for (m = 0; m < M; m++)
  542.     for (n = 0; n < M; n++) 
  543.     {
  544.         mm = subbandtree[J][i];
  545. // Compute $prob(S_i, S_{p(i) | O)$
  546. prob2 = beta[J][i][m] *
  547.   model_trans[J][mm][m][n] *
  548.   alpha[J-1][i/nCh][n] *
  549.   beta[J-1][i/nCh][n] /
  550.   beta_par[J][i][n]
  551.   / denom;
  552. sum_trans[J][mm][m][n] += prob2;
  553.     }
  554.     } // end ob_ind
  555.     // Average log-likelihood
  556.     avf = avf / double(nObs);
  557.     // Normalize and update model parameters
  558.     for (J = 0; J < nLev; J++)
  559.       for (mm = 0; mm < model_stdv[J].size(); mm++)
  560. for (m = 0; m < M; m++) 
  561. {
  562.     if (!zeromean) // Only update means for non-zeromean model
  563.     {
  564. newval = sum_mean[J][mm][m] / sum_prob[J][mm][m];
  565. delta += fabs(newval - model_mean[J][mm][m]);
  566. model_mean[J][mm][m] = newval;
  567.     }
  568.     newval = sqrt(sum_stdv[J][mm][m] / sum_prob[J][mm][m]);
  569.     delta += fabs(newval - model_stdv[J][mm][m]);
  570.     model_stdv[J][mm][m] = newval;
  571. }
  572.     //state probs 
  573.     for (m = 0; m < M; m++) 
  574.     {
  575. newval = sum_prob[0][0][m] / nObs;
  576. delta += fabs(newval - model_trans[0][0][m][0]);
  577. model_trans[0][0][m][0] = newval;
  578.     }
  579.     // And transition probs 
  580.     for (J = 1; J < nLev; J++){
  581.       for (mm=0; mm < model_trans[J].size(); mm++)
  582. for (m = 0; m < M; m++)
  583.     for (n = 0; n < M; n++) 
  584.     {
  585.       if (sum_prob[J].size()>sum_prob[J-1].size())
  586. newval = sum_trans[J][mm][m][n] / sum_prob[J-1][mm/2][n]
  587.   /(nCh/2);
  588.       else
  589. newval =sum_trans[J][mm][m][n] / sum_prob[J-1][mm][n]/nCh;
  590.       delta += fabs(newval - model_trans[J][mm][m][n]);
  591.       model_trans[J][mm][m][n] = newval;
  592.     }
  593.     }
  594. }
  595. //-----------------------------------------------------------------------------
  596. void THMT::reorder_model()
  597. {
  598.   int level, dir, state, state2, largeststate, k;
  599.   double largeststdv, tempdouble;
  600.   // for each node
  601.   for(level = 0; level < nLev; level++){
  602.     for(dir = 0; dir < model_stdv[level].size(); dir++) {
  603.       for(state=0; state<M-1; state++) {
  604. // initialize
  605. largeststdv = -1;
  606. largeststate = M;
  607. for(state2=state; state2<M; state2++){
  608.   // search for the state with the largest standard deviation
  609.   if (model_stdv[level][dir][state2]>largeststdv){
  610.     largeststdv = model_stdv[level][dir][state2];
  611.     largeststate = state2;
  612.   }
  613. }
  614. // if the current state is not the state with the largest stdv, then
  615. // need to swap the order
  616. if(largeststate != state){
  617.   // swap the order of the stdv
  618.   model_stdv[level][dir][largeststate] = model_stdv[level][dir][state];
  619.   model_stdv[level][dir][state] = largeststdv;
  620.   // swap the order of the mean
  621.   if (!zeromean) {
  622.     tempdouble = model_mean[level][dir][largeststate];
  623.     model_mean[level][dir][largeststate] = model_mean[level][dir]
  624.       [state];
  625.     model_mean[level][dir][state] = tempdouble;
  626.   }
  627.   // swap the order of the transition matrix with parent
  628.   for(k=0; k<M; k++){
  629.     tempdouble = model_trans[level][dir][largeststate][k];
  630.     model_trans[level][dir][largeststate][k] = model_trans[level][dir]
  631.       [state][k];
  632.     model_trans[level][dir][state][k] = tempdouble;
  633.   }
  634.   // swap the order of the transition matrix with children if this
  635.   // is not a leaf node
  636.   if(level != nLev-1){
  637.     if(model_trans[level].size()==model_trans[level+1].size())
  638.       for(k=0; k<M; k++){
  639. tempdouble = model_trans[level+1][dir][k][largeststate];
  640. model_trans[level+1][dir][k][largeststate] = 
  641.   model_trans[level+1][dir][k][state];
  642. model_trans[level+1][dir][k][state] = tempdouble;
  643.       }
  644.     else if (model_trans[level].size()*2==model_trans[level+1].size())
  645.       for (k=0; k<M; k++){
  646. tempdouble =model_trans[level+1][2*dir][k][largeststate];
  647. model_trans[level+1][2*dir][k][largeststate] = 
  648.   model_trans[level+1][2*dir][k][state];
  649. model_trans[level+1][2*dir][k][state] = tempdouble;
  650. tempdouble =model_trans[level+1][2*dir+1][k][largeststate];
  651. model_trans[level+1][2*dir+1][k][largeststate] = 
  652.   model_trans[level+1][2*dir+1][k][state];
  653. model_trans[level+1][2*dir+1][k][state] = tempdouble;
  654.       } 
  655.   }
  656. }
  657.       }
  658.     }   
  659.   }
  660. }
  661. //-----------------------------------------------------------------------------
  662. void THMT::batch_train(tree<double> *trainTree, double min_delta)
  663. {
  664.   
  665.     // Assign data pointer
  666.     obs = trainTree;
  667.     // Train HMT model
  668.     train_all(min_delta);
  669. }
  670. //-----------------------------------------------------------------------------
  671. void THMT::train_all(double min_delta)
  672. {
  673.     register int count = 0, J, i, m, nNode;
  674.     register double delta = min_delta;
  675.     register double avf;
  676.     register double last_avf = -10e6;
  677.  
  678.     if (obs->nlev() == 0)
  679.       mexErrMsgTxt("ERROR in THMT::train_all(): empty training data");
  680.     if ((obs->nlev() != nLev) || (obs->nch() != nCh))
  681.       mexErrMsgTxt("ERROR in THMT::train_all(): incompatible training data");
  682.     nObs = obs->nrt();
  683.     // Allocate space for training
  684.     allocate_training();
  685.     while ((delta >= min_delta) && (count++ <= MAX_ITR)) 
  686.     {
  687. update_model(delta, avf);
  688. //if (avf < last_avf){
  689. //  mexWarnMsgTxt("WARNING: Log-likelihood decreases in training!");
  690. //  break;
  691. //}
  692. last_avf = avf;
  693. //#ifdef DEBUG
  694. //mexPrintf("count = %dndelta = %fnavf = %fn", count, delta, avf);
  695. //#endif
  696.     }
  697.     // change the model so that the state 1 always has the largest variance
  698.     // state 2 the second, and so on.....
  699.     //reorder_model();
  700.     mexPrintf("done batch-train:ncount = %dndelta = %fnavf = %fn", 
  701.       count, delta, avf);
  702. }
  703. //-----------------------------------------------------------------------------
  704. double THMT::batch_test(tree<double> *testTree)
  705. {
  706.     // Assign data pointer
  707.   obs = testTree;
  708.  
  709.   // Compute average log-likelihood
  710.   return test_all();
  711. }
  712. //-----------------------------------------------------------------------------
  713. double THMT::batch_test(char *filename)
  714. {
  715.     // Read data from file
  716.     obs = new tree<double> (filename);
  717.     // Compute average log-likelihood
  718.     double avf = test_all();
  719.     // Delete data
  720.     delete obs;
  721.     return avf;
  722. }
  723. //-----------------------------------------------------------------------------
  724. double THMT::test_all()
  725. {
  726.     register int ob_ind;
  727.     register double f, sumf = 0.0;
  728.     
  729.     
  730.     if (obs->nlev() == 0)
  731.     {
  732.       //cerr << "ERROR in THMT::test_all(): empty data"
  733.       //     << endl;
  734. return 0.0;
  735.     }
  736.     
  737.    
  738.     if ((obs->nlev() != nLev) || (obs->nch() != nCh))
  739.     {
  740.       //cerr << "ERROR in THMT::test_all(): incompatible data"
  741.       //      << endl;
  742. return 0.0;
  743.     }
  744.   
  745.     nObs = obs->nrt();
  746.   
  747.     // Allocate space for training
  748.     allocate_testing();
  749.     for (ob_ind = 0; ob_ind < nObs; ob_ind++) 
  750.     {
  751. compute_beta(ob_ind);
  752. f = compute_likelihood();
  753. //#ifdef DEBUG
  754. //cout << "ob_ind = " << ob_ind << "tf = " << f << endl;
  755. //#endif
  756. sumf += f;
  757.     }
  758.     return (sumf / double(nObs));
  759. }
  760. //-----------------------------------------------------------------------------
  761. void THMT::rnd_init_model()
  762. {
  763.     const double MAX_MEAN = 100.0;
  764.     const double MAX_STDV = 100.0;
  765.     register int J, MM, m, mm, n;
  766.     double temp;
  767.     vector<double> vprob(M);
  768.     int idum = -time(NULL);    // random seed
  769.     for (J = 0; J < nLev; J++){
  770.       for (mm=0; mm<model_mean[J].size(); mm++)
  771. for (m = 0; m < M; m++) 
  772.   {
  773.     if (zeromean)
  774.       model_mean[J][mm][m] = 0.0;
  775.     else
  776.       model_mean[J][mm][m] = (2.0*ran1(idum) - 1.0) * MAX_MEAN;
  777.     model_stdv[J][mm][m]  = ran1(idum) * MAX_STDV;
  778.     ranprobs(vprob, idum);
  779.     for (n = 0; n < M; n++)
  780.       model_trans[J][mm][n][m] = vprob[n];
  781.   }
  782.     }
  783.     
  784. }
  785. //-----------------------------------------------------------------------------
  786. void THMT::generate_data(tree<double>&obs, int n)
  787. {
  788.   tree<double> aTree(1, nCh, nLev);
  789.   register int ob_ind, J;
  790.   int idum = -time(NULL);     // random seed
  791.   // Resize output tree if necessary
  792.   if ((obs.nlev() != nLev) || (obs.nch() != nCh) || (obs.nrt() != n))
  793.     obs = tree<double>(n, nCh, nLev);
  794.   
  795.   for (ob_ind = 0; ob_ind < n; ob_ind++) {
  796.     generate_one(aTree, idum);
  797.     
  798.     for (J = 0; J < nLev; J++)
  799.       copy(aTree[J].begin(), aTree[J].end(),
  800.    obs[J].begin() + ob_ind * ipow(nCh,J));
  801.   }
  802. }
  803. //-----------------------------------------------------------------------------
  804. void THMT::generate_data(char *filename, int n)
  805. {
  806.     tree<double> genTree(n, nCh, nLev);
  807.     tree<double> aTree(1, nCh, nLev);
  808.     register int ob_ind, J;
  809.     int idum = -time(NULL);     // random seed
  810.     for (ob_ind = 0; ob_ind < n; ob_ind++) {
  811. generate_one(aTree, idum);
  812. for (J = 0; J < nLev; J++)
  813.     copy(aTree[J].begin(), aTree[J].end(),
  814.  genTree[J].begin() + ob_ind * ipow(nCh,J));
  815.     }
  816.     genTree.dump(filename);
  817. }
  818. //-----------------------------------------------------------------------------
  819. void THMT::generate_one(tree<double> &aTree, int& idum)
  820. {
  821.     tree<int> states(1, nCh, nLev);
  822.     vector<double> vprob(M);
  823.     register int J, i, nNode, m, mm;
  824.     register double mean, stdv;
  825.     // Initial state
  826.     for (m = 0; m < M; m++)
  827. vprob[m] = model_trans[0][0][m][0];
  828.     
  829.     states[0][0] = ranind(vprob, idum);
  830.     mean = model_mean[0][0][states[0][0]];
  831.     stdv = model_stdv[0][0][states[0][0]];
  832.     aTree[0][0] = mean + stdv * rangas(idum);
  833.   
  834.     // All others
  835.     for (J = 1; J < nLev; J++) 
  836.     {
  837. for (i = 0, nNode = aTree[J].size(); i < nNode; i++) 
  838. {
  839.   mm = subbandtree[J][i];
  840.   // build vector prob
  841.   for (m = 0; m < M; m++)
  842.     vprob[m] = model_trans[J][mm][m][states[J-1][i/nCh]];
  843.   states[J][i] = ranind(vprob, idum);
  844.   mean = model_mean[J][mm][states[J][i]];
  845.   stdv = model_stdv[J][mm][states[J][i]];
  846.   aTree[J][i] = mean + stdv * rangas(idum);
  847. }
  848.     }
  849. }
  850. //-----------------------------------------------------------------------------
  851. void THMT::generate_one(tree<double> &aTree, int& idum, double initval)
  852. {
  853.     tree<int> states(1, nCh, nLev);
  854.     vector<double> vprob(M);
  855.     register int J, i, nNode, m,mm;
  856.     register double mean, stdv;
  857.     aTree[0][0] = initval;
  858.     /********* WRONG!!!
  859.     // Find initial states from given initial value
  860.     for (m = 0; m < M; m++)
  861. vprob[m] = model_trans[0][m][0] *
  862.     compute_g(initval, model_mean[0][m], model_stdv[0][m]);
  863.     // It is the state with highest probability
  864.     states[0][0] = 0;
  865.     double maxprob = vprob[0];
  866.     for (m = 1; m < M; m++)
  867.     {
  868. if (vprob[m] > maxprob)
  869. {
  870.     states[0][0] = m;
  871.     maxprob = vprob[m];
  872. }
  873.     }
  874.     **********/
  875.     /***** HACK (before the end of the Millennium!!!) *****/
  876.     if (M != 2)
  877.       mexErrMsgTxt("Only works for 2 states");
  878.     int smallState = (model_stdv[0][0][0] < model_stdv[0][0][1]) ? 0 : 1;
  879.     // Cumunative probability
  880.     double cumprob = 0.0;
  881.     for (m = 0; m < M; m++)
  882. cumprob += model_trans[0][0][m][0] *
  883.     Psi(fabs(initval - model_mean[0][0][m]) / model_stdv[0][0][m]);
  884.     if (cumprob > (0.5 + 0.5 * model_trans[0][0][smallState][0]))
  885. states[0][0] = smallState;
  886.     else
  887. states[0][0] = 1 - smallState; // largeState
  888.     /***** END HACK *****/
  889.   
  890.     // All others
  891.     for (J = 1; J < nLev; J++) 
  892.     {
  893. for (i = 0, nNode = aTree[J].size(); i < nNode; i++) 
  894. {
  895.   mm = subbandtree[J][i];
  896.   // build vector prob
  897.   for (m = 0; m < M; m++)
  898.     vprob[m] = model_trans[J][mm][m][states[J-1][i/nCh]];
  899.   states[J][i] = ranind(vprob, idum);
  900.   mean = model_mean[J][mm][states[J][i]];
  901.   stdv = model_stdv[J][mm][states[J][i]];
  902.   aTree[J][i] = mean + stdv * rangas(idum);
  903. }
  904.     }
  905. }
  906. //-----------------------------------------------------------------------------
  907. tree<double>* THMT::denoise(double nvar, tree<double>* source, 
  908.     const mxArray* stateprob)
  909. {
  910.   double temp = 0, *doubleptr;
  911.   int J, i, m, mm, nNode;
  912.   mxArray* stateprobcell;
  913.   // Read data from file
  914.   obs = new tree<double> (*source);
  915.   nObs = obs->nrt();
  916.   state_prob = tree< vector<double> >(nObs, nCh, nLev, 
  917.       vector<double>(M));
  918.   for (J=0; J<nLev; J++){
  919.     stateprobcell = mxGetCell(stateprob, J);
  920.     doubleptr = mxGetPr(stateprobcell);
  921.     for (i = 0, nNode = state_prob[J].size(); i < nNode; i++) {
  922.       for (m = 0; m < M; m++){
  923. state_prob[J][i][m] = doubleptr[m*nNode+i];
  924.       }
  925.     }
  926.   }
  927.  
  928.   for (J = 0; J < nLev; J++) {
  929.     for (i = 0, nNode = state_prob[J].size(); i < nNode; i++) {
  930.       mm = subbandtree[J][i%ipow(nCh,J)];
  931.       for (m = 0; m < M; m++){
  932. temp += state_prob[J][i][m]*model_stdv[J][mm][m]*model_stdv[J][mm][m]
  933.   /(nvar+model_stdv[J][mm][m]*model_stdv[J][mm][m]);
  934.       }
  935.       (*obs)[J][i] = (*obs)[J][i]*temp;
  936.       temp = 0;
  937.     }
  938.   }
  939.   return obs;
  940. }
  941. //-----------------------------------------------------------------------------
  942. void THMT::dump_model(char *filename)
  943. {
  944.     FILE *fp;
  945.     register int J, MM, m, mm, n;
  946.     fp = fopen(filename,"a");
  947.     if (!fp) 
  948.       mexErrMsgTxt("ERROR: can not open for writing");
  949.     fprintf(fp, "nStates: %dn", M);
  950.     fprintf(fp, "nLevels: %dn", nLev);
  951.     if (zeromean)
  952. fprintf(fp, "zeroMean: yesn");
  953.     else 
  954. fprintf(fp, "zeroMean: non");
  955.     // Initial probs
  956.     fprintf(fp, "n");
  957.     for (m = 0; m < M; m++)
  958. fprintf(fp, "%f ", model_trans[0][0][m][0]);
  959.     fprintf(fp, "nn");
  960.     // Trans probs
  961.     for (J =  1; J < nLev; J++)
  962.     { 
  963.       for (m = 0; m < M; m++){
  964. for (mm = 0; mm < model_trans[J].size(); mm++) 
  965.     for (n = 0; n < M; n++)
  966.       fprintf(fp, "%f ", model_trans[J][mm][m][n]);
  967. fprintf(fp, "n");
  968.       }
  969.       fprintf(fp, "n");
  970.     }
  971.     fprintf(fp, "n");
  972.     // Mean
  973.     if (!zeromean){
  974. for (J = 0; J < nLev; J++)
  975. {
  976.   for (mm=0; mm < model_mean[J].size(); mm++)
  977.     for (m=0; m<M; m++)
  978.       fprintf(fp, "%f ", model_mean[J][mm][m]);
  979.   fprintf(fp, "n");
  980. }
  981. fprintf(fp, "n");
  982.     }
  983.     // Standard deviation
  984.     for (J =  0; J < nLev; J++)
  985.     {
  986.       for (mm=0; mm < model_stdv[J].size(); mm++)
  987. for (m = 0; m < M; m++) 
  988.     fprintf(fp, "%f ", model_stdv[J][mm][m]);
  989.   
  990.       fprintf(fp, "n");
  991.     }
  992.     fprintf(fp, "n");
  993.     fclose(fp);
  994. }
  995. //-----------------------------------------------------------------------------
  996. void THMT::dump_model_struct(const mxArray* model)
  997. {
  998.     double* doubleptr;
  999.     mxArray* pointer, *pointer2, *pointer3, *assignptr;
  1000.     int numfields, numels;
  1001.     register int J, MM, m, mm, n;
  1002.     numfields = mxGetNumberOfFields(model);
  1003.     numels = mxGetNumberOfElements(model);
  1004.     if (((numfields != 6) && (zeromean)) || 
  1005. ((numfields != 7) && (!zeromean)))
  1006.       mexErrMsgTxt("ERROR: number of fields in struct model is incorrect");
  1007.     if (numels != 1)
  1008.       mexErrMsgTxt("ERROR: Too many elements");
  1009.     if (strcmp(mxGetFieldNameByNumber(model, 0), "nstates") != 0)
  1010.       mexErrMsgTxt("Field 0 has wrong name");
  1011.     pointer = mxGetFieldByNumber(model, 0, 0);
  1012.     doubleptr = mxGetPr(pointer);
  1013.     *doubleptr = (double)M;
  1014.     if (strcmp(mxGetFieldNameByNumber(model, 1), "nlevels") != 0)
  1015.       mexErrMsgTxt("Field 1 has wrong name");    
  1016.     pointer = mxGetFieldByNumber(model, 0, 1);
  1017.     doubleptr = mxGetPr(pointer);
  1018.     *doubleptr = (double)nLev;
  1019.     if (strcmp(mxGetFieldNameByNumber(model, 2), "zeromean") != 0)
  1020.       mexErrMsgTxt("Field 2 has wrong name");
  1021.     if (zeromean){
  1022.         assignptr = mxCreateString("yes");
  1023.         mxSetFieldByNumber((mxArray*)model, 0, 2, assignptr);
  1024.     }
  1025.     else {
  1026.         assignptr = mxCreateString("no");
  1027.         mxSetFieldByNumber((mxArray*)model, 0, 2, assignptr);
  1028.     }
  1029.     if (strcmp(mxGetFieldNameByNumber(model, 3), "rootprob") != 0)
  1030.       mexErrMsgTxt("Field 3 has wrong name");
  1031.     pointer = mxGetFieldByNumber(model, 0, 3);
  1032.     if (pointer == NULL) {
  1033.       mexPrintf("%s%dn",
  1034.      "FIELD:", 3);
  1035.       mexErrMsgTxt("Above field is empty!"); 
  1036.     }
  1037.     doubleptr = mxGetPr(pointer);
  1038.     for (m = 0; m < M; m++)
  1039.         doubleptr[m] = model_trans[0][0][m][0];
  1040.  
  1041.     if (strcmp(mxGetFieldNameByNumber(model, 4), "transprob") != 0)
  1042.       mexErrMsgTxt("Field 4 has wrong name");
  1043.     pointer = mxGetFieldByNumber(model, 0, 4);
  1044.     if (pointer == NULL) {
  1045.       mexPrintf("%s%dn",
  1046.      "FIELD:", 4);
  1047.       mexErrMsgTxt("Above field is empty!"); 
  1048.     }
  1049.     // Trans probs
  1050.     for (J =  1; J < nLev; J++)
  1051.     { 
  1052.       pointer2 = mxGetCell(pointer,J-1);
  1053.       for (m = 0; m < M; m++){
  1054. for (mm = 0; mm < model_trans[J].size(); mm++) {
  1055.     pointer3 = mxGetCell(pointer2, mm);
  1056.     doubleptr = mxGetPr(pointer3);
  1057.     for (n = 0; n < M; n++)
  1058.       doubleptr[n*M+m] = model_trans[J][mm][m][n];
  1059. }
  1060.       }
  1061.     }
  1062.     // Mean
  1063.     if (!zeromean)
  1064.     {
  1065.         if (strcmp(mxGetFieldNameByNumber(model, 5), "mean") != 0)
  1066.   mexErrMsgTxt("Field 5 has wrong name");
  1067.         pointer = mxGetFieldByNumber(model, 0, 5);
  1068. if (pointer == NULL) {
  1069.   mexPrintf("%s%dn",
  1070.     "FIELD:", 5);
  1071.   mexErrMsgTxt("Above field is empty!");
  1072. }
  1073. for (J = 0; J < nLev; J++)
  1074. {
  1075.   pointer2 = mxGetCell(pointer, J);
  1076.   for (mm=0; mm < model_mean[J].size(); mm++){
  1077.     pointer3 = mxGetCell(pointer2, mm);
  1078.     doubleptr = mxGetPr(pointer3);
  1079.     for (m=0; m<M; m++)
  1080.       doubleptr[m] = model_mean[J][mm][m];
  1081.   }
  1082. }
  1083.     }
  1084.     if(!zeromean){
  1085.       if (strcmp(mxGetFieldNameByNumber(model, 6), "stdv") != 0)
  1086. mexErrMsgTxt("Field 6 has wrong name");
  1087.       pointer = mxGetFieldByNumber(model, 0, 6);
  1088.       if (pointer == NULL) {
  1089. mexPrintf("%s%dn",
  1090.   "FIELD:", 6);
  1091. mexErrMsgTxt("Above field is empty!");
  1092.       }
  1093.     }
  1094.     else{
  1095.       if (strcmp(mxGetFieldNameByNumber(model, 5), "stdv") != 0)
  1096. mexErrMsgTxt("Field 5 has wrong name");
  1097.       pointer = mxGetFieldByNumber(model, 0, 5);
  1098.       if (pointer == NULL) {
  1099. mexPrintf("%s%dn",
  1100.   "FIELD:", 5);
  1101. mexErrMsgTxt("Above field is empty!");
  1102.       } 
  1103.     }
  1104.     // Standard deviation
  1105.     for (J =  0; J < nLev; J++)
  1106.     {
  1107.       pointer2 = mxGetCell(pointer, J);
  1108.       for (mm=0; mm < model_stdv[J].size(); mm++){
  1109. pointer3 = mxGetCell(pointer2, mm);
  1110. doubleptr = mxGetPr(pointer3);
  1111. for (m = 0; m < M; m++) 
  1112.     doubleptr[m] = model_stdv[J][mm][m];
  1113.       }
  1114.     }
  1115. }
  1116. //-----------------------------------------------------------------------------
  1117. void THMT::dump_state_prob(const mxArray* stateprob)
  1118. {
  1119.     int i, J, m, nNode;
  1120.     mxArray* tempptr;
  1121.     double* dataptr;
  1122.     for (J=0; J<nLev; J++){
  1123.       tempptr = mxGetCell(stateprob, J);
  1124.       dataptr = mxGetPr(tempptr);
  1125.       for (i = 0, nNode = state_prob[J].size(); i < nNode; i++) 
  1126. for (m = 0; m < M; m++)
  1127.   dataptr[m*state_prob[J].size()+i] = state_prob[J][i][m];
  1128.     }
  1129. }  
  1130. //-----------------------------------------------------------------------------
  1131. double KLD_est(THMT model1, THMT model2, int nObservations)
  1132. {
  1133.     tree<double> genTree;
  1134.     double logden1, logden2;
  1135.     
  1136.     model1.generate_data(genTree, nObservations);
  1137.     logden1 = model1.batch_test(&genTree);
  1138.     logden2 = model2.batch_test(&genTree);
  1139.     
  1140.     return (logden1 - logden2);
  1141. }
  1142. //-----------------------------------------------------------------------------
  1143. // Compute the Kullback-Leibler distance between two discrete 
  1144. // probality mass functions
  1145. double KLD_disc(const vector<double>& prob1, const vector<double>& prob2)
  1146. {
  1147.     if (prob1.size() != prob2.size())
  1148. cerr << "KLD_disc: Two probability vectors have different length"
  1149.      << endl;
  1150.     double d = 0.0;
  1151.     for (int i = 0; i < prob1.size(); i++)
  1152. if ((prob1[i] != 0.0) && (prob2[0] != 0.0))
  1153.     d += prob1[i] * log(prob1[i] / prob2[i]);
  1154.     return d;
  1155. }
  1156. //-----------------------------------------------------------------------------
  1157. // Compute the Kullback-Leibler distance between two continuous 
  1158. // Gaussian probability desity functions
  1159. double KLD_gauss(double mean1, double stdv1, double mean2, double stdv2)
  1160. {
  1161.     double r1 = stdv1 / stdv2;
  1162.     double r2 = (mean1 - mean2) / stdv2;
  1163.     return 0.5 * (-2*log(r1) - 1 + r1*r1 + r2*r2); 
  1164. }
  1165. //-----------------------------------------------------------------------------
  1166. double KLD_upb(THMT model1, THMT model2)
  1167. {
  1168.     int J, m, n, dir, maxdir;
  1169.     int M, nCh, nLev;
  1170.     double test_sum;
  1171.    
  1172.     M = model1.M;
  1173.     if (model2.M != M)
  1174. mexErrMsgTxt("KLD_upb: Incompatible models.");
  1175.     nCh = model1.nCh;
  1176.     if (model2.nCh != nCh)
  1177. mexErrMsgTxt("KLD_upb: Incompatible models.");
  1178.     nLev = model1.nLev;
  1179.     if (model2.nLev != nLev)
  1180. mexErrMsgTxt("KLD_upb: Incompatible models.");
  1181.     // assume the lowest level has the largest number of directions
  1182.     maxdir = model1.model_stdv[nLev-1].size();
  1183.     vector<vector<double> > D(maxdir, vector<double>(M));
  1184.     vector<vector<double> > d(maxdir, vector<double>(M));
  1185.     vector<double> trans1(M);
  1186.     vector<double> trans2(M);
  1187.  
  1188.     // Initial: lowest level
  1189.     J = nLev - 1;
  1190.     for (dir = 0; dir < model1.model_stdv[J].size(); dir++){
  1191.       for (m = 0; m < M; m++)
  1192. D[dir][m] = KLD_gauss(model1.model_mean[J][dir][m], 
  1193.  model1.model_stdv[J][dir][m],
  1194.  model2.model_mean[J][dir][m], 
  1195.  model2.model_stdv[J][dir][m]);
  1196.     } 
  1197.       
  1198.     // DEBUG
  1199.     //mexPrintf("%sn", "Lowest level: ");
  1200.     //for (dir = 0; dir < model1.model_stdv[J].size(); dir++)
  1201.     //  for (m = 0; m < M; m++)
  1202.     // mexPrintf("%d %d %f n", dir, m, D[dir][m]);
  1203.     // Induction:
  1204.     for (J = nLev-1; J > 0; J--) {
  1205.       for (dir = 0; dir < model1.model_stdv[J-1].size(); dir++){
  1206. for (m = 0; m < M; m++) {
  1207.   d[dir][m] = KLD_gauss(model1.model_mean[J-1][dir][m], 
  1208. model1.model_stdv[J-1][dir][m],
  1209. model2.model_mean[J-1][dir][m], 
  1210. model2.model_stdv[J-1][dir][m]);
  1211.  
  1212.   if ( model1.model_stdv[J].size() == model1.model_stdv[J-1].size()){
  1213.     for (n=0; n<M; n++){
  1214.       trans1[n] = model1.model_trans[J][dir][n][m];
  1215.       trans2[n] = model2.model_trans[J][dir][n][m];
  1216.     }
  1217.     d[dir][m] += nCh * KLD_disc(trans1,trans2);
  1218.   }
  1219.   else if (model1.model_stdv[J].size() == 
  1220.    2*model1.model_stdv[J-1].size()){
  1221.     for (n=0; n<M; n++){
  1222.       trans1[n] = model1.model_trans[J][dir*2][n][m];
  1223.       trans2[n] = model2.model_trans[J][dir*2][n][m];
  1224.     }
  1225.     d[dir][m] += nCh/2 * KLD_disc(trans1,trans2);
  1226.     for (n=0; n<M; n++){
  1227.       trans1[n] = model1.model_trans[J][dir*2+1][n][m];
  1228.       trans2[n] = model2.model_trans[J][dir*2+1][n][m];
  1229.     }
  1230.     d[dir][m] += nCh/2 * KLD_disc(trans1,trans2);
  1231.   }
  1232.   else
  1233.     mexErrMsgTxt("Error: Multiple parents for one child");
  1234.   for (n = 0; n < M; n++)
  1235.     if ( model1.model_stdv[J].size() == model1.model_stdv[J-1].size())
  1236.       d[dir][m] += nCh * model1.model_trans[J][dir][n][m] 
  1237. * D[dir][n]; 
  1238.     else {
  1239.       d[dir][m] += nCh/2 * model1.model_trans[J][dir*2][n][m] 
  1240. * D[dir*2][n];
  1241.       d[dir][m] += nCh/2 * model1.model_trans[J][dir*2+1][n][m] 
  1242. * D[dir*2+1][n];
  1243.     }
  1244. }
  1245.       }
  1246.       // DEBUG:
  1247.       //for (dir = 0; dir < model1.model_stdv[J-1].size(); dir++)
  1248.       // for (m = 0; m < M; m++)
  1249.       //   mexPrintf("%d %d %f %fn", dir, m, D[dir][m], d[dir][m]);
  1250.       // updating the temporary distance vector D
  1251.       D = d;
  1252.     }
  1253.  
  1254.     // Final:
  1255.     double dist;
  1256.     for (n=0; n<M; n++){
  1257.       trans1[n] = model1.model_trans[0][0][n][0];
  1258.       trans2[n] = model2.model_trans[0][0][n][0];
  1259.     }
  1260.     dist = KLD_disc(trans1, trans2);
  1261.     // DEBUG
  1262.     //mexPrintf("%s %f", "Final: KLD_disc = ", dist);
  1263.     for (m = 0; m < M; m++)
  1264. dist += model1.model_trans[0][0][m][0] * D[0][m];
  1265.     return dist;
  1266. }
  1267. /*********************************************************************
  1268.     TEMPLATES INSTANCIATION !!!
  1269. *************************************************************/
  1270. /*template class matrix<float>;
  1271. template class vector<double>;
  1272. template class tree<double>;
  1273. template class tree<int>;
  1274. template class tree<vector<double> >; */