svm.cpp
上传用户:xgw_05
上传日期:2014-12-08
资源大小:2726k
文件大小:42k
源码类别:

.net编程

开发平台:

Java

  1. #include <math.h>
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <ctype.h>
  5. #include <float.h>
  6. #include <string.h>
  7. #include <stdarg.h>
  8. #include "svm.h"
  9. typedef float Qfloat;
  10. typedef signed char schar;
  11. #ifndef min
  12. template <class T> inline T min(T x,T y) { return (x<y)?x:y; }
  13. #endif
  14. #ifndef max
  15. template <class T> inline T max(T x,T y) { return (x>y)?x:y; }
  16. #endif
  17. template <class T> inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
  18. template <class S, class T> inline void clone(T*& dst, S* src, int n)
  19. {
  20. dst = new T[n];
  21. memcpy((void *)dst,(void *)src,sizeof(T)*n);
  22. }
  23. #define INF HUGE_VAL
  24. #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
  25. #if 1
  26. void info(char *fmt,...)
  27. {
  28. va_list ap;
  29. va_start(ap,fmt);
  30. vprintf(fmt,ap);
  31. va_end(ap);
  32. }
  33. void info_flush()
  34. {
  35. fflush(stdout);
  36. }
  37. #else
  38. void info(char *fmt,...) {}
  39. void info_flush() {}
  40. #endif
  41. //
  42. // Kernel Cache
  43. //
  44. // l is the number of total data items
  45. // size is the cache size limit in bytes
  46. //
  47. class Cache
  48. {
  49. public:
  50. Cache(int l,int size);
  51. ~Cache();
  52. // request data [0,len)
  53. // return some position p where [p,len) need to be filled
  54. // (p >= len if nothing needs to be filled)
  55. int get_data(const int index, Qfloat **data, int len);
  56. void swap_index(int i, int j); // future_option
  57. private:
  58. int l;
  59. int size;
  60. struct head_t
  61. {
  62. head_t *prev, *next; // a cicular list
  63. Qfloat *data;
  64. int len; // data[0,len) is cached in this entry
  65. };
  66. head_t* head;
  67. head_t lru_head;
  68. void lru_delete(head_t *h);
  69. void lru_insert(head_t *h);
  70. };
  71. Cache::Cache(int l_,int size_):l(l_),size(size_)
  72. {
  73. head = (head_t *)calloc(l,sizeof(head_t)); // initialized to 0
  74. size /= sizeof(Qfloat);
  75. size -= l * sizeof(head_t) / sizeof(Qfloat);
  76. lru_head.next = lru_head.prev = &lru_head;
  77. }
  78. Cache::~Cache()
  79. {
  80. for(head_t *h = lru_head.next; h != &lru_head; h=h->next)
  81. free(h->data);
  82. free(head);
  83. }
  84. void Cache::lru_delete(head_t *h)
  85. {
  86. // delete from current location
  87. h->prev->next = h->next;
  88. h->next->prev = h->prev;
  89. }
  90. void Cache::lru_insert(head_t *h)
  91. {
  92. // insert to last position
  93. h->next = &lru_head;
  94. h->prev = lru_head.prev;
  95. h->prev->next = h;
  96. h->next->prev = h;
  97. }
  98. int Cache::get_data(const int index, Qfloat **data, int len)
  99. {
  100. head_t *h = &head[index];
  101. if(h->len) lru_delete(h);
  102. int more = len - h->len;
  103. if(more > 0)
  104. {
  105. // free old space
  106. while(size < more)
  107. {
  108. head_t *old = lru_head.next;
  109. lru_delete(old);
  110. free(old->data);
  111. size += old->len;
  112. old->data = 0;
  113. old->len = 0;
  114. }
  115. // allocate new space
  116. h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len);
  117. size -= more;
  118. swap(h->len,len);
  119. }
  120. lru_insert(h);
  121. *data = h->data;
  122. return len;
  123. }
  124. void Cache::swap_index(int i, int j)
  125. {
  126. if(i==j) return;
  127. if(head[i].len) lru_delete(&head[i]);
  128. if(head[j].len) lru_delete(&head[j]);
  129. swap(head[i].data,head[j].data);
  130. swap(head[i].len,head[j].len);
  131. if(head[i].len) lru_insert(&head[i]);
  132. if(head[j].len) lru_insert(&head[j]);
  133. if(i>j) swap(i,j);
  134. for(head_t *h = lru_head.next; h!=&lru_head; h=h->next)
  135. {
  136. if(h->len > i)
  137. {
  138. if(h->len > j)
  139. swap(h->data[i],h->data[j]);
  140. else
  141. {
  142. // give up
  143. lru_delete(h);
  144. free(h->data);
  145. size += h->len;
  146. h->data = 0;
  147. h->len = 0;
  148. }
  149. }
  150. }
  151. }
  152. //
  153. // Kernel evaluation
  154. //
  155. // the static method k_function is for doing single kernel evaluation
  156. // the constructor of Kernel prepares to calculate the l*l kernel matrix
  157. // the member function get_Q is for getting one column from the Q Matrix
  158. //
  159. class Kernel {
  160. public:
  161. Kernel(int l, svm_node * const * x, const svm_parameter& param);
  162. virtual ~Kernel();
  163. static double k_function(const svm_node *x, const svm_node *y,
  164.  const svm_parameter& param);
  165. virtual Qfloat *get_Q(int column, int len) const = 0;
  166. virtual void swap_index(int i, int j) const // no so const...
  167. {
  168. swap(x[i],x[j]);
  169. if(x_square) swap(x_square[i],x_square[j]);
  170. }
  171. protected:
  172. double (Kernel::*kernel_function)(int i, int j) const;
  173. private:
  174. const svm_node **x;
  175. double *x_square;
  176. // svm_parameter
  177. const int kernel_type;
  178. const double degree;
  179. const double gamma;
  180. const double coef0;
  181. static double dot(const svm_node *px, const svm_node *py);
  182. double kernel_linear(int i, int j) const
  183. {
  184. return dot(x[i],x[j]);
  185. }
  186. double kernel_poly(int i, int j) const
  187. {
  188. return pow(gamma*dot(x[i],x[j])+coef0,degree);
  189. }
  190. double kernel_rbf(int i, int j) const
  191. {
  192. return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
  193. }
  194. double kernel_sigmoid(int i, int j) const
  195. {
  196. return tanh(gamma*dot(x[i],x[j])+coef0);
  197. }
  198. };
  199. Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)
  200. :kernel_type(param.kernel_type), degree(param.degree),
  201.  gamma(param.gamma), coef0(param.coef0)
  202. {
  203. switch(kernel_type)
  204. {
  205. case LINEAR:
  206. kernel_function = &Kernel::kernel_linear;
  207. break;
  208. case POLY:
  209. kernel_function = &Kernel::kernel_poly;
  210. break;
  211. case RBF:
  212. kernel_function = &Kernel::kernel_rbf;
  213. break;
  214. case SIGMOID:
  215. kernel_function = &Kernel::kernel_sigmoid;
  216. break;
  217. default:
  218. fprintf(stderr,"unknown kernel function.n");
  219. exit(1);
  220. }
  221. clone(x,x_,l);
  222. if(kernel_type == RBF)
  223. {
  224. x_square = new double[l];
  225. for(int i=0;i<l;i++)
  226. x_square[i] = dot(x[i],x[i]);
  227. }
  228. else
  229. x_square = 0;
  230. }
  231. Kernel::~Kernel()
  232. {
  233. delete[] x;
  234. delete[] x_square;
  235. }
  236. double Kernel::dot(const svm_node *px, const svm_node *py)
  237. {
  238. double sum = 0;
  239. while(px->index != -1 && py->index != -1)
  240. {
  241. if(px->index == py->index)
  242. {
  243. sum += px->value * py->value;
  244. ++px;
  245. ++py;
  246. }
  247. else
  248. {
  249. if(px->index > py->index)
  250. ++py;
  251. else
  252. ++px;
  253. }
  254. }
  255. return sum;
  256. }
  257. double Kernel::k_function(const svm_node *x, const svm_node *y,
  258.   const svm_parameter& param)
  259. {
  260. switch(param.kernel_type)
  261. {
  262. case LINEAR:
  263. return dot(x,y);
  264. case POLY:
  265. return pow(param.gamma*dot(x,y)+param.coef0,param.degree);
  266. case RBF:
  267. {
  268. double sum = 0;
  269. while(x->index != -1 && y->index !=-1)
  270. {
  271. if(x->index == y->index)
  272. {
  273. double d = x->value - y->value;
  274. sum += d*d;
  275. ++x;
  276. ++y;
  277. }
  278. else
  279. {
  280. if(x->index > y->index)
  281. {
  282. sum += y->value * y->value;
  283. ++y;
  284. }
  285. else
  286. {
  287. sum += x->value * x->value;
  288. ++x;
  289. }
  290. }
  291. }
  292. while(x->index != -1)
  293. {
  294. sum += x->value * x->value;
  295. ++x;
  296. }
  297. while(y->index != -1)
  298. {
  299. sum += y->value * y->value;
  300. ++y;
  301. }
  302. return exp(-param.gamma*sum);
  303. }
  304. case SIGMOID:
  305. return tanh(param.gamma*dot(x,y)+param.coef0);
  306. default:
  307. break;
  308. }
  309. fprintf(stderr,"unknown kernel function.n");
  310. exit(1);
  311. }
  312. // Generalized SMO+SVMlight algorithm
  313. // Solves:
  314. //
  315. // min 0.5(alpha^T Q alpha) + b^T alpha
  316. //
  317. // y^T alpha = delta
  318. // y_i = +1 or -1
  319. // 0 <= alpha_i <= Cp for y_i = 1
  320. // 0 <= alpha_i <= Cn for y_i = -1
  321. //
  322. // Given:
  323. //
  324. // Q, b, y, Cp, Cn, and an initial feasible point alpha
  325. // l is the size of vectors and matrices
  326. // eps is the stopping criterion
  327. //
  328. // solution will be put in alpha, objective value will be put in obj
  329. //
  330. class Solver {
  331. public:
  332. Solver() {};
  333. virtual ~Solver() {};
  334. struct SolutionInfo {
  335. double obj;
  336. double rho;
  337. double upper_bound_p;
  338. double upper_bound_n;
  339. double r; // for Solver_NU
  340. };
  341. void Solve(int l, const Kernel& Q, const double *b_, const schar *y_,
  342.    double *alpha_, double Cp, double Cn, double eps,
  343.    SolutionInfo* si, int shrinking);
  344. protected:
  345. int active_size;
  346. schar *y;
  347. double *G; // gradient of objective function
  348. enum { LOWER_BOUND, UPPER_BOUND, FREE };
  349. char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE
  350. double *alpha;
  351. const Kernel *Q;
  352. double eps;
  353. double Cp,Cn;
  354. double *b;
  355. int *active_set;
  356. double *G_bar; // gradient, if we treat free variables as 0
  357. int l;
  358. bool unshrinked; // XXX
  359. double get_C(int i)
  360. {
  361. return (y[i] > 0)? Cp : Cn;
  362. }
  363. void update_alpha_status(int i)
  364. {
  365. if(alpha[i] >= get_C(i))
  366. alpha_status[i] = UPPER_BOUND;
  367. else if(alpha[i] <= 0)
  368. alpha_status[i] = LOWER_BOUND;
  369. else alpha_status[i] = FREE;
  370. }
  371. bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
  372. bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
  373. bool is_free(int i) { return alpha_status[i] == FREE; }
  374. void swap_index(int i, int j);
  375. void reconstruct_gradient();
  376. virtual int select_working_set(int &i, int &j);
  377. virtual double calculate_rho();
  378. virtual void do_shrinking();
  379. };
  380. void Solver::swap_index(int i, int j)
  381. {
  382. Q->swap_index(i,j);
  383. swap(y[i],y[j]);
  384. swap(G[i],G[j]);
  385. swap(alpha_status[i],alpha_status[j]);
  386. swap(alpha[i],alpha[j]);
  387. swap(b[i],b[j]);
  388. swap(active_set[i],active_set[j]);
  389. swap(G_bar[i],G_bar[j]);
  390. }
  391. void Solver::reconstruct_gradient()
  392. {
  393. // reconstruct inactive elements of G from G_bar and free variables
  394. if(active_size == l) return;
  395. int i;
  396. for(i=active_size;i<l;i++)
  397. G[i] = G_bar[i] + b[i];
  398. for(i=0;i<active_size;i++)
  399. if(is_free(i))
  400. {
  401. const Qfloat *Q_i = Q->get_Q(i,l);
  402. double alpha_i = alpha[i];
  403. for(int j=active_size;j<l;j++)
  404. G[j] += alpha_i * Q_i[j];
  405. }
  406. }
  407. void Solver::Solve(int l, const Kernel& Q, const double *b_, const schar *y_,
  408.    double *alpha_, double Cp, double Cn, double eps,
  409.    SolutionInfo* si, int shrinking)
  410. {
  411. this->l = l;
  412. this->Q = &Q;
  413. clone(b, b_,l);
  414. clone(y, y_,l);
  415. clone(alpha,alpha_,l);
  416. this->Cp = Cp;
  417. this->Cn = Cn;
  418. this->eps = eps;
  419. unshrinked = false;
  420. // initialize alpha_status
  421. {
  422. alpha_status = new char[l];
  423. for(int i=0;i<l;i++)
  424. update_alpha_status(i);
  425. }
  426. // initialize active set (for shrinking)
  427. {
  428. active_set = new int[l];
  429. for(int i=0;i<l;i++)
  430. active_set[i] = i;
  431. active_size = l;
  432. }
  433. // initialize gradient
  434. {
  435. G = new double[l];
  436. G_bar = new double[l];
  437. int i;
  438. for(i=0;i<l;i++)
  439. {
  440. G[i] = b[i];
  441. G_bar[i] = 0;
  442. }
  443. for(i=0;i<l;i++)
  444. if(!is_lower_bound(i))
  445. {
  446. Qfloat *Q_i = Q.get_Q(i,l);
  447. double alpha_i = alpha[i];
  448. int j;
  449. for(j=0;j<l;j++)
  450. G[j] += alpha_i*Q_i[j];
  451. if(is_upper_bound(i))
  452. for(j=0;j<l;j++)
  453. G_bar[j] += get_C(i) * Q_i[j];
  454. }
  455. }
  456. // optimization step
  457. int iter = 0;
  458. int counter = min(l,1000)+1;
  459. while(1)
  460. {
  461. // show progress and do shrinking
  462. if(--counter == 0)
  463. {
  464. counter = min(l,1000);
  465. if(shrinking) do_shrinking();
  466. info("."); info_flush();
  467. }
  468. int i,j;
  469. if(select_working_set(i,j)!=0)
  470. {
  471. // reconstruct the whole gradient
  472. reconstruct_gradient();
  473. // reset active set size and check
  474. active_size = l;
  475. info("*"); info_flush();
  476. if(select_working_set(i,j)!=0)
  477. break;
  478. else
  479. counter = 1; // do shrinking next iteration
  480. }
  481. ++iter;
  482. // update alpha[i] and alpha[j], handle bounds carefully
  483. const Qfloat *Q_i = Q.get_Q(i,active_size);
  484. const Qfloat *Q_j = Q.get_Q(j,active_size);
  485. double C_i = get_C(i);
  486. double C_j = get_C(j);
  487. double old_alpha_i = alpha[i];
  488. double old_alpha_j = alpha[j];
  489. if(y[i]!=y[j])
  490. {
  491. double delta = (-G[i]-G[j])/(Q_i[i]+Q_j[j]+2*Q_i[j]);
  492. double diff = alpha[i] - alpha[j];
  493. alpha[i] += delta;
  494. alpha[j] += delta;
  495. if(diff > 0)
  496. {
  497. if(alpha[j] < 0)
  498. {
  499. alpha[j] = 0;
  500. alpha[i] = diff;
  501. }
  502. }
  503. else
  504. {
  505. if(alpha[i] < 0)
  506. {
  507. alpha[i] = 0;
  508. alpha[j] = -diff;
  509. }
  510. }
  511. if(diff > C_i - C_j)
  512. {
  513. if(alpha[i] > C_i)
  514. {
  515. alpha[i] = C_i;
  516. alpha[j] = C_i - diff;
  517. }
  518. }
  519. else
  520. {
  521. if(alpha[j] > C_j)
  522. {
  523. alpha[j] = C_j;
  524. alpha[i] = C_j + diff;
  525. }
  526. }
  527. }
  528. else
  529. {
  530. double delta = (G[i]-G[j])/(Q_i[i]+Q_j[j]-2*Q_i[j]);
  531. double sum = alpha[i] + alpha[j];
  532. alpha[i] -= delta;
  533. alpha[j] += delta;
  534. if(sum > C_i)
  535. {
  536. if(alpha[i] > C_i)
  537. {
  538. alpha[i] = C_i;
  539. alpha[j] = sum - C_i;
  540. }
  541. }
  542. else
  543. {
  544. if(alpha[j] < 0)
  545. {
  546. alpha[j] = 0;
  547. alpha[i] = sum;
  548. }
  549. }
  550. if(sum > C_j)
  551. {
  552. if(alpha[j] > C_j)
  553. {
  554. alpha[j] = C_j;
  555. alpha[i] = sum - C_j;
  556. }
  557. }
  558. else
  559. {
  560. if(alpha[i] < 0)
  561. {
  562. alpha[i] = 0;
  563. alpha[j] = sum;
  564. }
  565. }
  566. }
  567. // update G
  568. double delta_alpha_i = alpha[i] - old_alpha_i;
  569. double delta_alpha_j = alpha[j] - old_alpha_j;
  570. for(int k=0;k<active_size;k++)
  571. {
  572. G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
  573. }
  574. // update alpha_status and G_bar
  575. {
  576. bool ui = is_upper_bound(i);
  577. bool uj = is_upper_bound(j);
  578. update_alpha_status(i);
  579. update_alpha_status(j);
  580. int k;
  581. if(ui != is_upper_bound(i))
  582. {
  583. Q_i = Q.get_Q(i,l);
  584. if(ui)
  585. for(k=0;k<l;k++)
  586. G_bar[k] -= C_i * Q_i[k];
  587. else
  588. for(k=0;k<l;k++)
  589. G_bar[k] += C_i * Q_i[k];
  590. }
  591. if(uj != is_upper_bound(j))
  592. {
  593. Q_j = Q.get_Q(j,l);
  594. if(uj)
  595. for(k=0;k<l;k++)
  596. G_bar[k] -= C_j * Q_j[k];
  597. else
  598. for(k=0;k<l;k++)
  599. G_bar[k] += C_j * Q_j[k];
  600. }
  601. }
  602. }
  603. // calculate rho
  604. si->rho = calculate_rho();
  605. // calculate objective value
  606. {
  607. double v = 0;
  608. int i;
  609. for(i=0;i<l;i++)
  610. v += alpha[i] * (G[i] + b[i]);
  611. si->obj = v/2;
  612. }
  613. // put back the solution
  614. {
  615. for(int i=0;i<l;i++)
  616. alpha_[active_set[i]] = alpha[i];
  617. }
  618. // juggle everything back
  619. /*{
  620. for(int i=0;i<l;i++)
  621. while(active_set[i] != i)
  622. swap_index(i,active_set[i]);
  623. // or Q.swap_index(i,active_set[i]);
  624. }*/
  625. si->upper_bound_p = Cp;
  626. si->upper_bound_n = Cn;
  627. info("noptimization finished, #iter = %dn",iter);
  628. delete[] b;
  629. delete[] y;
  630. delete[] alpha;
  631. delete[] alpha_status;
  632. delete[] active_set;
  633. delete[] G;
  634. delete[] G_bar;
  635. }
  636. // return 1 if already optimal, return 0 otherwise
  637. int Solver::select_working_set(int &out_i, int &out_j)
  638. {
  639. // return i,j which maximize -grad(f)^T d , under constraint
  640. // if alpha_i == C, d != +1
  641. // if alpha_i == 0, d != -1
  642. double Gmax1 = -INF; // max { -grad(f)_i * d | y_i*d = +1 }
  643. int Gmax1_idx = -1;
  644. double Gmax2 = -INF; // max { -grad(f)_i * d | y_i*d = -1 }
  645. int Gmax2_idx = -1;
  646. for(int i=0;i<active_size;i++)
  647. {
  648. if(y[i]==+1) // y = +1
  649. {
  650. if(!is_upper_bound(i)) // d = +1
  651. {
  652. if(-G[i] > Gmax1)
  653. {
  654. Gmax1 = -G[i];
  655. Gmax1_idx = i;
  656. }
  657. }
  658. if(!is_lower_bound(i)) // d = -1
  659. {
  660. if(G[i] > Gmax2)
  661. {
  662. Gmax2 = G[i];
  663. Gmax2_idx = i;
  664. }
  665. }
  666. }
  667. else // y = -1
  668. {
  669. if(!is_upper_bound(i)) // d = +1
  670. {
  671. if(-G[i] > Gmax2)
  672. {
  673. Gmax2 = -G[i];
  674. Gmax2_idx = i;
  675. }
  676. }
  677. if(!is_lower_bound(i)) // d = -1
  678. {
  679. if(G[i] > Gmax1)
  680. {
  681. Gmax1 = G[i];
  682. Gmax1_idx = i;
  683. }
  684. }
  685. }
  686. }
  687. if(Gmax1+Gmax2 < eps)
  688.   return 1;
  689. out_i = Gmax1_idx;
  690. out_j = Gmax2_idx;
  691. return 0;
  692. }
  693. void Solver::do_shrinking()
  694. {
  695. int i,j,k;
  696. if(select_working_set(i,j)!=0) return;
  697. double Gm1 = -y[j]*G[j];
  698. double Gm2 = y[i]*G[i];
  699. // shrink
  700. for(k=0;k<active_size;k++)
  701. {
  702. if(is_lower_bound(k))
  703. {
  704. if(y[k]==+1)
  705. {
  706. if(-G[k] >= Gm1) continue;
  707. }
  708. else if(-G[k] >= Gm2) continue;
  709. }
  710. else if(is_upper_bound(k))
  711. {
  712. if(y[k]==+1)
  713. {
  714. if(G[k] >= Gm2) continue;
  715. }
  716. else if(G[k] >= Gm1) continue;
  717. }
  718. else continue;
  719. --active_size;
  720. swap_index(k,active_size);
  721. --k; // look at the newcomer
  722. }
  723. // unshrink, check all variables again before final iterations
  724. if(unshrinked || -(Gm1 + Gm2) > eps*10) return;
  725. unshrinked = true;
  726. reconstruct_gradient();
  727. for(k=l-1;k>=active_size;k--)
  728. {
  729. if(is_lower_bound(k))
  730. {
  731. if(y[k]==+1)
  732. {
  733. if(-G[k] < Gm1) continue;
  734. }
  735. else if(-G[k] < Gm2) continue;
  736. }
  737. else if(is_upper_bound(k))
  738. {
  739. if(y[k]==+1)
  740. {
  741. if(G[k] < Gm2) continue;
  742. }
  743. else if(G[k] < Gm1) continue;
  744. }
  745. else continue;
  746. swap_index(k,active_size);
  747. active_size++;
  748. ++k; // look at the newcomer
  749. }
  750. }
  751. double Solver::calculate_rho()
  752. {
  753. double r;
  754. int nr_free = 0;
  755. double ub = INF, lb = -INF, sum_free = 0;
  756. for(int i=0;i<active_size;i++)
  757. {
  758. double yG = y[i]*G[i];
  759. if(is_lower_bound(i))
  760. {
  761. if(y[i] > 0)
  762. ub = min(ub,yG);
  763. else
  764. lb = max(lb,yG);
  765. }
  766. else if(is_upper_bound(i))
  767. {
  768. if(y[i] < 0)
  769. ub = min(ub,yG);
  770. else
  771. lb = max(lb,yG);
  772. }
  773. else
  774. {
  775. ++nr_free;
  776. sum_free += yG;
  777. }
  778. }
  779. if(nr_free>0)
  780. r = sum_free/nr_free;
  781. else
  782. r = (ub+lb)/2;
  783. return r;
  784. }
  785. //
  786. // Solver for nu-svm classification and regression
  787. //
  788. // additional constraint: e^T alpha = constant
  789. //
  790. class Solver_NU : public Solver
  791. {
  792. public:
  793. Solver_NU() {}
  794. void Solve(int l, const Kernel& Q, const double *b, const schar *y,
  795.    double *alpha, double Cp, double Cn, double eps,
  796.    SolutionInfo* si, int shrinking)
  797. {
  798. this->si = si;
  799. Solver::Solve(l,Q,b,y,alpha,Cp,Cn,eps,si,shrinking);
  800. }
  801. private:
  802. SolutionInfo *si;
  803. int select_working_set(int &i, int &j);
  804. double calculate_rho();
  805. void do_shrinking();
  806. };
  807. int Solver_NU::select_working_set(int &out_i, int &out_j)
  808. {
  809. // return i,j which maximize -grad(f)^T d , under constraint
  810. // if alpha_i == C, d != +1
  811. // if alpha_i == 0, d != -1
  812. double Gmax1 = -INF; // max { -grad(f)_i * d | y_i = +1, d = +1 }
  813. int Gmax1_idx = -1;
  814. double Gmax2 = -INF; // max { -grad(f)_i * d | y_i = +1, d = -1 }
  815. int Gmax2_idx = -1;
  816. double Gmax3 = -INF; // max { -grad(f)_i * d | y_i = -1, d = +1 }
  817. int Gmax3_idx = -1;
  818. double Gmax4 = -INF; // max { -grad(f)_i * d | y_i = -1, d = -1 }
  819. int Gmax4_idx = -1;
  820. for(int i=0;i<active_size;i++)
  821. {
  822. if(y[i]==+1) // y == +1
  823. {
  824. if(!is_upper_bound(i)) // d = +1
  825. {
  826. if(-G[i] > Gmax1)
  827. {
  828. Gmax1 = -G[i];
  829. Gmax1_idx = i;
  830. }
  831. }
  832. if(!is_lower_bound(i)) // d = -1
  833. {
  834. if(G[i] > Gmax2)
  835. {
  836. Gmax2 = G[i];
  837. Gmax2_idx = i;
  838. }
  839. }
  840. }
  841. else // y == -1
  842. {
  843. if(!is_upper_bound(i)) // d = +1
  844. {
  845. if(-G[i] > Gmax3)
  846. {
  847. Gmax3 = -G[i];
  848. Gmax3_idx = i;
  849. }
  850. }
  851. if(!is_lower_bound(i)) // d = -1
  852. {
  853. if(G[i] > Gmax4)
  854. {
  855. Gmax4 = G[i];
  856. Gmax4_idx = i;
  857. }
  858. }
  859. }
  860. }
  861. if(max(Gmax1+Gmax2,Gmax3+Gmax4) < eps)
  862.   return 1;
  863. if(Gmax1+Gmax2 > Gmax3+Gmax4)
  864. {
  865. out_i = Gmax1_idx;
  866. out_j = Gmax2_idx;
  867. }
  868. else
  869. {
  870. out_i = Gmax3_idx;
  871. out_j = Gmax4_idx;
  872. }
  873. return 0;
  874. }
  875. void Solver_NU::do_shrinking()
  876. {
  877. double Gmax1 = -INF; // max { -grad(f)_i * d | y_i = +1, d = +1 }
  878. double Gmax2 = -INF; // max { -grad(f)_i * d | y_i = +1, d = -1 }
  879. double Gmax3 = -INF; // max { -grad(f)_i * d | y_i = -1, d = +1 }
  880. double Gmax4 = -INF; // max { -grad(f)_i * d | y_i = -1, d = -1 }
  881. int k;
  882. for(k=0;k<active_size;k++)
  883. {
  884. if(!is_upper_bound(k))
  885. {
  886. if(y[k]==+1)
  887. {
  888. if(-G[k] > Gmax1) Gmax1 = -G[k];
  889. }
  890. else if(-G[k] > Gmax3) Gmax3 = -G[k];
  891. }
  892. if(!is_lower_bound(k))
  893. {
  894. if(y[k]==+1)
  895. {
  896. if(G[k] > Gmax2) Gmax2 = G[k];
  897. }
  898. else if(G[k] > Gmax4) Gmax4 = G[k];
  899. }
  900. }
  901. double Gm1 = -Gmax2;
  902. double Gm2 = -Gmax1;
  903. double Gm3 = -Gmax4;
  904. double Gm4 = -Gmax3;
  905. for(k=0;k<active_size;k++)
  906. {
  907. if(is_lower_bound(k))
  908. {
  909. if(y[k]==+1)
  910. {
  911. if(-G[k] >= Gm1) continue;
  912. }
  913. else if(-G[k] >= Gm3) continue;
  914. }
  915. else if(is_upper_bound(k))
  916. {
  917. if(y[k]==+1)
  918. {
  919. if(G[k] >= Gm2) continue;
  920. }
  921. else if(G[k] >= Gm4) continue;
  922. }
  923. else continue;
  924. --active_size;
  925. swap_index(k,active_size);
  926. --k; // look at the newcomer
  927. }
  928. // unshrink, check all variables again before final iterations
  929. if(unshrinked || max(-(Gm1+Gm2),-(Gm3+Gm4)) > eps*10) return;
  930. unshrinked = true;
  931. reconstruct_gradient();
  932. for(k=l-1;k>=active_size;k--)
  933. {
  934. if(is_lower_bound(k))
  935. {
  936. if(y[k]==+1)
  937. {
  938. if(-G[k] < Gm1) continue;
  939. }
  940. else if(-G[k] < Gm3) continue;
  941. }
  942. else if(is_upper_bound(k))
  943. {
  944. if(y[k]==+1)
  945. {
  946. if(G[k] < Gm2) continue;
  947. }
  948. else if(G[k] < Gm4) continue;
  949. }
  950. else continue;
  951. swap_index(k,active_size);
  952. active_size++;
  953. ++k; // look at the newcomer
  954. }
  955. }
  956. double Solver_NU::calculate_rho()
  957. {
  958. int nr_free1 = 0,nr_free2 = 0;
  959. double ub1 = INF, ub2 = INF;
  960. double lb1 = -INF, lb2 = -INF;
  961. double sum_free1 = 0, sum_free2 = 0;
  962. for(int i=0;i<active_size;i++)
  963. {
  964. if(y[i]==+1)
  965. {
  966. if(is_lower_bound(i))
  967. ub1 = min(ub1,G[i]);
  968. else if(is_upper_bound(i))
  969. lb1 = max(lb1,G[i]);
  970. else
  971. {
  972. ++nr_free1;
  973. sum_free1 += G[i];
  974. }
  975. }
  976. else
  977. {
  978. if(is_lower_bound(i))
  979. ub2 = min(ub2,G[i]);
  980. else if(is_upper_bound(i))
  981. lb2 = max(lb2,G[i]);
  982. else
  983. {
  984. ++nr_free2;
  985. sum_free2 += G[i];
  986. }
  987. }
  988. }
  989. double r1,r2;
  990. if(nr_free1 > 0)
  991. r1 = sum_free1/nr_free1;
  992. else
  993. r1 = (ub1+lb1)/2;
  994. if(nr_free2 > 0)
  995. r2 = sum_free2/nr_free2;
  996. else
  997. r2 = (ub2+lb2)/2;
  998. si->r = (r1+r2)/2;
  999. return (r1-r2)/2;
  1000. }
  1001. //
  1002. // Q matrices for various formulations
  1003. //
  1004. class SVC_Q: public Kernel
  1005. public:
  1006. SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
  1007. :Kernel(prob.l, prob.x, param)
  1008. {
  1009. clone(y,y_,prob.l);
  1010. cache = new Cache(prob.l,(int)(param.cache_size*(1<<20)));
  1011. }
  1012. Qfloat *get_Q(int i, int len) const
  1013. {
  1014. Qfloat *data;
  1015. int start;
  1016. if((start = cache->get_data(i,&data,len)) < len)
  1017. {
  1018. for(int j=start;j<len;j++)
  1019. data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
  1020. }
  1021. return data;
  1022. }
  1023. void swap_index(int i, int j) const
  1024. {
  1025. cache->swap_index(i,j);
  1026. Kernel::swap_index(i,j);
  1027. swap(y[i],y[j]);
  1028. }
  1029. ~SVC_Q()
  1030. {
  1031. delete[] y;
  1032. delete cache;
  1033. }
  1034. private:
  1035. schar *y;
  1036. Cache *cache;
  1037. };
  1038. class ONE_CLASS_Q: public Kernel
  1039. {
  1040. public:
  1041. ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
  1042. :Kernel(prob.l, prob.x, param)
  1043. {
  1044. cache = new Cache(prob.l,(int)(param.cache_size*(1<<20)));
  1045. }
  1046. Qfloat *get_Q(int i, int len) const
  1047. {
  1048. Qfloat *data;
  1049. int start;
  1050. if((start = cache->get_data(i,&data,len)) < len)
  1051. {
  1052. for(int j=start;j<len;j++)
  1053. data[j] = (Qfloat)(this->*kernel_function)(i,j);
  1054. }
  1055. return data;
  1056. }
  1057. void swap_index(int i, int j) const
  1058. {
  1059. cache->swap_index(i,j);
  1060. Kernel::swap_index(i,j);
  1061. }
  1062. ~ONE_CLASS_Q()
  1063. {
  1064. delete cache;
  1065. }
  1066. private:
  1067. Cache *cache;
  1068. };
  1069. class SVR_Q: public Kernel
  1070. public:
  1071. SVR_Q(const svm_problem& prob, const svm_parameter& param)
  1072. :Kernel(prob.l, prob.x, param)
  1073. {
  1074. l = prob.l;
  1075. cache = new Cache(l,(int)(param.cache_size*(1<<20)));
  1076. sign = new schar[2*l];
  1077. index = new int[2*l];
  1078. for(int k=0;k<l;k++)
  1079. {
  1080. sign[k] = 1;
  1081. sign[k+l] = -1;
  1082. index[k] = k;
  1083. index[k+l] = k;
  1084. }
  1085. buffer[0] = new Qfloat[2*l];
  1086. buffer[1] = new Qfloat[2*l];
  1087. next_buffer = 0;
  1088. }
  1089. void swap_index(int i, int j) const
  1090. {
  1091. swap(sign[i],sign[j]);
  1092. swap(index[i],index[j]);
  1093. }
  1094. Qfloat *get_Q(int i, int len) const
  1095. {
  1096. Qfloat *data;
  1097. int real_i = index[i];
  1098. if(cache->get_data(real_i,&data,l) < l)
  1099. {
  1100. for(int j=0;j<l;j++)
  1101. data[j] = (Qfloat)(this->*kernel_function)(real_i,j);
  1102. }
  1103. // reorder and copy
  1104. Qfloat *buf = buffer[next_buffer];
  1105. next_buffer = 1 - next_buffer;
  1106. schar si = sign[i];
  1107. for(int j=0;j<len;j++)
  1108. buf[j] = si * sign[j] * data[index[j]];
  1109. return buf;
  1110. }
  1111. ~SVR_Q()
  1112. {
  1113. delete cache;
  1114. delete[] sign;
  1115. delete[] index;
  1116. delete[] buffer[0];
  1117. delete[] buffer[1];
  1118. }
  1119. private:
  1120. int l;
  1121. Cache *cache;
  1122. schar *sign;
  1123. int *index;
  1124. mutable int next_buffer;
  1125. Qfloat* buffer[2];
  1126. };
  1127. //
  1128. // construct and solve various formulations
  1129. //
  1130. static void solve_c_svc(
  1131. const svm_problem *prob, const svm_parameter* param,
  1132. double *alpha, Solver::SolutionInfo* si, double Cp, double Cn)
  1133. {
  1134. int l = prob->l;
  1135. double *minus_ones = new double[l];
  1136. schar *y = new schar[l];
  1137. int i;
  1138. for(i=0;i<l;i++)
  1139. {
  1140. alpha[i] = 0;
  1141. minus_ones[i] = -1;
  1142. if(prob->y[i] > 0) y[i] = +1; else y[i]=-1;
  1143. }
  1144. Solver s;
  1145. s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
  1146. alpha, Cp, Cn, param->eps, si, param->shrinking);
  1147. double sum_alpha=0;
  1148. for(i=0;i<l;i++)
  1149. sum_alpha += alpha[i];
  1150. info("nu = %fn", sum_alpha/(param->C*prob->l));
  1151. for(i=0;i<l;i++)
  1152. alpha[i] *= y[i];
  1153. delete[] minus_ones;
  1154. delete[] y;
  1155. }
  1156. static void solve_nu_svc(
  1157. const svm_problem *prob, const svm_parameter *param,
  1158. double *alpha, Solver::SolutionInfo* si)
  1159. {
  1160. int i;
  1161. int l = prob->l;
  1162. double nu = param->nu;
  1163. int y_pos = 0;
  1164. int y_neg = 0;
  1165. schar *y = new schar[l];
  1166. for(i=0;i<l;i++)
  1167. if(prob->y[i]>0)
  1168. {
  1169. y[i] = +1;
  1170. ++y_pos;
  1171. }
  1172. else
  1173. {
  1174. y[i] = -1;
  1175. ++y_neg;
  1176. }
  1177. if(nu < 0 || nu*l/2 > min(y_pos,y_neg))
  1178. {
  1179. fprintf(stderr,"specified nu is infeasiblen");
  1180. exit(1);
  1181. }
  1182. double sum_pos = nu*l/2;
  1183. double sum_neg = nu*l/2;
  1184. for(i=0;i<l;i++)
  1185. if(y[i] == +1)
  1186. {
  1187. alpha[i] = min(1.0,sum_pos);
  1188. sum_pos -= alpha[i];
  1189. }
  1190. else
  1191. {
  1192. alpha[i] = min(1.0,sum_neg);
  1193. sum_neg -= alpha[i];
  1194. }
  1195. double *zeros = new double[l];
  1196. for(i=0;i<l;i++)
  1197. zeros[i] = 0;
  1198. Solver_NU s;
  1199. s.Solve(l, SVC_Q(*prob,*param,y), zeros, y,
  1200. alpha, 1.0, 1.0, param->eps, si,  param->shrinking);
  1201. double r = si->r;
  1202. info("C = %fn",1/r);
  1203. for(i=0;i<l;i++)
  1204. alpha[i] *= y[i]/r;
  1205. si->rho /= r;
  1206. si->obj /= (r*r);
  1207. si->upper_bound_p = 1/r;
  1208. si->upper_bound_n = 1/r;
  1209. delete[] y;
  1210. delete[] zeros;
  1211. }
  1212. static void solve_one_class(
  1213. const svm_problem *prob, const svm_parameter *param,
  1214. double *alpha, Solver::SolutionInfo* si)
  1215. {
  1216. int l = prob->l;
  1217. double *zeros = new double[l];
  1218. schar *ones = new schar[l];
  1219. int i;
  1220. int n = (int)(param->nu*prob->l); // # of alpha's at upper bound
  1221. if(n>=prob->l)
  1222. {
  1223. fprintf(stderr,"nu must be in (0,1)n");
  1224. exit(1);
  1225. }
  1226. for(i=0;i<n;i++)
  1227. alpha[i] = 1;
  1228. alpha[n] = param->nu * prob->l - n;
  1229. for(i=n+1;i<l;i++)
  1230. alpha[i] = 0;
  1231. for(i=0;i<l;i++)
  1232. {
  1233. zeros[i] = 0;
  1234. ones[i] = 1;
  1235. }
  1236. Solver s;
  1237. s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones,
  1238. alpha, 1.0, 1.0, param->eps, si, param->shrinking);
  1239. delete[] zeros;
  1240. delete[] ones;
  1241. }
  1242. static void solve_epsilon_svr(
  1243. const svm_problem *prob, const svm_parameter *param,
  1244. double *alpha, Solver::SolutionInfo* si)
  1245. {
  1246. int l = prob->l;
  1247. double *alpha2 = new double[2*l];
  1248. double *linear_term = new double[2*l];
  1249. schar *y = new schar[2*l];
  1250. int i;
  1251. for(i=0;i<l;i++)
  1252. {
  1253. alpha2[i] = 0;
  1254. linear_term[i] = param->p - prob->y[i];
  1255. y[i] = 1;
  1256. alpha2[i+l] = 0;
  1257. linear_term[i+l] = param->p + prob->y[i];
  1258. y[i+l] = -1;
  1259. }
  1260. Solver s;
  1261. s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
  1262. alpha2, param->C, param->C, param->eps, si, param->shrinking);
  1263. double sum_alpha = 0;
  1264. for(i=0;i<l;i++)
  1265. {
  1266. alpha[i] = alpha2[i] - alpha2[i+l];
  1267. sum_alpha += fabs(alpha[i]);
  1268. }
  1269. info("nu = %fn",sum_alpha/(param->C*l));
  1270. delete[] alpha2;
  1271. delete[] linear_term;
  1272. delete[] y;
  1273. }
  1274. static void solve_nu_svr(
  1275. const svm_problem *prob, const svm_parameter *param,
  1276. double *alpha, Solver::SolutionInfo* si)
  1277. {
  1278. if(param->nu < 0 || param->nu > 1)
  1279. {
  1280. fprintf(stderr,"specified nu is out of rangen");
  1281. exit(1);
  1282. }
  1283. int l = prob->l;
  1284. double C = param->C;
  1285. double *alpha2 = new double[2*l];
  1286. double *linear_term = new double[2*l];
  1287. schar *y = new schar[2*l];
  1288. int i;
  1289. double sum = C * param->nu * l / 2;
  1290. for(i=0;i<l;i++)
  1291. {
  1292. alpha2[i] = alpha2[i+l] = min(sum,C);
  1293. sum -= alpha2[i];
  1294. linear_term[i] = - prob->y[i];
  1295. y[i] = 1;
  1296. linear_term[i+l] = prob->y[i];
  1297. y[i+l] = -1;
  1298. }
  1299. Solver_NU s;
  1300. s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
  1301. alpha2, C, C, param->eps, si, param->shrinking);
  1302. info("epsilon = %fn",-si->r);
  1303. for(i=0;i<l;i++)
  1304. alpha[i] = alpha2[i] - alpha2[i+l];
  1305. delete[] alpha2;
  1306. delete[] linear_term;
  1307. delete[] y;
  1308. }
  1309. //
  1310. // decision_function
  1311. //
  1312. struct decision_function
  1313. {
  1314. double *alpha;
  1315. double rho;
  1316. };
  1317. decision_function svm_train_one(
  1318. const svm_problem *prob, const svm_parameter *param,
  1319. double Cp, double Cn)
  1320. {
  1321. double *alpha = Malloc(double,prob->l);
  1322. Solver::SolutionInfo si;
  1323. switch(param->svm_type)
  1324. {
  1325. case C_SVC:
  1326. solve_c_svc(prob,param,alpha,&si,Cp,Cn);
  1327. break;
  1328. case NU_SVC:
  1329. solve_nu_svc(prob,param,alpha,&si);
  1330. break;
  1331. case ONE_CLASS:
  1332. solve_one_class(prob,param,alpha,&si);
  1333. break;
  1334. case EPSILON_SVR:
  1335. solve_epsilon_svr(prob,param,alpha,&si);
  1336. break;
  1337. case NU_SVR:
  1338. solve_nu_svr(prob,param,alpha,&si);
  1339. break;
  1340. }
  1341. info("obj = %f, rho = %fn",si.obj,si.rho);
  1342. // output SVs
  1343. int nSV = 0;
  1344. int nBSV = 0;
  1345. for(int i=0;i<prob->l;i++)
  1346. {
  1347. if(fabs(alpha[i]) > 0)
  1348. {
  1349. ++nSV;
  1350. if(prob->y[i] > 0)
  1351. {
  1352. if(fabs(alpha[i]) >= si.upper_bound_p)
  1353. ++nBSV;
  1354. }
  1355. else
  1356. {
  1357. if(fabs(alpha[i]) >= si.upper_bound_n)
  1358. ++nBSV;
  1359. }
  1360. }
  1361. }
  1362. info("nSV = %d, nBSV = %dn",nSV,nBSV);
  1363. decision_function f;
  1364. f.alpha = alpha;
  1365. f.rho = si.rho;
  1366. return f;
  1367. }
  1368. //
  1369. // svm_model
  1370. //
  1371. struct svm_model
  1372. {
  1373. svm_parameter param; // parameter
  1374. int nr_class; // number of classes, = 2 in regression/one class svm
  1375. int l; // total #SV
  1376. svm_node **SV; // SVs (SV[l])
  1377. double **sv_coef; // coefficients for SVs in decision functions (sv_coef[n-1][l])
  1378. double *rho; // constants in decision functions (rho[n*(n-1)/2])
  1379. // for classification only
  1380. int *label; // label of each class (label[n])
  1381. int *nSV; // number of SVs for each class (nSV[n])
  1382. // nSV[0] + nSV[1] + ... + nSV[n-1] = l
  1383. // XXX
  1384. int free_sv; // 1 if svm_model is created by svm_load_model
  1385. // 0 if svm_model is created by svm_train
  1386. };
  1387. //
  1388. // Interface functions
  1389. //
  1390. svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)
  1391. {
  1392. svm_model *model = Malloc(svm_model,1);
  1393. model->param = *param;
  1394. model->free_sv = 0; // XXX
  1395. if(param->svm_type == ONE_CLASS ||
  1396.    param->svm_type == EPSILON_SVR ||
  1397.    param->svm_type == NU_SVR)
  1398. {
  1399. // regression or one-class-svm
  1400. model->nr_class = 2;
  1401. model->label = NULL;
  1402. model->nSV = NULL;
  1403. model->sv_coef = Malloc(double *,1);
  1404. decision_function f = svm_train_one(prob,param,0,0);
  1405. model->rho = Malloc(double,1);
  1406. model->rho[0] = f.rho;
  1407. int nSV = 0;
  1408. int i;
  1409. for(i=0;i<prob->l;i++)
  1410. if(fabs(f.alpha[i]) > 0) ++nSV;
  1411. model->l = nSV;
  1412. model->SV = Malloc(svm_node *,nSV);
  1413. model->sv_coef[0] = Malloc(double,nSV);
  1414. int j = 0;
  1415. for(i=0;i<prob->l;i++)
  1416. if(fabs(f.alpha[i]) > 0)
  1417. {
  1418. model->SV[j] = prob->x[i];
  1419. model->sv_coef[0][j] = f.alpha[i];
  1420. ++j;
  1421. }
  1422. free(f.alpha);
  1423. }
  1424. else
  1425. {
  1426. // classification
  1427. // find out the number of classes
  1428. int l = prob->l;
  1429. int max_nr_class = 16;
  1430. int nr_class = 0;
  1431. int *label = Malloc(int,max_nr_class);
  1432. int *count = Malloc(int,max_nr_class);
  1433. int *index = Malloc(int,l);
  1434. int i;
  1435. for(i=0;i<l;i++)
  1436. {
  1437. int this_label = (int)prob->y[i];
  1438. int j;
  1439. for(j=0;j<nr_class;j++)
  1440. if(this_label == label[j])
  1441. {
  1442. ++count[j];
  1443. break;
  1444. }
  1445. index[i] = j;
  1446. if(j == nr_class)
  1447. {
  1448. if(nr_class == max_nr_class)
  1449. {
  1450. max_nr_class *= 2;
  1451. label = (int *)realloc(label,max_nr_class*sizeof(int));
  1452. count = (int *)realloc(count,max_nr_class*sizeof(int));
  1453. }
  1454. label[nr_class] = this_label;
  1455. count[nr_class] = 1;
  1456. ++nr_class;
  1457. }
  1458. }
  1459. // group training data of the same class
  1460. int *start = Malloc(int,nr_class);
  1461. start[0] = 0;
  1462. for(i=1;i<nr_class;i++)
  1463. start[i] = start[i-1]+count[i-1];
  1464. svm_node **x = Malloc(svm_node *,l);
  1465. for(i=0;i<l;i++)
  1466. {
  1467. x[start[index[i]]] = prob->x[i];
  1468. ++start[index[i]];
  1469. }
  1470. start[0] = 0;
  1471. for(i=1;i<nr_class;i++)
  1472. start[i] = start[i-1]+count[i-1];
  1473. // calculate weighted C
  1474. double *weighted_C = Malloc(double, nr_class);
  1475. for(i=0;i<nr_class;i++)
  1476. weighted_C[i] = param->C;
  1477. for(i=0;i<param->nr_weight;i++)
  1478. {
  1479. int j;
  1480. for(j=0;j<nr_class;j++)
  1481. if(param->weight_label[i] == label[j])
  1482. break;
  1483. if(j == nr_class)
  1484. fprintf(stderr,"warning: class label %d specified in weight is not foundn", param->weight_label[i]);
  1485. else
  1486. weighted_C[j] *= param->weight[i];
  1487. }
  1488. // train n*(n-1)/2 models
  1489. bool *nonzero = Malloc(bool,l);
  1490. for(i=0;i<l;i++)
  1491. nonzero[i] = false;
  1492. decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);
  1493. int p = 0;
  1494. for(i=0;i<nr_class;i++)
  1495. for(int j=i+1;j<nr_class;j++)
  1496. {
  1497. svm_problem sub_prob;
  1498. int si = start[i], sj = start[j];
  1499. int ci = count[i], cj = count[j];
  1500. sub_prob.l = ci+cj;
  1501. sub_prob.x = Malloc(svm_node *,sub_prob.l);
  1502. sub_prob.y = Malloc(double,sub_prob.l);
  1503. int k;
  1504. for(k=0;k<ci;k++)
  1505. {
  1506. sub_prob.x[k] = x[si+k];
  1507. sub_prob.y[k] = +1;
  1508. }
  1509. for(k=0;k<cj;k++)
  1510. {
  1511. sub_prob.x[ci+k] = x[sj+k];
  1512. sub_prob.y[ci+k] = -1;
  1513. }
  1514. f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);
  1515. for(k=0;k<ci;k++)
  1516. if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)
  1517. nonzero[si+k] = true;
  1518. for(k=0;k<cj;k++)
  1519. if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)
  1520. nonzero[sj+k] = true;
  1521. free(sub_prob.x);
  1522. free(sub_prob.y);
  1523. ++p;
  1524. }
  1525. // build output
  1526. model->nr_class = nr_class;
  1527. model->label = Malloc(int,nr_class);
  1528. for(i=0;i<nr_class;i++)
  1529. model->label[i] = label[i];
  1530. model->rho = Malloc(double,nr_class*(nr_class-1)/2);
  1531. for(i=0;i<nr_class*(nr_class-1)/2;i++)
  1532. model->rho[i] = f[i].rho;
  1533. int total_sv = 0;
  1534. int *nz_count = Malloc(int,nr_class);
  1535. model->nSV = Malloc(int,nr_class);
  1536. for(i=0;i<nr_class;i++)
  1537. {
  1538. int nSV = 0;
  1539. for(int j=0;j<count[i];j++)
  1540. if(nonzero[start[i]+j])
  1541. {
  1542. ++nSV;
  1543. ++total_sv;
  1544. }
  1545. model->nSV[i] = nSV;
  1546. nz_count[i] = nSV;
  1547. }
  1548. info("Total nSV = %dn",total_sv);
  1549. model->l = total_sv;
  1550. model->SV = Malloc(svm_node *,total_sv);
  1551. p = 0;
  1552. for(i=0;i<l;i++)
  1553. if(nonzero[i]) model->SV[p++] = x[i];
  1554. int *nz_start = Malloc(int,nr_class);
  1555. nz_start[0] = 0;
  1556. for(i=1;i<nr_class;i++)
  1557. nz_start[i] = nz_start[i-1]+nz_count[i-1];
  1558. model->sv_coef = Malloc(double *,nr_class-1);
  1559. for(i=0;i<nr_class-1;i++)
  1560. model->sv_coef[i] = Malloc(double,total_sv);
  1561. p = 0;
  1562. for(i=0;i<nr_class;i++)
  1563. for(int j=i+1;j<nr_class;j++)
  1564. {
  1565. // classifier (i,j): coefficients with
  1566. // i are in sv_coef[j-1][nz_start[i]...],
  1567. // j are in sv_coef[i][nz_start[j]...]
  1568. int si = start[i];
  1569. int sj = start[j];
  1570. int ci = count[i];
  1571. int cj = count[j];
  1572. int q = nz_start[i];
  1573. int k;
  1574. for(k=0;k<ci;k++)
  1575. if(nonzero[si+k])
  1576. model->sv_coef[j-1][q++] = f[p].alpha[k];
  1577. q = nz_start[j];
  1578. for(k=0;k<cj;k++)
  1579. if(nonzero[sj+k])
  1580. model->sv_coef[i][q++] = f[p].alpha[ci+k];
  1581. ++p;
  1582. }
  1583. free(label);
  1584. free(count);
  1585. free(index);
  1586. free(start);
  1587. free(x);
  1588. free(weighted_C);
  1589. free(nonzero);
  1590. for(i=0;i<nr_class*(nr_class-1)/2;i++)
  1591. free(f[i].alpha);
  1592. free(f);
  1593. free(nz_count);
  1594. free(nz_start);
  1595. }
  1596. return model;
  1597. }
  1598. double svm_predict(const svm_model *model, const svm_node *x)
  1599. {
  1600. if(model->param.svm_type == ONE_CLASS ||
  1601.    model->param.svm_type == EPSILON_SVR ||
  1602.    model->param.svm_type == NU_SVR)
  1603. {
  1604. double *sv_coef = model->sv_coef[0];
  1605. double sum = 0;
  1606. for(int i=0;i<model->l;i++)
  1607. sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param);
  1608. sum -= model->rho[0];
  1609. if(model->param.svm_type == ONE_CLASS)
  1610. return (sum>0)?1:-1;
  1611. else
  1612. return sum;
  1613. }
  1614. else
  1615. {
  1616. int i;
  1617. int nr_class = model->nr_class;
  1618. int l = model->l;
  1619. double *kvalue = Malloc(double,l);
  1620. for(i=0;i<l;i++)
  1621. kvalue[i] = Kernel::k_function(x,model->SV[i],model->param);
  1622. int *start = Malloc(int,nr_class);
  1623. start[0] = 0;
  1624. for(i=1;i<nr_class;i++)
  1625. start[i] = start[i-1]+model->nSV[i-1];
  1626. int *vote = Malloc(int,nr_class);
  1627. for(i=0;i<nr_class;i++)
  1628. vote[i] = 0;
  1629. int p=0;
  1630. for(i=0;i<nr_class;i++)
  1631. for(int j=i+1;j<nr_class;j++)
  1632. {
  1633. double sum = 0;
  1634. int si = start[i];
  1635. int sj = start[j];
  1636. int ci = model->nSV[i];
  1637. int cj = model->nSV[j];
  1638. int k;
  1639. double *coef1 = model->sv_coef[j-1];
  1640. double *coef2 = model->sv_coef[i];
  1641. for(k=0;k<ci;k++)
  1642. sum += coef1[si+k] * kvalue[si+k];
  1643. for(k=0;k<cj;k++)
  1644. sum += coef2[sj+k] * kvalue[sj+k];
  1645. sum -= model->rho[p++];
  1646. if(sum > 0)
  1647. ++vote[i];
  1648. else
  1649. ++vote[j];
  1650. }
  1651. int vote_max_idx = 0;
  1652. for(i=1;i<nr_class;i++)
  1653. if(vote[i] > vote[vote_max_idx])
  1654. vote_max_idx = i;
  1655. free(kvalue);
  1656. free(start);
  1657. free(vote);
  1658. return model->label[vote_max_idx];
  1659. }
  1660. }
  1661. const char *svm_type_table[] =
  1662. {
  1663. "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
  1664. };
  1665. const char *kernel_type_table[]=
  1666. {
  1667. "linear","polynomial","rbf","sigmoid",NULL
  1668. };
  1669. int svm_save_model(const char *model_file_name, const svm_model *model)
  1670. {
  1671. FILE *fp = fopen(model_file_name,"w");
  1672. if(fp==NULL) return -1;
  1673. const svm_parameter& param = model->param;
  1674. fprintf(fp,"svm_type %sn", svm_type_table[param.svm_type]);
  1675. fprintf(fp,"kernel_type %sn", kernel_type_table[param.kernel_type]);
  1676. if(param.kernel_type == POLY)
  1677. fprintf(fp,"degree %gn", param.degree);
  1678. if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID)
  1679. fprintf(fp,"gamma %gn", param.gamma);
  1680. if(param.kernel_type == POLY || param.kernel_type == SIGMOID)
  1681. fprintf(fp,"coef0 %gn", param.coef0);
  1682. int nr_class = model->nr_class;
  1683. int l = model->l;
  1684. fprintf(fp, "nr_class %dn", nr_class);
  1685. fprintf(fp, "total_sv %dn",l);
  1686. {
  1687. fprintf(fp, "rho");
  1688. for(int i=0;i<nr_class*(nr_class-1)/2;i++)
  1689. fprintf(fp," %g",model->rho[i]);
  1690. fprintf(fp, "n");
  1691. }
  1692. if(model->label)
  1693. {
  1694. fprintf(fp, "label");
  1695. for(int i=0;i<nr_class;i++)
  1696. fprintf(fp," %d",model->label[i]);
  1697. fprintf(fp, "n");
  1698. }
  1699. if(model->nSV)
  1700. {
  1701. fprintf(fp, "nr_sv");
  1702. for(int i=0;i<nr_class;i++)
  1703. fprintf(fp," %d",model->nSV[i]);
  1704. fprintf(fp, "n");
  1705. }
  1706. fprintf(fp, "SVn");
  1707. const double * const *sv_coef = model->sv_coef;
  1708. const svm_node * const *SV = model->SV;
  1709. for(int i=0;i<l;i++)
  1710. {
  1711. for(int j=0;j<nr_class-1;j++)
  1712. fprintf(fp, "%.16g ",sv_coef[j][i]);
  1713. const svm_node *p = SV[i];
  1714. while(p->index != -1)
  1715. {
  1716. fprintf(fp,"%d:%.8g ",p->index,p->value);
  1717. p++;
  1718. }
  1719. fprintf(fp, "n");
  1720. }
  1721. fclose(fp);
  1722. return 0;
  1723. }
  1724. svm_model *svm_load_model(const char *model_file_name)
  1725. {
  1726. FILE *fp = fopen(model_file_name,"rb");
  1727. if(fp==NULL) return NULL;
  1728. // read parameters
  1729. svm_model *model = (svm_model *)malloc(sizeof(svm_model));
  1730. svm_parameter& param = model->param;
  1731. model->label = NULL;
  1732. model->nSV = NULL;
  1733. char cmd[81];
  1734. while(1)
  1735. {
  1736. fscanf(fp,"%80s",cmd);
  1737. if(strcmp(cmd,"svm_type")==0)
  1738. {
  1739. fscanf(fp,"%80s",cmd);
  1740. int i;
  1741. for(i=0;svm_type_table[i];i++)
  1742. {
  1743. if(strcmp(svm_type_table[i],cmd)==0)
  1744. {
  1745. param.svm_type=i;
  1746. break;
  1747. }
  1748. }
  1749. if(svm_type_table[i] == NULL)
  1750. {
  1751. fprintf(stderr,"unknown svm type.n");
  1752. exit(1);
  1753. }
  1754. }
  1755. else if(strcmp(cmd,"kernel_type")==0)
  1756. {
  1757. fscanf(fp,"%80s",cmd);
  1758. int i;
  1759. for(i=0;kernel_type_table[i];i++)
  1760. {
  1761. if(strcmp(kernel_type_table[i],cmd)==0)
  1762. {
  1763. param.kernel_type=i;
  1764. break;
  1765. }
  1766. }
  1767. if(kernel_type_table[i] == NULL)
  1768. {
  1769. fprintf(stderr,"unknown kernel function.n");
  1770. exit(1);
  1771. }
  1772. }
  1773. else if(strcmp(cmd,"degree")==0)
  1774. fscanf(fp,"%lf",&param.degree);
  1775. else if(strcmp(cmd,"gamma")==0)
  1776. fscanf(fp,"%lf",&param.gamma);
  1777. else if(strcmp(cmd,"coef0")==0)
  1778. fscanf(fp,"%lf",&param.coef0);
  1779. else if(strcmp(cmd,"nr_class")==0)
  1780. fscanf(fp,"%d",&model->nr_class);
  1781. else if(strcmp(cmd,"total_sv")==0)
  1782. fscanf(fp,"%d",&model->l);
  1783. else if(strcmp(cmd,"rho")==0)
  1784. {
  1785. int n = model->nr_class * (model->nr_class-1)/2;
  1786. model->rho = Malloc(double,n);
  1787. for(int i=0;i<n;i++)
  1788. fscanf(fp,"%lf",&model->rho[i]);
  1789. }
  1790. else if(strcmp(cmd,"label")==0)
  1791. {
  1792. int n = model->nr_class;
  1793. model->label = Malloc(int,n);
  1794. for(int i=0;i<n;i++)
  1795. fscanf(fp,"%d",&model->label[i]);
  1796. }
  1797. else if(strcmp(cmd,"nr_sv")==0)
  1798. {
  1799. int n = model->nr_class;
  1800. model->nSV = Malloc(int,n);
  1801. for(int i=0;i<n;i++)
  1802. fscanf(fp,"%d",&model->nSV[i]);
  1803. }
  1804. else if(strcmp(cmd,"SV")==0)
  1805. {
  1806. while(1)
  1807. {
  1808. int c = getc(fp);
  1809. if(c==EOF || c=='n') break;
  1810. }
  1811. break;
  1812. }
  1813. else
  1814. {
  1815. fprintf(stderr,"unknown text in model filen");
  1816. exit(1);
  1817. }
  1818. }
  1819. // read sv_coef and SV
  1820. int elements = 0;
  1821. long pos = ftell(fp);
  1822. while(1)
  1823. {
  1824. int c = fgetc(fp);
  1825. switch(c)
  1826. {
  1827. case 'n':
  1828. // count the '-1' element
  1829. case ':':
  1830. ++elements;
  1831. break;
  1832. case EOF:
  1833. goto out;
  1834. default:
  1835. ;
  1836. }
  1837. }
  1838. out:
  1839. fseek(fp,pos,SEEK_SET);
  1840. int m = model->nr_class - 1;
  1841. int l = model->l;
  1842. model->sv_coef = Malloc(double *,m);
  1843. int i;
  1844. for(i=0;i<m;i++)
  1845. model->sv_coef[i] = Malloc(double,l);
  1846. model->SV = Malloc(svm_node*,l);
  1847. svm_node *x_space = Malloc(svm_node,elements);
  1848. int j=0;
  1849. for(i=0;i<l;i++)
  1850. {
  1851. model->SV[i] = &x_space[j];
  1852. for(int k=0;k<m;k++)
  1853. fscanf(fp,"%lf",&model->sv_coef[k][i]);
  1854. while(1)
  1855. {
  1856. int c;
  1857. do {
  1858. c = getc(fp);
  1859. if(c=='n') goto out2;
  1860. } while(isspace(c));
  1861. ungetc(c,fp);
  1862. fscanf(fp,"%d:%lf",&(x_space[j].index),&(x_space[j].value));
  1863. ++j;
  1864. }
  1865. out2:
  1866. x_space[j++].index = -1;
  1867. }
  1868. fclose(fp);
  1869. model->free_sv = 1; // XXX
  1870. return model;
  1871. }
  1872. void svm_destroy_model(svm_model* model)
  1873. {
  1874. if(model->free_sv)
  1875. free((void *)(model->SV[0]));
  1876. for(int i=0;i<model->nr_class-1;i++)
  1877. free(model->sv_coef[i]);
  1878. free(model->SV);
  1879. free(model->sv_coef);
  1880. free(model->rho);
  1881. free(model->label);
  1882. free(model->nSV);
  1883. free(model);
  1884. }