learn_kalman.m
上传用户:mozhenmi
上传日期:2008-02-18
资源大小:13k
文件大小:5k
源码类别:

其他小程序

开发平台:

Matlab

  1. function [A, C, Q, R, initx, initV, LL] = ...
  2.     learn_kalman(data, A, C, Q, R, initx, initV, max_iter, diagQ, diagR, ARmode, constr_fun, varargin)
  3. % LEARN_KALMAN Find the ML parameters of a stochastic Linear Dynamical System using EM.
  4. %
  5. % [A, C, Q, R, INITX, INITV, LL] = LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0) fits
  6. % the parameters which are defined as follows
  7. %   x(t+1) = A*x(t) + w(t),  w ~ N(0, Q),  x(0) ~ N(init_x, init_V)
  8. %   y(t)   = C*x(t) + v(t),  v ~ N(0, R)
  9. % A0 is the initial value, A is the final value, etc.
  10. % DATA(:,t,l) is the observation vector at time t for sequence l. If the sequences are of
  11. % different lengths, you can pass in a cell array, so DATA{l} is an O*T matrix.
  12. % LL is the "learning curve": a vector of the log lik. values at each iteration.
  13. % LL might go positive, since prob. densities can exceed 1, although this probably
  14. % indicates that something has gone wrong e.g., a variance has collapsed to 0.
  15. %
  16. % There are several optional arguments, that should be passed in the following order.
  17. % LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0, MAX_ITER, DIAGQ, DIAGR, ARmode)
  18. % MAX_ITER specifies the maximum number of EM iterations (default 10).
  19. % DIAGQ=1 specifies that the Q matrix should be diagonal. (Default 0).
  20. % DIAGR=1 specifies that the R matrix should also be diagonal. (Default 0).
  21. % ARMODE=1 specifies that C=I, R=0. i.e., a Gauss-Markov process. (Default 0).
  22. % This problem has a global MLE. Hence the initial parameter values are not important.
  23. % LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0, MAX_ITER, DIAGQ, DIAGR, F, P1, P2, ...)
  24. % calls [A,C,Q,R,initx,initV] = f(A,C,Q,R,initx,initV,P1,P2,...) after every M step. f can be
  25. % used to enforce any constraints on the params. 
  26. %
  27. % For details, see
  28. % - Ghahramani and Hinton, "Parameter Estimation for LDS", U. Toronto tech. report, 1996
  29. % - Digalakis, Rohlicek and Ostendorf, "ML Estimation of a stochastic linear system with the EM
  30. %      algorithm and its application to speech recognition",
  31. %       IEEE Trans. Speech and Audio Proc., 1(4):431--442, 1993.
  32. %    learn_kalman(data, A, C, Q, R, initx, initV, max_iter, diagQ, diagR, ARmode, constr_fun, varargin)
  33. if nargin < 8, max_iter = 10; end
  34. if nargin < 9, diagQ = 0; end
  35. if nargin < 10, diagR = 0; end
  36. if nargin < 11, ARmode = 0; end
  37. if nargin < 12, constr_fun = []; end
  38. verbose = 1;
  39. thresh = 1e-4;
  40. if ~iscell(data)
  41.   N = size(data, 3);
  42.   data = num2cell(data, [1 2]); % each elt of the 3rd dim gets its own cell
  43. else
  44.   N = length(data);
  45. end
  46. N = length(data);
  47. ss = size(A, 1);
  48. os = size(C,1);
  49. alpha = zeros(os, os);
  50. Tsum = 0;
  51. for ex = 1:N
  52.   %y = data(:,:,ex);
  53.   y = data{ex};
  54.   T = length(y);
  55.   Tsum = Tsum + T;
  56.   alpha_temp = zeros(os, os);
  57.   for t=1:T
  58.     alpha_temp = alpha_temp + y(:,t)*y(:,t)';
  59.   end
  60.   alpha = alpha + alpha_temp;
  61. end
  62. previous_loglik = -inf;
  63. loglik = 0;
  64. converged = 0;
  65. num_iter = 1;
  66. LL = [];
  67. % Convert to inline function as needed.
  68. if ~isempty(constr_fun)
  69.   constr_fun = fcnchk(constr_fun,length(varargin));
  70. end
  71. while ~converged & (num_iter <= max_iter) 
  72.   %%% E step
  73.   
  74.   delta = zeros(os, ss);
  75.   gamma = zeros(ss, ss);
  76.   gamma1 = zeros(ss, ss);
  77.   gamma2 = zeros(ss, ss);
  78.   beta = zeros(ss, ss);
  79.   P1sum = zeros(ss, ss);
  80.   x1sum = zeros(ss, 1);
  81.   loglik = 0;
  82.   
  83.   for ex = 1:N
  84.     y = data{ex};
  85.     T = length(y);
  86.     [beta_t, gamma_t, delta_t, gamma1_t, gamma2_t, x1, V1, loglik_t] = ...
  87. Estep(y, A, C, Q, R, initx, initV, ARmode);
  88.     beta = beta + beta_t;
  89.     gamma = gamma + gamma_t;
  90.     delta = delta + delta_t;
  91.     gamma1 = gamma1 + gamma1_t;
  92.     gamma2 = gamma2 + gamma2_t;
  93.     P1sum = P1sum + V1 + x1*x1';
  94.     x1sum = x1sum + x1;
  95.     %fprintf(1, 'example %d, ll/T %5.3fn', ex, loglik_t/T);
  96.     loglik = loglik + loglik_t;
  97.   end
  98.   LL = [LL loglik];
  99.   if verbose, fprintf(1, 'iteration %d, loglik = %fn', num_iter, loglik); end
  100.   %fprintf(1, 'iteration %d, loglik/NT = %fn', num_iter, loglik/Tsum);
  101.   num_iter =  num_iter + 1;
  102.   
  103.   %%% M step
  104.   
  105.   % Tsum =  N*T
  106.   % Tsum1 = N*(T-1);
  107.   Tsum1 = Tsum - N;
  108.   A = beta * inv(gamma1);
  109.   Q = (gamma2 - A*beta') / Tsum1;
  110.   if diagQ
  111.     Q = diag(diag(Q));
  112.   end
  113.   if ~ARmode
  114.     C = delta * inv(gamma);
  115.     R = (alpha - C*delta') / Tsum;
  116.     if diagR
  117.       R = diag(diag(R));
  118.     end
  119.   end
  120.   initx = x1sum / N;
  121.   initV = P1sum/N - initx*initx';
  122.   if ~isempty(constr_fun)
  123.     [A,C,Q,R,initx,initV] = feval(constr_fun, A, C, Q, R, initx, initV, varargin{:});
  124.   end
  125.   
  126.   converged = em_converged(loglik, previous_loglik, thresh);
  127.   previous_loglik = loglik;
  128. end
  129. %%%%%%%%%
  130. function [beta, gamma, delta, gamma1, gamma2, x1, V1, loglik] = ...
  131.     Estep(y, A, C, Q, R, initx, initV, ARmode)
  132. %
  133. % Compute the (expected) sufficient statistics for a single Kalman filter sequence.
  134. %
  135. [os T] = size(y);
  136. ss = length(A);
  137. if ARmode
  138.   xsmooth = y;
  139.   Vsmooth = zeros(ss, ss, T); % no uncertainty about the hidden states
  140.   VVsmooth = zeros(ss, ss, T);
  141.   loglik = 0;
  142. else
  143.   [xsmooth, Vsmooth, VVsmooth, loglik] = kalman_smoother(y, A, C, Q, R, initx, initV);
  144. end
  145. delta = zeros(os, ss);
  146. gamma = zeros(ss, ss);
  147. beta = zeros(ss, ss);
  148. for t=1:T
  149.   delta = delta + y(:,t)*xsmooth(:,t)';
  150.   gamma = gamma + xsmooth(:,t)*xsmooth(:,t)' + Vsmooth(:,:,t);
  151.   if t>1 beta = beta + xsmooth(:,t)*xsmooth(:,t-1)' + VVsmooth(:,:,t); end
  152. end
  153. gamma1 = gamma - xsmooth(:,T)*xsmooth(:,T)' - Vsmooth(:,:,T);
  154. gamma2 = gamma - xsmooth(:,1)*xsmooth(:,1)' - Vsmooth(:,:,1);
  155. x1 = xsmooth(:,1);
  156. V1 = Vsmooth(:,:,1);