svm_predict.java
上传用户:xgw_05
上传日期:2014-12-08
资源大小:2726k
文件大小:2k
源码类别:

.net编程

开发平台:

Java

  1. import libsvm.*;
  2. import java.io.*;
  3. import java.util.*;
  4. class svm_predict {
  5. private static double atof(String s)
  6. {
  7. return Double.valueOf(s).doubleValue();
  8. }
  9. private static int atoi(String s)
  10. {
  11. return Integer.parseInt(s);
  12. }
  13. private static void predict(BufferedReader input, DataOutputStream output, svm_model model) throws IOException
  14. {
  15. int correct = 0;
  16. int total = 0;
  17. double error = 0;
  18. double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
  19. while(true)
  20. {
  21. String line = input.readLine();
  22. if(line == null) break;
  23. StringTokenizer st = new StringTokenizer(line," tnrf:");
  24. double target = atof(st.nextToken());
  25. int m = st.countTokens()/2;
  26. svm_node[] x = new svm_node[m];
  27. for(int j=0;j<m;j++)
  28. {
  29. x[j] = new svm_node();
  30. x[j].index = atoi(st.nextToken());
  31. x[j].value = atof(st.nextToken());
  32. }
  33. double v = svm.svm_predict(model,x);
  34. if(v == target)
  35. ++correct;
  36. error += (v-target)*(v-target);
  37. sumv += v;
  38. sumy += target;
  39. sumvv += v*v;
  40. sumyy += target*target;
  41. sumvy += v*target;
  42. ++total;
  43. output.writeBytes(v+"n");
  44. }
  45. System.out.print("Accuracy = "+(double)correct/total*100+
  46.  "% ("+correct+"/"+total+") (classification)n");
  47. System.out.print("Mean squared error = "+error/total+" (regression)n");
  48. System.out.print("Squared correlation coefficient = "+
  49. ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
  50. ((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))+" (regression)n"
  51. );
  52. }
  53. public static void main(String argv[]) throws IOException
  54. {
  55. if(argv.length != 3)
  56. {
  57. System.err.print("usage: svm-predict test_file model_file output_filen");
  58. System.exit(1);
  59. }
  60. BufferedReader input = new BufferedReader(new FileReader(argv[0]));
  61. DataOutputStream output = new DataOutputStream(new FileOutputStream(argv[2]));
  62. svm_model model = svm.svm_load_model(argv[1]);
  63. predict(input,output,model);
  64. }
  65. }