rjnn.m
上传用户:fsbooksir
上传日期:2013-10-19
资源大小:14k
文件大小:8k
源码类别:

matlab例程

开发平台:

Matlab

  1. function [k,mu,alpha,sigma,nabla,delta,ypred,ypredv,post] = rjnn(x,y,chainLength,Ndata,bFunction,par,xv,yv);
  2. %
  3. % =============================
  4. if nargin < 5, error('Not enough input arguments.'); end;
  5. if ((nargin==5) | (nargin==7)),
  6.   if nargin == 5
  7.     Validation = 0;
  8.   else
  9.     Validation = 1;
  10.   end;
  11.   hyper.a = 2;                      % Hyperparameter for delta.  
  12.   hyper.b = 10;                     % Hyperparameter for delta.
  13.   hyper.e1 = 0.0001;                % Hyperparameter for nabla.    
  14.   hyper.e2 = 0.0001;                % Hyperparameter for nabla.   
  15.   hyper.v = 0;                      % Hyperparameter for sigma    
  16.   hyper.gamma = 0;                  % Hyperparameter for sigma. 
  17.   kMax = 50;                        % Maximum number of basis.
  18.   arbC = 0.5;                       % Constant for birth and death moves.
  19.   doPlot = 1;                       % To plot or not to plot? Thats ...
  20.   sigStar = .1;                     % Merge-split parameter.
  21.   sWalk = .001;
  22.   Lambda = .5;
  23.   walkPer = 0.1;
  24. elseif ((nargin==6) | (nargin==8))
  25.   if nargin == 6
  26.     Validation = 0;
  27.   else
  28.     Validation = 1;
  29.   end;
  30.   hyper.a = par.a;                 
  31.   hyper.b = par.b;                
  32.   hyper.e1 = par.e1;           
  33.   hyper.e2 = par.e2;           
  34.   hyper.v = par.v;                 
  35.   hyper.gamma = par.gamma;             
  36.   kMax = par.kMax;                   
  37.   arbC = par.arbC;
  38.   doPlot = par.doPlot;    
  39.   sigStar = par.merge;
  40.   sWalk = par.sRW;
  41.   Lambda = par.Lambda;
  42.   walkPer = par.walkPer;
  43. else
  44.  error('Wrong Number of input arguments.');
  45. end;
  46. if Validation,
  47.   [Nv,dv] = size(xv);   % Nv = number of test data, dv = dimension of xv.
  48. end;
  49. [N,d] = size(x);      % N = number of train data, d = dimension of x.
  50. [N,c] = size(y);      % c = dimension of y, i.e. number of outputs.
  51. if Ndata ~= N, error('input must me N by d and output N by c.'); end;
  52. % INITIALISATION:
  53. % ==============
  54. post = ones(chainLength,1);       % p(centres,k|y).
  55. if Validation,
  56.   ypredv = zeros(Nv,c,chainLength);  % Output fit (test set).
  57. end;
  58. ypred = zeros(N,c,chainLength);   % Output fit (train set).
  59. nabla = zeros(chainLength,1);     % Poisson parameter.
  60. delta = zeros(chainLength,c);     % Regularisation parameter.
  61. k = ones(chainLength,1);          % Model order - number of basis.
  62. sigma = ones(chainLength,c);      % Output noise variance.
  63. mu = cell(chainLength,1);         % Radial basis centres.
  64. alpha = cell(chainLength,c);      % Radial basis coefficients.
  65. % DEFINE WALK INTERVAL FOR MU:
  66. % ===========================
  67. walk = walkPer*(max(x)-min(x));
  68. walkInt=zeros(d,1);
  69. for i=1:d,
  70.   walkInt(i,1) = (max(x(:,i))-min(x(:,i))) + 2*walk(i);
  71. end;
  72. % SAMPLE INITIAL CONDITIONS FROM THEIR PRIORS:
  73. % ===========================================
  74. nabla(1) = gengamma(0.5 + hyper.e1,hyper.e2);
  75. k(1) = poissrnd(nabla(1));
  76. k(1) = 40;                              % TEMPORARY: for demo1 comparison.
  77. k(1) = max(k(1),1);
  78. k(1) = min(k(1),kMax);
  79. for i=1:c
  80.   delta(1,i) = inv(gengamma(hyper.a,hyper.b));
  81.   sigma(1,i) = inv(gengamma(hyper.v/2,hyper.gamma/2));
  82.   alpha{1,i} = mvnrnd(zeros(1,k(1)+d+1),sigma(1,i)*delta(1,i)*eye(k(1)+d+1),1)';
  83. end;
  84. % DRAW THE INITIAL RADIAL CENTRES:
  85. % ===============================
  86. mu{1}=zeros(k(1),d);
  87. for i=1:d,
  88.   mu{1}(:,i)= (min(x(:,i))-walk(i))*ones(k(1),1) + ((max(x(:,i))+walk(i))-(min(x(:,i))-walk(i)))*rand(k(1),1);
  89. end;
  90. % FILL THE REGRESSION MATRIX:
  91. % ==========================
  92. M=zeros(N,k(1)+d+1);
  93. M(:,1) = ones(N,1);
  94. M(:,2:d+1) = x;
  95. for j=d+2:k(1)+d+1,
  96.   M(:,j) = feval(bFunction,mu{1}(j-d-1,:),x);
  97. end;
  98. for i=1:c,
  99.   ypred(:,i,1) = M*alpha{1,i};
  100. end;
  101. if Validation
  102.   Mv=zeros(Nv,k(1)+d+1);
  103.   Mv(:,1) = ones(Nv,1);
  104.   Mv(:,2:d+1) = xv;
  105.   for j=d+2:k(1)+d+1,
  106.     Mv(:,j) = feval(bFunction,mu{1}(j-d-1,:),xv);
  107.   end;
  108.   for i=1:c,
  109.     ypredv(:,i,1) = Mv*alpha{1,i};
  110.   end;
  111. end;
  112. % INITIALISE COUNTERS:
  113. % ===================
  114. aUpdate=0;
  115. rUpdate=0;
  116. aBirth=0;
  117. rBirth=0;
  118. aDeath=0;
  119. rDeath=0;
  120. aMerge=0;
  121. rMerge=0;
  122. aSplit=0;
  123. rSplit=0;
  124. aRW=0;
  125. rRW=0;
  126. match=0;
  127. if doPlot
  128.   figure(3)
  129.   clf;
  130. end;
  131. % ITERATE THE MARKOV CHAIN:
  132. % ========================
  133. for t=1:chainLength-1,
  134.   iteration=t
  135.   % COMPUTE THE CENTRES AND DIMENSION WITH METROPOLIS, BIRTH AND DEATH MOVES:
  136.   % ========================================================================
  137.   decision=rand(1);
  138.   birth=arbC*min(1,(nabla(t)/(k(t)+1)));
  139.   death=arbC*min(1,((k(t)+1)/nabla(t)));
  140.   if ((decision <= birth) & (k(t)<kMax)),
  141.     [k,mu,M,match,aBirth,rBirth] = radialBirth(match,aBirth,rBirth,k,mu,M,delta,x,y,hyper,t,bFunction,walkInt,walk);
  142.   elseif ((decision <= birth+death) & (k(t)>0)),
  143.     [k,mu,M,aDeath,rDeath] = radialDeath(aDeath,rDeath,k,mu,M,delta,x,y,hyper,t,nabla);
  144.   elseif ((decision <= 2*birth+death) & (k(t)<kMax) & (k(t)>1)),
  145.     [k,mu,M,aSplit,rSplit] = radialSplit(aSplit,rSplit,k,mu,M,delta,x,y,hyper,t,bFunction,sigStar,walkInt,walk);
  146.   elseif ((decision <= 2*birth+2*death) & (k(t)>1)),
  147.     [k,mu,M,aMerge,rMerge] = radialMerge(aMerge,rMerge,k,mu,M,delta,x,y,hyper,t,bFunction,sigStar,walkInt);
  148.   else
  149.     uLambda = rand(1);
  150.     if ((uLambda>Lambda) & (k(t)>0))
  151.       [k,mu,M,match,aRW,rRW] = radialRW(match,aRW,rRW,k,mu,M,delta,x,y,hyper,t,bFunction,sWalk,walk);
  152.     else  
  153.       [k,mu,M,match,aUpdate,rUpdate] = radialUpdate(match,aUpdate,rUpdate,k,mu,M,delta,x,y,hyper,t,bFunction,walkInt,walk);
  154.     end;
  155.   end;
  156.   % UPDATE OTHER PARAMETERS WITH GIBBS:
  157.   % ==================================
  158.   H=zeros(k(t+1)+1+d,k(t+1)+1+d,c);
  159.   F=zeros(k(t+1)+1+d,c);
  160.   P=zeros(N,N,c);
  161.   for i=1:c,
  162.     H(:,:,i) = inv(M'*M + (1/delta(t,i))*eye(k(t+1)+1+d));
  163.     F(:,i) = H(:,:,i)*M'*y(:,i);
  164.     P(:,:,i) = eye(N) - M*H(:,:,i)*M';
  165.     sigma(t+1,i) = inv(gengamma((hyper.v+N)/2,(hyper.gamma+y(:,i)'*P(:,:,i)*y(:,i))/2));
  166.     alpha{t+1,i} = mvnrnd(F(:,i),sigma(t+1,i)*H(:,:,i),1)';
  167.     delta(t+1,i) = inv(gengamma(hyper.a+(k(t+1)+d+1)/2,hyper.b+inv(2*sigma(t+1,i))*alpha{t+1,i}'*alpha{t+1,i}));
  168.   end;
  169.   nabla(t+1) = gengamma(0.5+hyper.e1+k(t+1),1+hyper.e2); 
  170.   % COMPUTE THE POSTERIOR FOR MONITORING:
  171.   % ==================================== 
  172.   posterior  =exp(-nabla(t+1)) * delta(t+1,1)^(-(d+k(t+1)+1)/2) * inv(prod(1:k(t+1)) * prod(walkInt)^(k(t+1))) * nabla(t+1)^(k(t+1)) * sqrt(det(H(:,:,1))) * (hyper.gamma+y(:,1)'*P(:,:,1)*y(:,1))^(-(hyper.v+N)/2);
  173.   for i=2:c,
  174.     newpost = delta(t+1,i)^(-(d+k(t+1)+1)/2) * sqrt(det(H(:,:,i))) * (hyper.gamma+y(:,i)'*P(:,:,i)*y(:,i))^(-(hyper.v+N)/2);  
  175.     posterior  = posterior * newpost;
  176.   end;
  177.   post(t+1) = log(posterior);
  178.   % PLOT FOR FUN AND MONITORING:
  179.   % ============================ 
  180.   for i=1:c,
  181.     ypred(:,i,t+1) = M*alpha{t+1,i};
  182.   end;
  183.   msError = inv(N) * trace((y-ypred(:,:,t+1))'*(y-ypred(:,:,t+1)));
  184. %  NRMSE = sqrt((y-ypred(:,:,t+1))'*(y-ypred(:,:,t+1))*inv((y-mean(y)*ones(size(y)))'*(y-mean(y)*ones(size(y)))))
  185.   if Validation,
  186.     % FILL THE VALIDATION REGRESSION MATRIX: 
  187.     % ======================================
  188.     Mv=zeros(Nv,k(t+1)+d+1);
  189.     Mv(:,1) = ones(Nv,1);
  190.     Mv(:,2:d+1) = xv;
  191.     for j=d+2:k(t+1)+d+1,
  192.       Mv(:,j) = feval(bFunction,mu{t+1}(j-d-1,:),xv);
  193.     end;
  194.     for i=1:c,
  195.       ypredv(:,i,t+1) = Mv*alpha{t+1,i};
  196.     end;
  197.     msErrorv = inv(Nv) * trace((yv-ypredv(:,:,t+1))'*(yv-ypredv(:,:,t+1)));
  198.   end;
  199.   if doPlot,
  200.     figure(1)  
  201.     clf
  202.     if (c==2),
  203.       plot(x(:,1),y(:,1),'b+',x(:,2),y(:,2),'r+',x(:,1),ypred(:,1,t+1),'bo',x(:,2),ypred(:,2,t+1),'ro');
  204.     elseif c==1,
  205.      plot(x,y,'b+',x,ypred(:,:,t+1),'ro');
  206.     end;
  207.     errorv = sum(abs(yv-ypredv(:,:,t+1)))*100*inv(Nv);
  208.     ylabel('Output','fontsize',15)
  209.     xlabel('Input','fontsize',15)
  210.     figure(3)
  211.     subplot(511);
  212.     hold on;
  213.     plot(t,k(t),'*');
  214.     ylabel('k','fontsize',15);
  215.     subplot(512);
  216.     hold on;
  217.     plot(t,post(t+1),'*');
  218.     ylabel('p(k,mu|y)','fontsize',15);  
  219.     subplot(513);
  220.     hold on;
  221.     plot(t,msError,'r*');
  222.     ylabel('Train error','fontsize',15);
  223.     subplot(514);
  224.     hold on;
  225.     plot(t,msErrorv,'r*');
  226.     ylabel('Test error','fontsize',15);
  227.     subplot(515);
  228.     hold on;
  229.     bar([1 2 3 4 5 6 7 8 9 10 11 12 13],[match aUpdate rUpdate aBirth rBirth aDeath rDeath aMerge rMerge aSplit rSplit aRW rRW]);
  230.     ylabel('Acceptance','fontsize',15);
  231.     xlabel('match aU rU aB rB aD rD aM rM aS rS aRW rRW','fontsize',15)
  232.   end;
  233. end;