svm_train.java
上传用户:xgw_05
上传日期:2014-12-08
资源大小:2726k
文件大小:8k
源码类别:

.net编程

开发平台:

Java

  1. import libsvm.*;
  2. import java.io.*;
  3. import java.util.*;
  4. class svm_train {
  5. private svm_parameter param; // set by parse_command_line
  6. private svm_problem prob; // set by read_problem
  7. private svm_model model;
  8. private String input_file_name; // set by parse_command_line
  9. private String model_file_name; // set by parse_command_line
  10. private int cross_validation = 0;
  11. private int nr_fold;
  12. private static void exit_with_help()
  13. {
  14. System.out.print(
  15.  "Usage: svm-train [options] training_set_file [model_file]n"
  16. +"options:n"
  17. +"-s svm_type : set type of SVM (default 0)n"
  18. +" 0 -- C-SVCn"
  19. +" 1 -- nu-SVCn"
  20. +" 2 -- one-class SVMn"
  21. +" 3 -- epsilon-SVRn"
  22. +" 4 -- nu-SVRn"
  23. +"-t kernel_type : set type of kernel function (default 2)n"
  24. +" 0 -- linear: u'*vn"
  25. +" 1 -- polynomial: (gamma*u'*v + coef0)^degreen"
  26. +" 2 -- radial basis function: exp(-gamma*|u-v|^2)n"
  27. +" 3 -- sigmoid: tanh(gamma*u'*v + coef0)n"
  28. +"-d degree : set degree in kernel function (default 3)n"
  29. +"-g gamma : set gamma in kernel function (default 1/k)n"
  30. +"-r coef0 : set coef0 in kernel function (default 0)n"
  31. +"-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)n"
  32. +"-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)n"
  33. +"-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)n"
  34. +"-m cachesize : set cache memory size in MB (default 40)n"
  35. +"-e epsilon : set tolerance of termination criterion (default 0.001)n"
  36. +"-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)n"
  37. +"-wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)n"
  38. +"-v n: n-fold cross validation moden"
  39. );
  40. System.exit(1);
  41. }
  42. private void do_cross_validation()
  43. {
  44. int i;
  45. int total_correct = 0;
  46. double total_error = 0;
  47. double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
  48. // random shuffle
  49. for(i=0;i<prob.l;i++)
  50. {
  51. int j = (int)(Math.random()*(prob.l-i));
  52. svm_node[] tx;
  53. double ty;
  54. tx = prob.x[i];
  55. prob.x[i] = prob.x[j];
  56. prob.x[j] = tx;
  57. ty = prob.y[i];
  58. prob.y[i] = prob.y[j];
  59. prob.y[j] = ty;
  60. }
  61. for(i=0;i<nr_fold;i++)
  62. {
  63. int begin = i*prob.l/nr_fold;
  64. int end = (i+1)*prob.l/nr_fold;
  65. int j,k;
  66. svm_problem subprob = new svm_problem();
  67. subprob.l = prob.l-(end-begin);
  68. subprob.x = new svm_node[subprob.l][];
  69. subprob.y = new double[subprob.l];
  70. k=0;
  71. for(j=0;j<begin;j++)
  72. {
  73. subprob.x[k] = prob.x[j];
  74. subprob.y[k] = prob.y[j];
  75. ++k;
  76. }
  77. for(j=end;j<prob.l;j++)
  78. {
  79. subprob.x[k] = prob.x[j];
  80. subprob.y[k] = prob.y[j];
  81. ++k;
  82. }
  83. if(param.svm_type == svm_parameter.EPSILON_SVR ||
  84.    param.svm_type == svm_parameter.NU_SVR)
  85. {
  86. svm_model submodel = svm.svm_train(subprob,param);
  87. double error = 0;
  88. for(j=begin;j<end;j++)
  89. {
  90. double v = svm.svm_predict(submodel,prob.x[j]);
  91. double y = prob.y[j];
  92. error += (v-y)*(v-y);
  93. sumv += v;
  94. sumy += y;
  95. sumvv += v*v;
  96. sumyy += y*y;
  97. sumvy += v*y;
  98. }
  99. System.out.print("Mean squared error = "+error/(end-begin)+"n");
  100. total_error += error;
  101. }
  102. else
  103. {
  104. svm_model submodel = svm.svm_train(subprob,param);
  105. int correct = 0;
  106. for(j=begin;j<end;j++)
  107. {
  108. double v = svm.svm_predict(submodel,prob.x[j]);
  109. if(v == prob.y[j])
  110. ++correct;
  111. }
  112. System.out.print("Accuracy = "+100.0*correct/(end-begin)+"% ("+correct+"/"+(end-begin)+")n");
  113. total_correct += correct;
  114. }
  115. }
  116. if(param.svm_type == svm_parameter.EPSILON_SVR || param.svm_type == svm_parameter.NU_SVR)
  117. {
  118. System.out.print("Cross Validation Mean squared error = "+total_error/prob.l+"n");
  119. System.out.print("Cross Validation Squared correlation coefficient = "+
  120. ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
  121. ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))+"n"
  122. );
  123. }
  124. else
  125. System.out.print("Cross Validation Accuracy = "+100.0*total_correct/prob.l+"%n");
  126. }
  127. private void run(String argv[]) throws IOException
  128. {
  129. parse_command_line(argv);
  130. read_problem();
  131. if(cross_validation != 0)
  132. {
  133. do_cross_validation();
  134. }
  135. else
  136. {
  137. model = svm.svm_train(prob,param);
  138. svm.svm_save_model(model_file_name,model);
  139. }
  140. }
  141. public static void main(String argv[]) throws IOException
  142. {
  143. svm_train t = new svm_train();
  144. t.run(argv);
  145. }
  146. private static double atof(String s)
  147. {
  148. return Double.valueOf(s).doubleValue();
  149. }
  150. private static int atoi(String s)
  151. {
  152. return Integer.parseInt(s);
  153. }
  154. private void parse_command_line(String argv[])
  155. {
  156. int i;
  157. param = new svm_parameter();
  158. // default values
  159. param.svm_type = svm_parameter.C_SVC;
  160. param.kernel_type = svm_parameter.RBF;
  161. param.degree = 3;
  162. param.gamma = 0; // 1/k
  163. param.coef0 = 0;
  164. param.nu = 0.5;
  165. param.cache_size = 40;
  166. param.C = 1;
  167. param.eps = 1e-3;
  168. param.p = 0.1;
  169. param.shrinking = 1;
  170. param.nr_weight = 0;
  171. param.weight_label = new int[0];
  172. param.weight = new double[0];
  173. // parse options
  174. for(i=0;i<argv.length;i++)
  175. {
  176. if(argv[i].charAt(0) != '-') break;
  177. ++i;
  178. switch(argv[i-1].charAt(1))
  179. {
  180. case 's':
  181. param.svm_type = atoi(argv[i]);
  182. break;
  183. case 't':
  184. param.kernel_type = atoi(argv[i]);
  185. break;
  186. case 'd':
  187. param.degree = atof(argv[i]);
  188. break;
  189. case 'g':
  190. param.gamma = atof(argv[i]);
  191. break;
  192. case 'r':
  193. param.coef0 = atof(argv[i]);
  194. break;
  195. case 'n':
  196. param.nu = atof(argv[i]);
  197. break;
  198. case 'm':
  199. param.cache_size = atof(argv[i]);
  200. break;
  201. case 'c':
  202. param.C = atof(argv[i]);
  203. break;
  204. case 'e':
  205. param.eps = atof(argv[i]);
  206. break;
  207. case 'p':
  208. param.p = atof(argv[i]);
  209. break;
  210. case 'h':
  211. param.shrinking = atoi(argv[i]);
  212. break;
  213. case 'w':
  214. ++param.nr_weight;
  215. {
  216. int[] old = param.weight_label;
  217. param.weight_label = new int[param.nr_weight];
  218. System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1);
  219. }
  220. {
  221. double[] old = param.weight;
  222. param.weight = new double[param.nr_weight];
  223. System.arraycopy(old,0,param.weight,0,param.nr_weight-1);
  224. }
  225. param.weight_label[param.nr_weight-1] = atoi(argv[i-1].substring(2));
  226. param.weight[param.nr_weight-1] = atof(argv[i]);
  227. break;
  228. case 'v':
  229. cross_validation = 1;
  230. nr_fold = atoi(argv[i]);
  231. if(nr_fold < 2)
  232. {
  233. System.err.print("n-fold cross validation: n must >= 2n");
  234. exit_with_help();
  235. }
  236. break;
  237. default:
  238. System.err.print("unknown optionn");
  239. exit_with_help();
  240. }
  241. }
  242. // determine filenames
  243. if(i>=argv.length)
  244. exit_with_help();
  245. input_file_name = argv[i];
  246. if(i<argv.length-1)
  247. model_file_name = argv[i+1];
  248. else
  249. {
  250. int p = argv[i].lastIndexOf('/');
  251. ++p; // whew...
  252. model_file_name = argv[i].substring(p)+".model";
  253. }
  254. }
  255. // read in a problem (in svmlight format)
  256. private void read_problem() throws IOException
  257. {
  258. BufferedReader fp = new BufferedReader(new FileReader(input_file_name));
  259. Vector vy = new Vector();
  260. Vector vx = new Vector();
  261. int max_index = 0;
  262. while(true)
  263. {
  264. String line = fp.readLine();
  265. if(line == null) break;
  266. StringTokenizer st = new StringTokenizer(line," tnrf:");
  267. vy.addElement(st.nextToken());
  268. int m = st.countTokens()/2;
  269. svm_node[] x = new svm_node[m];
  270. for(int j=0;j<m;j++)
  271. {
  272. x[j] = new svm_node();
  273. x[j].index = atoi(st.nextToken());
  274. x[j].value = atof(st.nextToken());
  275. }
  276. if(m>0) max_index = Math.max(max_index, x[m-1].index);
  277. vx.addElement(x);
  278. }
  279. prob = new svm_problem();
  280. prob.l = vy.size();
  281. prob.x = new svm_node[prob.l][];
  282. for(int i=0;i<prob.l;i++)
  283. prob.x[i] = (svm_node[])vx.elementAt(i);
  284. prob.y = new double[prob.l];
  285. for(int i=0;i<prob.l;i++)
  286. prob.y[i] = atof((String)vy.elementAt(i));
  287. if(param.gamma == 0)
  288. param.gamma = 1.0/max_index;
  289. fp.close();
  290. }
  291. }