svm-train.c
上传用户:xgw_05
上传日期:2014-12-08
资源大小:2726k
文件大小:8k
源码类别:

.net编程

开发平台:

Java

  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <string.h>
  4. #include <ctype.h>
  5. #include "svm.h"
  6. #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
  7. void exit_with_help()
  8. {
  9. printf(
  10. "Usage: svm-train [options] training_set_file [model_file]n"
  11. "options:n"
  12. "-s svm_type : set type of SVM (default 0)n"
  13. " 0 -- C-SVCn"
  14. " 1 -- nu-SVCn"
  15. " 2 -- one-class SVMn"
  16. " 3 -- epsilon-SVRn"
  17. " 4 -- nu-SVRn"
  18. "-t kernel_type : set type of kernel function (default 2)n"
  19. " 0 -- linear: u'*vn"
  20. " 1 -- polynomial: (gamma*u'*v + coef0)^degreen"
  21. " 2 -- radial basis function: exp(-gamma*|u-v|^2)n"
  22. " 3 -- sigmoid: tanh(gamma*u'*v + coef0)n"
  23. "-d degree : set degree in kernel function (default 3)n"
  24. "-g gamma : set gamma in kernel function (default 1/k)n"
  25. "-r coef0 : set coef0 in kernel function (default 0)n"
  26. "-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)n"
  27. "-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)n"
  28. "-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)n"
  29. "-m cachesize : set cache memory size in MB (default 40)n"
  30. "-e epsilon : set tolerance of termination criterion (default 0.001)n"
  31. "-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)n"
  32. "-wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)n"
  33. "-v n: n-fold cross validation moden"
  34. );
  35. exit(1);
  36. }
  37. void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
  38. void read_problem(const char *filename);
  39. void do_cross_validation();
  40. struct svm_parameter param; // set by parse_command_line
  41. struct svm_problem prob; // set by read_problem
  42. struct svm_model *model;
  43. struct svm_node *x_space;
  44. int cross_validation = 0;
  45. int nr_fold;
  46. int main(int argc, char **argv)
  47. {
  48. char input_file_name[1024];
  49. char model_file_name[1024];
  50. parse_command_line(argc, argv, input_file_name, model_file_name);
  51. read_problem(input_file_name);
  52. if(cross_validation)
  53. {
  54. do_cross_validation();
  55. }
  56. else
  57. {
  58. model = svm_train(&prob,&param);
  59. svm_save_model(model_file_name,model);
  60. svm_destroy_model(model);
  61. }
  62. free(prob.y);
  63. free(prob.x);
  64. free(x_space);
  65. return 0;
  66. }
  67. void do_cross_validation()
  68. {
  69. int i;
  70. int total_correct = 0;
  71. double total_error = 0;
  72. double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
  73. // random shuffle
  74. for(i=0;i<prob.l;i++)
  75. {
  76. int j = rand()%(prob.l-i);
  77. struct svm_node *tx;
  78. double ty;
  79. tx = prob.x[i];
  80. prob.x[i] = prob.x[j];
  81. prob.x[j] = tx;
  82. ty = prob.y[i];
  83. prob.y[i] = prob.y[j];
  84. prob.y[j] = ty;
  85. }
  86. for(i=0;i<nr_fold;i++)
  87. {
  88. int begin = i*prob.l/nr_fold;
  89. int end = (i+1)*prob.l/nr_fold;
  90. int j,k;
  91. struct svm_problem subprob;
  92. subprob.l = prob.l-(end-begin);
  93. subprob.x = Malloc(struct svm_node*,subprob.l);
  94. subprob.y = Malloc(double,subprob.l);
  95. k=0;
  96. for(j=0;j<begin;j++)
  97. {
  98. subprob.x[k] = prob.x[j];
  99. subprob.y[k] = prob.y[j];
  100. ++k;
  101. }
  102. for(j=end;j<prob.l;j++)
  103. {
  104. subprob.x[k] = prob.x[j];
  105. subprob.y[k] = prob.y[j];
  106. ++k;
  107. }
  108. if(param.svm_type == EPSILON_SVR ||
  109.    param.svm_type == NU_SVR)
  110. {
  111. struct svm_model *submodel = svm_train(&subprob,&param);
  112. double error = 0;
  113. for(j=begin;j<end;j++)
  114. {
  115. double v = svm_predict(submodel,prob.x[j]);
  116. double y = prob.y[j];
  117. error += (v-y)*(v-y);
  118. sumv += v;
  119. sumy += y;
  120. sumvv += v*v;
  121. sumyy += y*y;
  122. sumvy += v*y;
  123. }
  124. svm_destroy_model(submodel);
  125. printf("Mean squared error = %gn", error/(end-begin));
  126. total_error += error;
  127. }
  128. else
  129. {
  130. struct svm_model *submodel = svm_train(&subprob,&param);
  131. int correct = 0;
  132. for(j=begin;j<end;j++)
  133. {
  134. double v = svm_predict(submodel,prob.x[j]);
  135. if(v == prob.y[j])
  136. ++correct;
  137. }
  138. svm_destroy_model(submodel);
  139. printf("Accuracy = %g%% (%d/%d)n", 100.0*correct/(end-begin),correct,(end-begin));
  140. total_correct += correct;
  141. }
  142. free(subprob.x);
  143. free(subprob.y);
  144. }
  145. if(param.svm_type == EPSILON_SVR || param.svm_type == NU_SVR)
  146. {
  147. printf("Cross Validation Mean squared error = %gn",total_error/prob.l);
  148. printf("Cross Validation Squared correlation coefficient = %gn",
  149. ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
  150. ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
  151. );
  152. }
  153. else
  154. printf("Cross Validation Accuracy = %g%%n",100.0*total_correct/prob.l);
  155. }
  156. void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
  157. {
  158. int i;
  159. // default values
  160. param.svm_type = C_SVC;
  161. param.kernel_type = RBF;
  162. param.degree = 3;
  163. param.gamma = 0; // 1/k
  164. param.coef0 = 0;
  165. param.nu = 0.5;
  166. param.cache_size = 40;
  167. param.C = 1;
  168. param.eps = 1e-3;
  169. param.p = 0.1;
  170. param.shrinking = 1;
  171. param.nr_weight = 0;
  172. param.weight_label = NULL;
  173. param.weight = NULL;
  174. // parse options
  175. for(i=1;i<argc;i++)
  176. {
  177. if(argv[i][0] != '-') break;
  178. ++i;
  179. switch(argv[i-1][1])
  180. {
  181. case 's':
  182. param.svm_type = atoi(argv[i]);
  183. break;
  184. case 't':
  185. param.kernel_type = atoi(argv[i]);
  186. break;
  187. case 'd':
  188. param.degree = atof(argv[i]);
  189. break;
  190. case 'g':
  191. param.gamma = atof(argv[i]);
  192. break;
  193. case 'r':
  194. param.coef0 = atof(argv[i]);
  195. break;
  196. case 'n':
  197. param.nu = atof(argv[i]);
  198. break;
  199. case 'm':
  200. param.cache_size = atof(argv[i]);
  201. break;
  202. case 'c':
  203. param.C = atof(argv[i]);
  204. break;
  205. case 'e':
  206. param.eps = atof(argv[i]);
  207. break;
  208. case 'p':
  209. param.p = atof(argv[i]);
  210. break;
  211. case 'h':
  212. param.shrinking = atoi(argv[i]);
  213. break;
  214. case 'v':
  215. cross_validation = 1;
  216. nr_fold = atoi(argv[i]);
  217. if(nr_fold < 2)
  218. {
  219. fprintf(stderr,"n-fold cross validation: n must >= 2n");
  220. exit_with_help();
  221. }
  222. break;
  223. case 'w':
  224. ++param.nr_weight;
  225. param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight);
  226. param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight);
  227. param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
  228. param.weight[param.nr_weight-1] = atof(argv[i]);
  229. break;
  230. default:
  231. fprintf(stderr,"unknown optionn");
  232. exit_with_help();
  233. }
  234. }
  235. // determine filenames
  236. if(i>=argc)
  237. exit_with_help();
  238. strcpy(input_file_name, argv[i]);
  239. if(i<argc-1)
  240. strcpy(model_file_name,argv[i+1]);
  241. else
  242. {
  243. char *p = strrchr(argv[i],'/');
  244. if(p==NULL)
  245. p = argv[i];
  246. else
  247. ++p;
  248. sprintf(model_file_name,"%s.model",p);
  249. }
  250. }
  251. // read in a problem (in svmlight format)
  252. void read_problem(const char *filename)
  253. {
  254. int elements, max_index, i, j;
  255. FILE *fp = fopen(filename,"r");
  256. if(fp == NULL)
  257. {
  258. fprintf(stderr,"can't open input file %sn",filename);
  259. exit(1);
  260. }
  261. prob.l = 0;
  262. elements = 0;
  263. while(1)
  264. {
  265. int c = fgetc(fp);
  266. switch(c)
  267. {
  268. case 'n':
  269. ++prob.l;
  270. // fall through,
  271. // count the '-1' element
  272. case ':':
  273. ++elements;
  274. break;
  275. case EOF:
  276. goto out;
  277. default:
  278. ;
  279. }
  280. }
  281. out:
  282. rewind(fp);
  283. prob.y = Malloc(double,prob.l);
  284. prob.x = Malloc(struct svm_node *,prob.l);
  285. x_space = Malloc(struct svm_node,elements);
  286. max_index = 0;
  287. j=0;
  288. for(i=0;i<prob.l;i++)
  289. {
  290. double label;
  291. prob.x[i] = &x_space[j];
  292. fscanf(fp,"%lf",&label);
  293. prob.y[i] = label;
  294. while(1)
  295. {
  296. int c;
  297. do {
  298. c = getc(fp);
  299. if(c=='n') goto out2;
  300. } while(isspace(c));
  301. ungetc(c,fp);
  302. fscanf(fp,"%d:%lf",&(x_space[j].index),&(x_space[j].value));
  303. ++j;
  304. }
  305. out2:
  306. if(j>=1 && x_space[j-1].index > max_index)
  307. max_index = x_space[j-1].index;
  308. x_space[j++].index = -1;
  309. }
  310. if(param.gamma == 0)
  311. param.gamma = 1.0/max_index;
  312. fclose(fp);
  313. }