AnnBP.cpp
上传用户:cdscwht
上传日期:2022-07-27
资源大小:264k
文件大小:11k
源码类别:

图形/文字识别

开发平台:

Visual Basic

  1. // AnnBP.cpp: implementation of the CAnnBP class.
  2. //
  3. //////////////////////////////////////////////////////////////////////
  4. #include "StdAfx.h"
  5. #include "AnnBP.h"
  6. #include "math.h"
  7. //////////////////////////////////////////////////////////////////////
  8. // Construction/Destruction
  9. //////////////////////////////////////////////////////////////////////
  10. CAnnBP::CAnnBP()
  11. {
  12. eta1=0.3;
  13. momentum1=0.3;
  14. }
  15. CAnnBP::~CAnnBP()
  16. {
  17. }
  18. double CAnnBP::drnd()
  19. {
  20. return ((double) rand() / (double) BIGRND);
  21. }
  22. /*** 返回-1.0到1.0之间的双精度随机数 ***/
  23. double CAnnBP::dpn1()
  24. {
  25. return (double) (rand())/(32767/2)-1;
  26. }
  27. /*** 作用函数,目前是S型函数 ***/
  28. double CAnnBP::squash(double x)
  29. {
  30. return (1.0 / (1.0 + exp(-x)));
  31. }
  32. /*** 申请1维双精度实数数组 ***/
  33. double* CAnnBP::alloc_1d_dbl(int n)
  34. {
  35. double *new1;
  36. new1 = (double *) malloc ((unsigned) (n * sizeof (double)));
  37. if (new1 == NULL) {
  38. AfxMessageBox("ALLOC_1D_DBL: Couldn't allocate array of doublesn");
  39. return (NULL);
  40. }
  41. return (new1);
  42. }
  43. /*** 申请2维双精度实数数组 ***/
  44. double** CAnnBP::alloc_2d_dbl(int m, int n)
  45. {
  46. int i;
  47. double **new1;
  48. new1 = (double **) malloc ((unsigned) (m * sizeof (double *)));
  49. if (new1 == NULL) {
  50. AfxMessageBox("ALLOC_2D_DBL: Couldn't allocate array of dbl ptrsn");
  51. return (NULL);
  52. }
  53. for (i = 0; i < m; i++) {
  54. new1[i] = alloc_1d_dbl(n);
  55. }
  56. return (new1);
  57. }
  58. /*** 随机初始化权值 ***/
  59. void CAnnBP::bpnn_randomize_weights(double **w, int m, int n)
  60. {
  61. int i, j;
  62. for (i = 0; i <= m; i++) {
  63. for (j = 0; j <= n; j++) {
  64. w[i][j] = dpn1();
  65. }
  66. }
  67. }
  68. /*** 0初始化权值 ***/
  69. void CAnnBP::bpnn_zero_weights(double **w, int m, int n)
  70. {
  71. int i, j;
  72. for (i = 0; i <= m; i++) {
  73. for (j = 0; j <= n; j++) {
  74. w[i][j] = 0.0;
  75. }
  76. }
  77. }
  78. /*** 设置随机数种子 ***/
  79. void CAnnBP::bpnn_initialize(int seed)
  80. {
  81. CString msg,s;
  82. msg="Random number generator seed:";
  83. s.Format("%d",seed);
  84. AfxMessageBox(msg+s);
  85. srand(seed);
  86. }
  87. /*** 创建BP网络 ***/
  88. BPNN* CAnnBP::bpnn_internal_create(int n_in, int n_hidden, int n_out)
  89. {
  90. BPNN *newnet;
  91. newnet = (BPNN *) malloc (sizeof (BPNN));
  92. if (newnet == NULL) {
  93. printf("BPNN_CREATE: Couldn't allocate neural networkn");
  94. return (NULL);
  95. }
  96. newnet->input_n = n_in;
  97. newnet->hidden_n = n_hidden;
  98. newnet->output_n = n_out;
  99. newnet->input_units = alloc_1d_dbl(n_in + 1);
  100. newnet->hidden_units = alloc_1d_dbl(n_hidden + 1);
  101. newnet->output_units = alloc_1d_dbl(n_out + 1);
  102. newnet->hidden_delta = alloc_1d_dbl(n_hidden + 1);
  103. newnet->output_delta = alloc_1d_dbl(n_out + 1);
  104. newnet->target = alloc_1d_dbl(n_out + 1);
  105. newnet->input_weights = alloc_2d_dbl(n_in + 1, n_hidden + 1);
  106. newnet->hidden_weights = alloc_2d_dbl(n_hidden + 1, n_out + 1);
  107. newnet->input_prev_weights = alloc_2d_dbl(n_in + 1, n_hidden + 1);
  108. newnet->hidden_prev_weights = alloc_2d_dbl(n_hidden + 1, n_out + 1);
  109. return (newnet);
  110. }
  111. /* 释放BP网络所占地内存空间 */
  112. void CAnnBP::bpnn_free(BPNN *net)
  113. {
  114. int n1, n2, i;
  115. n1 = net->input_n;
  116. n2 = net->hidden_n;
  117. free((char *) net->input_units);
  118. free((char *) net->hidden_units);
  119. free((char *) net->output_units);
  120. free((char *) net->hidden_delta);
  121. free((char *) net->output_delta);
  122. free((char *) net->target);
  123. for (i = 0; i <= n1; i++) {
  124. free((char *) net->input_weights[i]);
  125. free((char *) net->input_prev_weights[i]);
  126. }
  127. free((char *) net->input_weights);
  128. free((char *) net->input_prev_weights);
  129. for (i = 0; i <= n2; i++) {
  130. free((char *) net->hidden_weights[i]);
  131. free((char *) net->hidden_prev_weights[i]);
  132. }
  133. free((char *) net->hidden_weights);
  134. free((char *) net->hidden_prev_weights);
  135. free((char *) net);
  136. }
  137. /*** 创建一个BP网络,并初始化权值***/
  138. BPNN* CAnnBP::bpnn_create(int n_in, int n_hidden, int n_out)
  139. {
  140. BPNN *newnet;
  141. newnet = bpnn_internal_create(n_in, n_hidden, n_out);
  142. #ifdef INITZERO
  143. bpnn_zero_weights(newnet->input_weights, n_in, n_hidden);
  144. #else
  145. bpnn_randomize_weights(newnet->input_weights, n_in, n_hidden);
  146. #endif
  147. bpnn_randomize_weights(newnet->hidden_weights, n_hidden, n_out);
  148. bpnn_zero_weights(newnet->input_prev_weights, n_in, n_hidden);
  149. bpnn_zero_weights(newnet->hidden_prev_weights, n_hidden, n_out);
  150. return (newnet);
  151. }
  152. void CAnnBP::bpnn_layerforward(double *l1, double *l2, double **conn, int n1, int n2)
  153. {
  154. double sum;
  155. int j, k;
  156. /*** 设置阈值 ***/
  157. l1[0] = 1.0;
  158. /*** 对于第二层的每个神经元 ***/
  159. for (j = 1; j <= n2; j++) {
  160. /*** 计算输入的加权总和 ***/
  161. sum = 0.0;
  162. for (k = 0; k <= n1; k++) {
  163. sum += conn[k][j] * l1[k];
  164. }
  165. l2[j] = squash(sum);
  166. }
  167. }
  168. /* 输出误差 */
  169. void CAnnBP::bpnn_output_error(double *delta, double *target, double *output, int nj, double *err)
  170. {
  171. int j;
  172. double o, t, errsum;
  173. errsum = 0.0;
  174. for (j = 1; j <= nj; j++) {
  175. o = output[j];
  176. t = target[j];
  177. delta[j] = o * (1.0 - o) * (t - o);
  178. errsum += ABS(delta[j]);
  179. }
  180. *err = errsum;
  181. }
  182. /* 隐含层误差 */
  183. void CAnnBP::bpnn_hidden_error(double *delta_h, int nh, double *delta_o, int no, double **who, double *hidden, double *err)
  184. {
  185. int j, k;
  186. double h, sum, errsum;
  187. errsum = 0.0;
  188. for (j = 1; j <= nh; j++) {
  189. h = hidden[j];
  190. sum = 0.0;
  191. for (k = 1; k <= no; k++) {
  192. sum += delta_o[k] * who[j][k];
  193. }
  194. delta_h[j] = h * (1.0 - h) * sum;
  195. errsum += ABS(delta_h[j]);
  196. }
  197. *err = errsum;
  198. }
  199. /* 调整权值 */
  200. void CAnnBP::bpnn_adjust_weights(double *delta, int ndelta, double *ly, int nly, double **w, double **oldw, double eta, double momentum)
  201. {
  202. double new_dw;
  203. int k, j;
  204. ly[0] = 1.0;
  205. for (j = 1; j <= ndelta; j++) {
  206. for (k = 0; k <= nly; k++) {
  207. new_dw = ((eta * delta[j] * ly[k]) + (momentum * oldw[k][j]));
  208. w[k][j] += new_dw;
  209. oldw[k][j] = new_dw;
  210. }
  211. }
  212. }
  213. /* 进行前向运算 */
  214. void CAnnBP::bpnn_feedforward(BPNN *net)
  215. {
  216. int in, hid, out;
  217. in = net->input_n;
  218. hid = net->hidden_n;
  219. out = net->output_n;
  220. /*** Feed forward input activations. ***/
  221. bpnn_layerforward(net->input_units, net->hidden_units,
  222. net->input_weights, in, hid);
  223. bpnn_layerforward(net->hidden_units, net->output_units,
  224. net->hidden_weights, hid, out);
  225. }
  226. /* 训练BP网络 */
  227. void CAnnBP::bpnn_train(BPNN *net, double eta, double momentum, double *eo, double *eh)
  228. {
  229. int in, hid, out;
  230. double out_err, hid_err;
  231. in = net->input_n;
  232. hid = net->hidden_n;
  233. out = net->output_n;
  234. /*** 前向输入激活 ***/
  235. bpnn_layerforward(net->input_units, net->hidden_units,
  236. net->input_weights, in, hid);
  237. bpnn_layerforward(net->hidden_units, net->output_units,
  238. net->hidden_weights, hid, out);
  239. /*** 计算隐含层和输出层误差 ***/
  240. bpnn_output_error(net->output_delta, net->target, net->output_units,
  241. out, &out_err);
  242. bpnn_hidden_error(net->hidden_delta, hid, net->output_delta, out,
  243. net->hidden_weights, net->hidden_units, &hid_err);
  244. *eo = out_err;
  245. *eh = hid_err;
  246. /*** 调整输入层和隐含层权值 ***/
  247. bpnn_adjust_weights(net->output_delta, out, net->hidden_units, hid,
  248. net->hidden_weights, net->hidden_prev_weights, eta, momentum);
  249. bpnn_adjust_weights(net->hidden_delta, hid, net->input_units, in,
  250. net->input_weights, net->input_prev_weights, eta, momentum);
  251. }
  252. /* 保存BP网络 */
  253. void CAnnBP::bpnn_save(BPNN *net, char *filename)
  254. {
  255. CFile file;
  256. char *mem;
  257. int n1, n2, n3, i, j, memcnt;
  258. double dvalue, **w;
  259. n1 = net->input_n;  n2 = net->hidden_n;  n3 = net->output_n;
  260. printf("Saving %dx%dx%d network to '%s'n", n1, n2, n3, filename);
  261. try
  262. {
  263. file.Open(filename,CFile::modeWrite|CFile::modeCreate|CFile::modeNoTruncate);
  264. }
  265. catch(CFileException* e)
  266. {
  267. e->ReportError();
  268. e->Delete();
  269. }
  270. file.Write(&n1,sizeof(int));
  271. file.Write(&n2,sizeof(int));
  272. file.Write(&n3,sizeof(int));
  273. memcnt = 0;
  274. w = net->input_weights;
  275. mem = (char *) malloc ((unsigned) ((n1+1) * (n2+1) * sizeof(double)));
  276. // mem = (char *) malloc (((n1+1) * (n2+1) * sizeof(double)));
  277. for (i = 0; i <= n1; i++) {
  278. for (j = 0; j <= n2; j++) {
  279. dvalue = w[i][j];
  280. //fastcopy(&mem[memcnt], &dvalue, sizeof(double));
  281. fastcopy(&mem[memcnt], &dvalue, sizeof(double));
  282. memcnt += sizeof(double);
  283. }
  284. }
  285. file.Write(mem,sizeof(double)*(n1+1)*(n2+1));
  286. free(mem);
  287. memcnt = 0;
  288. w = net->hidden_weights;
  289. mem = (char *) malloc ((unsigned) ((n2+1) * (n3+1) * sizeof(double)));
  290. // mem = (char *) malloc (((n2+1) * (n3+1) * sizeof(double)));
  291. for (i = 0; i <= n2; i++) {
  292. for (j = 0; j <= n3; j++) {
  293. dvalue = w[i][j];
  294. fastcopy(&mem[memcnt], &dvalue, sizeof(double));
  295. // fastcopy(&mem[memcnt], &dvalue, sizeof(double));
  296. memcnt += sizeof(double);
  297. }
  298. }
  299. file.Write(mem, (n2+1) * (n3+1) * sizeof(double));
  300. // free(mem);
  301. file.Close();
  302. return;
  303. }
  304. /* 从文件中读取BP网络 */
  305. BPNN* CAnnBP::bpnn_read(char *filename)
  306. {
  307. char *mem;
  308. BPNN *new1;
  309. int n1, n2, n3, i, j, memcnt;
  310. CFile file;
  311. try
  312. {
  313. file.Open(filename,CFile::modeRead|CFile::modeCreate|CFile::modeNoTruncate);
  314. }
  315. catch(CFileException* e)
  316. {
  317. e->ReportError();
  318. e->Delete();
  319. }
  320. // printf("Reading '%s'n", filename);// fflush(stdout);
  321. file.Read(&n1, sizeof(int));
  322. file.Read(&n2, sizeof(int));
  323. file.Read(&n3, sizeof(int));
  324. new1 = bpnn_internal_create(n1, n2, n3);
  325. // printf("'%s' contains a %dx%dx%d networkn", filename, n1, n2, n3);
  326. // printf("Reading input weights..."); // fflush(stdout);
  327. memcnt = 0;
  328. mem = (char *) malloc (((n1+1) * (n2+1) * sizeof(double)));
  329. file.Read(mem, ((n1+1)*(n2+1))*sizeof(double));
  330. for (i = 0; i <= n1; i++) {
  331. for (j = 0; j <= n2; j++) {
  332. //fastcopy(&(new1->input_weights[i][j]), &mem[memcnt], sizeof(double));
  333. fastcopy(&(new1->input_weights[i][j]), &mem[memcnt], sizeof(double));
  334. memcnt += sizeof(double);
  335. }
  336. }
  337. free(mem);
  338. // printf("DonenReading hidden weights...");  //fflush(stdout);
  339. memcnt = 0;
  340. mem = (char *) malloc (((n2+1) * (n3+1) * sizeof(double)));
  341. file.Read(mem, (n2+1) * (n3+1) * sizeof(double));
  342. for (i = 0; i <= n2; i++) {
  343. for (j = 0; j <= n3; j++) {
  344. //fastcopy(&(new1->hidden_weights[i][j]), &mem[memcnt], sizeof(double));
  345. fastcopy(&(new1->hidden_weights[i][j]), &mem[memcnt], sizeof(double));
  346. memcnt += sizeof(double);
  347. }
  348. }
  349. free(mem);
  350. file.Close();
  351. printf("Donen");  //fflush(stdout);
  352. bpnn_zero_weights(new1->input_prev_weights, n1, n2);
  353. bpnn_zero_weights(new1->hidden_prev_weights, n2, n3);
  354. return (new1);
  355. }
  356. void CAnnBP::CreateBP(int n_in, int n_hidden, int n_out)
  357. {
  358. net=bpnn_create(n_in,n_hidden,n_out);
  359. }
  360. void CAnnBP::FreeBP()
  361. {
  362. bpnn_free(net);
  363. }
  364. void CAnnBP::Train(double *input_unit,int input_num, double *target,int target_num, double *eo, double *eh)
  365. {
  366. for(int i=1;i<=input_num;i++)
  367. {
  368. net->input_units[i]=input_unit[i-1];
  369. }
  370. for(int j=1;j<=target_num;j++)
  371. {
  372. net->target[j]=target[j-1];
  373. }
  374. bpnn_train(net,eta1,momentum1,eo,eh);
  375. }
  376. void CAnnBP::Identify(double *input_unit,int input_num,double *target,int target_num)
  377. {
  378. for(int i=1;i<=input_num;i++)
  379. {
  380. net->input_units[i]=input_unit[i-1];
  381. }
  382. bpnn_feedforward(net);
  383. for(int j=1;j<=target_num;j++)
  384. {
  385. target[j-1]=net->output_units[j];
  386. }
  387. }
  388. void CAnnBP::Save(char *filename)
  389. {
  390. bpnn_save(net,filename);
  391. }
  392. void CAnnBP::Read(char *filename)
  393. {
  394. net=bpnn_read(filename);
  395. }
  396. void CAnnBP::SetBParm(double eta, double momentum)
  397. {
  398. eta1=eta;
  399. momentum1=momentum;
  400. }
  401. void CAnnBP::Initialize(int seed)
  402. {
  403. bpnn_initialize(seed);
  404. }