SVM.h
上传用户:sanxfzhen
上传日期:2014-12-28
资源大小:2324k
文件大小:14k
源码类别:

多国语言处理

开发平台:

Visual C++

  1. // SVM.h: interface for the CSVM class.
  2. //
  3. //////////////////////////////////////////////////////////////////////
  4. #if !defined(AFX_SVM_H__169833C2_5291_4CDB_BE81_22DBD80908F0__INCLUDED_)
  5. #define AFX_SVM_H__169833C2_5291_4CDB_BE81_22DBD80908F0__INCLUDED_
  6. #if _MSC_VER > 1000
  7. #pragma once
  8. #endif // _MSC_VER > 1000
  9. #include "Compute_Param.h"
  10. #include "Compute_Prompt.h"
  11. #include "Compute_Result.h"
  12. # define CFLOAT  float       /* the type of float to use for caching */
  13. /* kernel evaluations. Using float saves */
  14. /* us some memory, but you can use double, too */
  15. # define FNUM    long        /* the type used for storing feature ids */
  16. # define FVAL    float       /* the type used for storing feature values */
  17. # define LINEAR  0           /* linear kernel type */
  18. # define POLY    1           /* polynoial kernel type */
  19. # define RBF     2           /* rbf kernel type */
  20. # define SIGMOID 3           /* sigmoid kernel type */
  21. # define CUSTOM  4
  22. typedef struct word
  23. {
  24. FNUM    wnum;
  25. FVAL    weight;
  26. } SVM_WORD;
  27. typedef struct doc
  28. {
  29. long    docnum;
  30. double  twonorm_sq;
  31. SVM_WORD    *words;
  32. } DOC;
  33. typedef struct learn_parm 
  34. {
  35. double svm_c;                /* upper bound C on alphas */
  36. double svm_costratio;        /* factor to multiply C for positive examples */
  37. double transduction_posratio;/* fraction of unlabeled examples to be */ /* classified as positives */
  38. long   biased_hyperplane;    /* if nonzero, use hyperplane w*x+b=0  otherwise w*x=0 */
  39. long   svm_maxqpsize;        /* size q of working set */
  40. long   svm_newvarsinqp;      /* new variables to enter the working set in each iteration */
  41. double epsilon_crit;         /* tolerable error for distances used in stopping criterion */
  42. double epsilon_shrink;       /* how much a multiplier should be above  zero for shrinking */
  43. long   svm_iter_to_shrink;   /* iterations h after which an example can be removed by shrinking */
  44. long   remove_inconsistent;  /* exclude examples with alpha at C and  retrain */
  45.  
  46. long   skip_final_opt_check; 
  47.  /* do not check KT-Conditions at the end of optimization for examples removed by  
  48.  shrinking. WARNING: This might lead to sub-optimal solutions! */
  49. long   compute_loo;          /* if nonzero, computes leave-one-out  estimates */
  50. double rho;                  /* parameter in xi/alpha-estimates and for pruning leave-one-out range [1..2] */
  51. long   xa_depth;             /* parameter in xi/alpha-estimates upper  bounding the number of SV the current alpha_t is distributed over */
  52. char predfile[200];          /* file for predicitions on unlabeled examples  in transduction */
  53. char alphafile[200];         
  54. /* file to store optimal alphas in. use
  55. empty string if alphas should not be   output */
  56. /* you probably do not want to touch the following */
  57. double epsilon_const;        /* tolerable error on eq-constraint */
  58. double epsilon_a;            /* tolerable error on alphas at bounds */
  59. double opt_precision;        /* precision of solver, set to e.g. 1e-21   if you get convergence problems */
  60. /* the following are only for internal use */
  61. long   svm_c_steps;          /* do so many steps for finding optimal C */
  62. double svm_c_factor;         /* increase C by this factor every step */
  63. double svm_costratio_unlab;
  64. double svm_unlabbound;
  65. double *svm_cost;            /* individual upper bounds for each var */
  66. long   totwords;             /* number of features */
  67. } LEARN_PARM;
  68. typedef struct kernel_parm 
  69. {
  70. long    kernel_type;   
  71. long    poly_degree;
  72. double  rbf_gamma;
  73. double  coef_lin;
  74. double  coef_const;
  75. char    custom[50];    /* for user supplied kernel */
  76. } KERNEL_PARM;
  77. typedef struct model 
  78. {
  79. long    sv_num;
  80. long    at_upper_bound;
  81. double  b;
  82. DOC     **supvec;
  83. double  *alpha;
  84. long    *index;       /* index from docnum to position in model */
  85. long    totwords;     /* number of features */
  86. long    totdoc;       /* number of training documents */
  87. KERNEL_PARM kernel_parm; /* kernel */
  88. /* the following values are not written to file */
  89. double  loo_error,loo_recall,loo_precision; /* leave-one-out estimates */
  90. double  xa_error,xa_recall,xa_precision;    /* xi/alpha estimates */
  91. double  *lin_weights;   /* weights for linear case using folding */
  92. } MODEL;
  93. typedef struct quadratic_program 
  94. {
  95. long   opt_n;            /* number of variables */
  96. long   opt_m;            /* number of linear equality constraints */
  97. double *opt_ce,*opt_ce0; /* linear equality constraints */
  98. double *opt_g;           /* hessian of objective */
  99. double *opt_g0;          /* linear part of objective */
  100. double *opt_xinit;       /* initial value for variables */
  101. double *opt_low,*opt_up; /* box constraints */
  102. } QP;
  103. typedef struct kernel_cache 
  104. {
  105. long   *index;  /* cache some kernel evalutations */
  106. CFLOAT *buffer; /* to improve speed */
  107. long   *invindex;
  108. long   *active2totdoc;
  109. long   *totdoc2active;
  110. long   *lru;
  111. long   *occu;
  112. long   elems;
  113. long   max_elems;
  114. long   time;
  115. long   activenum;
  116. long   buffsize;
  117. } KERNEL_CACHE;
  118.  
  119. typedef struct timing_profile 
  120. {
  121. long   time_kernel;
  122. long   time_opti;
  123. long   time_shrink;
  124. long   time_update;
  125. long   time_model;
  126. long   time_check;
  127. long   time_select;
  128. } TIMING;
  129.  
  130. typedef struct shrink_state 
  131. {
  132. long   *active;
  133. long   *inactive_since;
  134. long   deactnum;
  135. double **a_history;
  136. } SHRINK_STATE;
  137. typedef struct cache_parm_s {
  138.   KERNEL_CACHE *kernel_cache;
  139.   CFLOAT *cache;
  140.   DOC *docs; 
  141.   long m;
  142.   KERNEL_PARM *kernel_parm;
  143.   long offset,stepsize;
  144. } cache_parm_t;
  145. class CSVM  
  146. {
  147. public: 
  148. CSVM();
  149. virtual ~CSVM();
  150. int svm_learn_main (int);
  151. int svm_classify (int, double*);
  152. double svm_classify(DOC &doc);
  153. private:
  154. void set_learn_parameters(LEARN_PARM* learn_parm,KERNEL_PARM* kernel_parm);
  155. private:   //svm_common.h
  156. double classify_example(MODEL *, DOC *);
  157. double classify_example_linear(MODEL *, DOC *);
  158. CFLOAT kernel(KERNEL_PARM *, DOC *, DOC *); 
  159. double custom_kernel(KERNEL_PARM *, DOC *, DOC *); 
  160. double sprod_ss(SVM_WORD *, SVM_WORD *);
  161. double model_length_s(MODEL *, KERNEL_PARM *);
  162. void   clear_vector_n(double *, long);
  163. void   add_vector_ns(double *, SVM_WORD *, double);
  164. double sprod_ns(double *, SVM_WORD *);
  165. double sprod_ss1(SVM_WORD *a,SVM_WORD*b,int offset);
  166. double sprod_ss2(SVM_WORD *a,SVM_WORD*b,int offset);
  167. void   add_weight_vector_to_linear_model(MODEL *);
  168. int    read_model(char *, MODEL *, long, long);
  169. int    read_documents(char *, DOC *, long *, long, long, long *, long *, int);
  170. int    parse_document(char *, DOC *, long *, long *, long);
  171. int    nol_ll(char *, long *, long *, long *);
  172. double ktt(int ta,int tb,double pa[],double pb[]);
  173. double kt(int t,double ta[],double tb[]);
  174. double fi(double* tt);
  175. double fs(double ta[]);
  176. double sumofword(DOC* a);
  177. long   minl(long, long);
  178. long   maxl(long, long);
  179. long   get_runtime();
  180. void   *my_malloc(long); 
  181. void   copyright_notice();
  182. //void  SetInitParam();         ?????????
  183. int isnan(double);
  184. private:    //svm_common.h
  185. void printm(char*);
  186. void    printe(char*);
  187. private:    //svm_learn.h
  188. void   svm_learn(DOC *, long *, long, long, LEARN_PARM *, KERNEL_PARM *, 
  189.  KERNEL_CACHE *, MODEL *);
  190. long   optimize_to_convergence(DOC *, long *, long, long, LEARN_PARM *,
  191.        KERNEL_PARM *, KERNEL_CACHE *, SHRINK_STATE *,
  192.        MODEL *, long *, long *, double *, double *, 
  193.        TIMING *, double *, long, long);
  194. double compute_objective_function(double *, double *, long *, long *);
  195. void   clear_index(long *);
  196. void   add_to_index(long *, long);
  197. long   compute_index(long *,long, long *);
  198. void   optimize_svm(DOC *, long *, long *, long *, long *, MODEL *, long, 
  199.     long *, long, double *, double *, LEARN_PARM *, CFLOAT *, 
  200.     KERNEL_PARM *, QP *, double *);
  201. void   compute_matrices_for_optimization(DOC *, long *, long *, long *, 
  202.  long *, long *, MODEL *, double *, 
  203.  double *, long, long, LEARN_PARM *, 
  204.  CFLOAT *, KERNEL_PARM *, QP *);
  205. long   calculate_svm_model(DOC *, long *, long *, double *, double *, 
  206.    double *, LEARN_PARM *, long *, MODEL *);
  207. long   check_optimality(MODEL *, long *, long *, double *, double *,long, 
  208. LEARN_PARM *,double *, double, long *, long *, long *,
  209. long *, long, KERNEL_PARM *);
  210. long   identify_inconsistent(double *, long *, long *, long, LEARN_PARM *, 
  211.      long *, long *);
  212. long   identify_misclassified(double *, long *, long *, long,
  213.       MODEL *, long *, long *);
  214. long   identify_one_misclassified(double *, long *, long *, long,
  215.   MODEL *, long *, long *);
  216. long   incorporate_unlabeled_examples(MODEL *, long *,long *, long *,
  217.       double *, double *, long, double *,
  218.       long *, long *, long, KERNEL_PARM *,
  219.       LEARN_PARM *);
  220. void   update_linear_component(DOC *, long *, long *, double *, double *, 
  221.        long *, long, long, KERNEL_PARM *, 
  222.        KERNEL_CACHE *, double *,
  223.        CFLOAT *, double *);
  224. long   select_next_qp_subproblem_grad(long *, long *, double *, double *, long,
  225.       long, LEARN_PARM *, long *, long *, 
  226.       long *, double *, long *, KERNEL_CACHE *,
  227.       long *, long *);
  228. long   select_next_qp_subproblem_grad_cache(long *, long *, double *, double *,
  229.     long, long, LEARN_PARM *, long *, 
  230.     long *, long *, double *, long *,
  231.     KERNEL_CACHE *, long *, long *);
  232. void   select_top_n(double *, long, long *, long);
  233. void   init_shrink_state(SHRINK_STATE *, long, long);
  234. void   shrink_state_cleanup(SHRINK_STATE *);
  235. long shrink_problem(
  236. /* shrink some variables away */
  237. /* do the shrinking only if at least minshrink variables can be removed */
  238. LEARN_PARM *learn_parm,
  239. SHRINK_STATE *shrink_state,
  240. long *active2dnum,long iteration,long* last_suboptimal_at,
  241. long totdoc,long minshrink,
  242. double *a,
  243. long *inconsistent);
  244. void   reactivate_inactive_examples(long *, long *, double *, SHRINK_STATE *,
  245.     double *, long, long, long, LEARN_PARM *, 
  246.     long *, DOC *, KERNEL_PARM *,
  247.     KERNEL_CACHE *, MODEL *, CFLOAT *, 
  248.     double *, double *);
  249. /* cache kernel evalutations to improve speed */
  250. void   get_kernel_row(KERNEL_CACHE *,DOC *, long, long, long *, CFLOAT *, 
  251.       KERNEL_PARM *);
  252. void   cache_kernel_row(KERNEL_CACHE *,DOC *, long, KERNEL_PARM *);
  253. void   cache_multiple_kernel_rows(KERNEL_CACHE *,DOC *, long *, long, 
  254.   KERNEL_PARM *);
  255. void   kernel_cache_shrink(KERNEL_CACHE *,long, long, long *);
  256. void   kernel_cache_init(KERNEL_CACHE *,long, long);
  257. void   kernel_cache_reset_lru(KERNEL_CACHE *);
  258. void   kernel_cache_cleanup(KERNEL_CACHE *);
  259. long   kernel_cache_malloc(KERNEL_CACHE *);
  260. void   kernel_cache_free(KERNEL_CACHE *,long);
  261. long   kernel_cache_free_lru(KERNEL_CACHE *);
  262. CFLOAT *kernel_cache_clean_and_malloc(KERNEL_CACHE *,long);
  263. long   kernel_cache_touch(KERNEL_CACHE *,long);
  264. long   kernel_cache_check(KERNEL_CACHE *,long);
  265. void compute_xa_estimates(
  266. MODEL *model,                           /* xa-estimate of error rate, */
  267. long *label,long *unlabeled,long totdoc,          /* recall, and precision      */
  268. DOC *docs,       
  269. double *lin,double *a,                      
  270. KERNEL_PARM *kernel_parm,
  271. LEARN_PARM *learn_parm,
  272. double *error,double *recall,double *precision);
  273. double xa_estimate_error(MODEL *, long *, long *, long, DOC *, 
  274.  double *, double *, KERNEL_PARM *, 
  275.  LEARN_PARM *);
  276. double xa_estimate_recall(MODEL *, long *, long *, long, DOC *, 
  277.   double *, double *, KERNEL_PARM *, 
  278.   LEARN_PARM *);
  279. double xa_estimate_precision(MODEL *, long *, long *, long, DOC *, 
  280.      double *, double *, KERNEL_PARM *, 
  281.      LEARN_PARM *);
  282. void avg_similarity_of_sv_of_one_class(MODEL *, DOC *, double *, long *, KERNEL_PARM *, double *, double *);
  283. double most_similar_sv_of_same_class(MODEL *, DOC *, double *, long, long *, KERNEL_PARM *, LEARN_PARM *);
  284. double distribute_alpha_t_greedily(long *, long, DOC *, double *, long, long *, KERNEL_PARM *, LEARN_PARM *, double);
  285. double distribute_alpha_t_greedily_noindex(MODEL *, DOC *, double *, long, long *, KERNEL_PARM *, LEARN_PARM *, double); 
  286. void estimate_transduction_quality(MODEL *, long *, long *, long, DOC *, double *);
  287. double estimate_margin_vcdim(MODEL *, double, double, KERNEL_PARM *);
  288. double estimate_sphere(MODEL *, KERNEL_PARM *);
  289. double estimate_r_delta_average(DOC *, long, KERNEL_PARM *); 
  290. double estimate_r_delta(DOC *, long, KERNEL_PARM *); 
  291. double length_of_longest_document_vector(DOC *, long, KERNEL_PARM *); 
  292. void   write_model(char *, MODEL *);
  293. void   write_prediction(char *, MODEL *, double *, double *, long *, long *,
  294. long, LEARN_PARM *);
  295. void   write_alphas(char *, double *, long *, long);
  296. private:    //svm_hideo.h
  297. double *optimize_qp(QP *, double *, long, double *, LEARN_PARM *);
  298. int optimize_hildreth_despo(long,long,double,double,double,long,long,long,double,double *,
  299.     double *,double *,double *,double *,double *,
  300.     double *,double *,double *,long *,double *);
  301. int solve_dual(long,long,double,double,long,double *,double *,double *,
  302.        double *,double *,double *,double *,double *,double *,
  303.        double *,double *,double *,double *,long);
  304. void linvert_matrix(double *, long, double *, double, long *);
  305. void lprint_matrix(double *, long);
  306. void ladd_matrix(double *, long, double);
  307. void lcopy_matrix(double *, long, double *);
  308. void lswitch_rows_matrix(double *, long, long, long);
  309. void lswitchrk_matrix(double *, long, long, long);
  310. double calculate_qp_objective(long, double *, double *, double *);
  311. public:
  312. CCompute_Prompt com_pro;
  313. CCompute_Param com_param;
  314. CCompute_Result com_result;
  315. private:    //svm_hideo.h
  316. double *primal,*dual;
  317. long   precision_violations;
  318. double opt_precision;
  319. long   maxiter;
  320. double lindep_sensitivity;
  321. double *buffer;
  322. long   *nonoptimal;
  323. long  smallroundcount;
  324. private:
  325. char temstr[MAX_PATH*10];
  326. };
  327. #endif // !defined(AFX_SVM_H__169833C2_5291_4CDB_BE81_22DBD80908F0__INCLUDED_)