svm-predict.c
上传用户:xgw_05
上传日期:2014-12-08
资源大小:2726k
文件大小:3k
源码类别:

.net编程

开发平台:

Java

  1. #include <stdio.h>
  2. #include <ctype.h>
  3. #include <stdlib.h>
  4. #include <string.h>
  5. #include "svm.h"
  6. char* line;
  7. int max_line_len = 1024;
  8. struct svm_node *x;
  9. int max_nr_attr = 64;
  10. struct svm_model* model;
  11. char* readline(FILE *input)
  12. {
  13. int len;
  14. if(fgets(line,max_line_len,input) == NULL)
  15. return NULL;
  16. while(strrchr(line,'n') == NULL)
  17. {
  18. max_line_len *= 2;
  19. line = (char *) realloc(line, max_line_len);
  20. len = strlen(line);
  21. if(fgets(line+len,max_line_len-len,input) == NULL)
  22. break;
  23. }
  24. return line;
  25. }
  26. void predict(FILE *input, FILE *output)
  27. {
  28. int correct = 0;
  29. int total = 0;
  30. double error = 0;
  31. double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
  32. #define SKIP_TARGET
  33. while(isspace(*p)) ++p;
  34. while(!isspace(*p)) ++p;
  35. #define SKIP_ELEMENT
  36. while(*p!=':') ++p;
  37. ++p;
  38. while(isspace(*p)) ++p;
  39. while(*p && !isspace(*p)) ++p;
  40. while(readline(input)!=NULL)
  41. {
  42. int i = 0;
  43. double target,v;
  44. const char *p = line;
  45. if(sscanf(p,"%lf",&target)!=1) break;
  46. SKIP_TARGET
  47. while(sscanf(p,"%d:%lf",&x[i].index,&x[i].value)==2)
  48. {
  49. SKIP_ELEMENT;
  50. ++i;
  51. if(i>=max_nr_attr-1) // need one more for index = -1
  52. {
  53. max_nr_attr *= 2;
  54. x = (struct svm_node *) realloc(x,max_nr_attr*sizeof(struct svm_node));
  55. }
  56. }
  57. x[i].index = -1;
  58. v = svm_predict(model,x);
  59. if(v == target)
  60. ++correct;
  61. error += (v-target)*(v-target);
  62. sumv += v;
  63. sumy += target;
  64. sumvv += v*v;
  65. sumyy += target*target;
  66. sumvy += v*target;
  67. ++total;
  68. fprintf(output,"%gn",v);
  69. }
  70. printf("Accuracy = %g%% (%d/%d) (classification)n",
  71. (double)correct/total*100,correct,total);
  72. printf("Mean squared error = %g (regression)n",error/total);
  73. printf("Squared correlation coefficient = %g (regression)n",
  74. ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
  75. ((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))
  76. );
  77. }
  78. int main(int argc, char **argv)
  79. {
  80. FILE *input, *output;
  81. if(argc!=4)
  82. {
  83. fprintf(stderr,"usage: svm-predict test_file model_file output_filen");
  84. exit(1);
  85. }
  86. input = fopen(argv[1],"r");
  87. if(input == NULL)
  88. {
  89. fprintf(stderr,"can't open input file %sn",argv[1]);
  90. exit(1);
  91. }
  92. output = fopen(argv[3],"w");
  93. if(output == NULL)
  94. {
  95. fprintf(stderr,"can't open output file %sn",argv[3]);
  96. exit(1);
  97. }
  98. if((model=svm_load_model(argv[2]))==0)
  99. {
  100. fprintf(stderr,"can't open model file %sn",argv[2]);
  101. exit(1);
  102. }
  103. line = (char *) malloc(max_line_len*sizeof(char));
  104. x = (struct svm_node *) malloc(max_nr_attr*sizeof(struct svm_node));
  105. predict(input,output);
  106. svm_destroy_model(model);
  107. free(line);
  108. free(x);
  109. fclose(input);
  110. fclose(output);
  111. return 0;
  112. }