svm.m4
上传用户:xgw_05
上传日期:2014-12-08
资源大小:2726k
文件大小:38k
源码类别:

.net编程

开发平台:

Java

  1. define(`swap',`do {$1 _=$2; $2=$3; $3=_;} while(false)')
  2. define(`Qfloat',`float')
  3. define(`SIZE_OF_QFLOAT',4)
  4. package libsvm;
  5. import java.io.*;
  6. import java.util.*;
  7. //
  8. // Kernel Cache
  9. //
  10. // l is the number of total data items
  11. // size is the cache size limit in bytes
  12. //
  13. class Cache {
  14. private final int l;
  15. private int size;
  16. private final class head_t
  17. {
  18. head_t prev, next; // a cicular list
  19. Qfloat[] data;
  20. int len; // data[0,len) is cached in this entry
  21. }
  22. private final head_t[] head;
  23. private head_t lru_head;
  24. Cache(int l_, int size_)
  25. {
  26. l = l_;
  27. size = size_;
  28. head = new head_t[l];
  29. for(int i=0;i<l;i++) head[i] = new head_t();
  30. size /= SIZE_OF_QFLOAT;
  31. size -= l * (16/SIZE_OF_QFLOAT); // sizeof(head_t) == 16
  32. lru_head = new head_t();
  33. lru_head.next = lru_head.prev = lru_head;
  34. }
  35. private void lru_delete(head_t h)
  36. {
  37. // delete from current location
  38. h.prev.next = h.next;
  39. h.next.prev = h.prev;
  40. }
  41. private void lru_insert(head_t h)
  42. {
  43. // insert to last position
  44. h.next = lru_head;
  45. h.prev = lru_head.prev;
  46. h.prev.next = h;
  47. h.next.prev = h;
  48. }
  49. // request data [0,len)
  50. // return some position p where [p,len) need to be filled
  51. // (p >= len if nothing needs to be filled)
  52. // java: simulate pointer using single-element array
  53. int get_data(int index, Qfloat[][] data, int len)
  54. {
  55. head_t h = head[index];
  56. if(h.len > 0) lru_delete(h);
  57. int more = len - h.len;
  58. if(more > 0)
  59. {
  60. // free old space
  61. while(size < more)
  62. {
  63. head_t old = lru_head.next;
  64. lru_delete(old);
  65. size += old.len;
  66. old.data = null;
  67. old.len = 0;
  68. }
  69. // allocate new space
  70. Qfloat[] new_data = new Qfloat[len];
  71. if(h.data != null) System.arraycopy(h.data,0,new_data,0,h.len);
  72. h.data = new_data;
  73. size -= more;
  74. swap(int,h.len,len);
  75. }
  76. lru_insert(h);
  77. data[0] = h.data;
  78. return len;
  79. }
  80. void swap_index(int i, int j)
  81. {
  82. if(i==j) return;
  83. if(head[i].len > 0) lru_delete(head[i]);
  84. if(head[j].len > 0) lru_delete(head[j]);
  85. swap(Qfloat[],head[i].data,head[j].data);
  86. swap(int,head[i].len,head[j].len);
  87. if(head[i].len > 0) lru_insert(head[i]);
  88. if(head[j].len > 0) lru_insert(head[j]);
  89. if(i>j) swap(int,i,j);
  90. for(head_t h = lru_head.next; h!=lru_head; h=h.next)
  91. {
  92. if(h.len > i)
  93. {
  94. if(h.len > j)
  95. swap(Qfloat,h.data[i],h.data[j]);
  96. else
  97. {
  98. // give up
  99. lru_delete(h);
  100. size += h.len;
  101. h.data = null;
  102. h.len = 0;
  103. }
  104. }
  105. }
  106. }
  107. }
  108. //
  109. // Kernel evaluation
  110. //
  111. // the static method k_function is for doing single kernel evaluation
  112. // the constructor of Kernel prepares to calculate the l*l kernel matrix
  113. // the member function get_Q is for getting one column from the Q Matrix
  114. //
  115. abstract class Kernel {
  116. private svm_node[][] x;
  117. private final double[] x_square;
  118. // svm_parameter
  119. private final int kernel_type;
  120. private final double degree;
  121. private final double gamma;
  122. private final double coef0;
  123. abstract Qfloat[] get_Q(int column, int len);
  124. void swap_index(int i, int j)
  125. {
  126. swap(svm_node[],x[i],x[j]);
  127. if(x_square != null) swap(double,x_square[i],x_square[j]);
  128. }
  129. private static double tanh(double x)
  130. {
  131. double e = Math.exp(x);
  132. return 1.0-2.0/(e*e+1);
  133. }
  134. double kernel_function(int i, int j)
  135. {
  136. switch(kernel_type)
  137. {
  138. case svm_parameter.LINEAR:
  139. return dot(x[i],x[j]);
  140. case svm_parameter.POLY:
  141. return Math.pow(gamma*dot(x[i],x[j])+coef0,degree);
  142. case svm_parameter.RBF:
  143. return Math.exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
  144. case svm_parameter.SIGMOID:
  145. return tanh(gamma*dot(x[i],x[j])+coef0);
  146. default:
  147. System.err.print("unknown kernel function.n");
  148. System.exit(1);
  149. return 0; // java
  150. }
  151. }
  152. Kernel(int l, svm_node[][] x_, svm_parameter param)
  153. {
  154. this.kernel_type = param.kernel_type;
  155. this.degree = param.degree;
  156. this.gamma = param.gamma;
  157. this.coef0 = param.coef0;
  158. x = (svm_node[][])x_.clone();
  159. if(kernel_type == svm_parameter.RBF)
  160. {
  161. x_square = new double[l];
  162. for(int i=0;i<l;i++)
  163. x_square[i] = dot(x[i],x[i]);
  164. }
  165. else x_square = null;
  166. }
  167. static double dot(svm_node[] x, svm_node[] y)
  168. {
  169. double sum = 0;
  170. int xlen = x.length;
  171. int ylen = y.length;
  172. int i = 0;
  173. int j = 0;
  174. while(i < xlen && j < ylen)
  175. {
  176. if(x[i].index == y[j].index)
  177. sum += x[i++].value * y[j++].value;
  178. else
  179. {
  180. if(x[i].index > y[j].index)
  181. ++j;
  182. else
  183. ++i;
  184. }
  185. }
  186. return sum;
  187. }
  188. static double k_function(svm_node[] x, svm_node[] y,
  189. svm_parameter param)
  190. {
  191. switch(param.kernel_type)
  192. {
  193. case svm_parameter.LINEAR:
  194. return dot(x,y);
  195. case svm_parameter.POLY:
  196. return Math.pow(param.gamma*dot(x,y)+param.coef0,param.degree);
  197. case svm_parameter.RBF:
  198. {
  199. double sum = 0;
  200. int xlen = x.length;
  201. int ylen = y.length;
  202. int i = 0;
  203. int j = 0;
  204. while(i < xlen && j < ylen)
  205. {
  206. if(x[i].index == y[j].index)
  207. {
  208. double d = x[i++].value - y[j++].value;
  209. sum += d*d;
  210. }
  211. else if(x[i].index > y[j].index)
  212. {
  213. sum += y[j].value * y[j].value;
  214. ++j;
  215. }
  216. else
  217. {
  218. sum += x[i].value * x[i].value;
  219. ++i;
  220. }
  221. }
  222. while(i < xlen)
  223. {
  224. sum += x[i].value * x[i].value;
  225. ++i;
  226. }
  227. while(j < ylen)
  228. {
  229. sum += y[j].value * y[j].value;
  230. ++j;
  231. }
  232. return Math.exp(-param.gamma*sum);
  233. }
  234. case svm_parameter.SIGMOID:
  235. return tanh(param.gamma*dot(x,y)+param.coef0);
  236. default:
  237. System.err.print("unknown kernel function.n");
  238. System.exit(1);
  239. return 0; // java
  240. }
  241. }
  242. }
  243. // Generalized SMO+SVMlight algorithm
  244. // Solves:
  245. //
  246. // min 0.5(alpha^T Q alpha) + b^T alpha
  247. //
  248. // y^T alpha = delta
  249. // y_i = +1 or -1
  250. // 0 <= alpha_i <= Cp for y_i = 1
  251. // 0 <= alpha_i <= Cn for y_i = -1
  252. //
  253. // Given:
  254. //
  255. // Q, b, y, Cp, Cn, and an initial feasible point alpha
  256. // l is the size of vectors and matrices
  257. // eps is the stopping criterion
  258. //
  259. // solution will be put in alpha, objective value will be put in obj
  260. //
  261. class Solver {
  262. int active_size;
  263. byte[] y;
  264. double[] G; // gradient of objective function
  265. static final byte LOWER_BOUND = 0;
  266. static final byte UPPER_BOUND = 1;
  267. static final byte FREE = 2;
  268. byte[] alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE
  269. double[] alpha;
  270. Kernel Q;
  271. double eps;
  272. double Cp,Cn;
  273. double[] b;
  274. int[] active_set;
  275. double[] G_bar; // gradient, if we treat free variables as 0
  276. int l;
  277. boolean unshrinked; // XXX
  278. static final double INF = java.lang.Double.POSITIVE_INFINITY;
  279. double get_C(int i)
  280. {
  281. return (y[i] > 0)? Cp : Cn;
  282. }
  283. void update_alpha_status(int i)
  284. {
  285. if(alpha[i] >= get_C(i))
  286. alpha_status[i] = UPPER_BOUND;
  287. else if(alpha[i] <= 0)
  288. alpha_status[i] = LOWER_BOUND;
  289. else alpha_status[i] = FREE;
  290. }
  291. boolean is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
  292. boolean is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
  293. boolean is_free(int i) {  return alpha_status[i] == FREE; }
  294. // java: information about solution except alpha,
  295. // because we cannot return multiple values otherwise...
  296. static class SolutionInfo {
  297. double obj;
  298. double rho;
  299. double upper_bound_p;
  300. double upper_bound_n;
  301. double r; // for Solver_NU
  302. }
  303. void swap_index(int i, int j)
  304. {
  305. Q.swap_index(i,j);
  306. swap(byte, y[i],y[j]);
  307. swap(double, G[i],G[j]);
  308. swap(byte, alpha_status[i],alpha_status[j]);
  309. swap(double, alpha[i],alpha[j]);
  310. swap(double, b[i],b[j]);
  311. swap(int, active_set[i],active_set[j]);
  312. swap(double, G_bar[i],G_bar[j]);
  313. }
  314. void reconstruct_gradient()
  315. {
  316. // reconstruct inactive elements of G from G_bar and free variables
  317. if(active_size == l) return;
  318. int i;
  319. for(i=active_size;i<l;i++)
  320. G[i] = G_bar[i] + b[i];
  321. for(i=0;i<active_size;i++)
  322. if(is_free(i))
  323. {
  324. Qfloat[] Q_i = Q.get_Q(i,l);
  325. double alpha_i = alpha[i];
  326. for(int j=active_size;j<l;j++)
  327. G[j] += alpha_i * Q_i[j];
  328. }
  329. }
  330. void Solve(int l, Kernel Q, double[] b_, byte[] y_,
  331.    double[] alpha_, double Cp, double Cn, double eps, SolutionInfo si, int shrinking)
  332. {
  333. this.l = l;
  334. this.Q = Q;
  335. b = (double[])b_.clone();
  336. y = (byte[])y_.clone();
  337. alpha = (double[])alpha_.clone();
  338. this.Cp = Cp;
  339. this.Cn = Cn;
  340. this.eps = eps;
  341. this.unshrinked = false;
  342. // initialize alpha_status
  343. {
  344. alpha_status = new byte[l];
  345. for(int i=0;i<l;i++)
  346. update_alpha_status(i);
  347. }
  348. // initialize active set (for shrinking)
  349. {
  350. active_set = new int[l];
  351. for(int i=0;i<l;i++)
  352. active_set[i] = i;
  353. active_size = l;
  354. }
  355. // initialize gradient
  356. {
  357. G = new double[l];
  358. G_bar = new double[l];
  359. int i;
  360. for(i=0;i<l;i++)
  361. {
  362. G[i] = b[i];
  363. G_bar[i] = 0;
  364. }
  365. for(i=0;i<l;i++)
  366. if(!is_lower_bound(i))
  367. {
  368. Qfloat[] Q_i = Q.get_Q(i,l);
  369. double alpha_i = alpha[i];
  370. int j;
  371. for(j=0;j<l;j++)
  372. G[j] += alpha_i*Q_i[j];
  373. if(is_upper_bound(i))
  374. for(j=0;j<l;j++)
  375. G_bar[j] += get_C(i) * Q_i[j];
  376. }
  377. }
  378. // optimization step
  379. int iter = 0;
  380. int counter = Math.min(l,1000)+1;
  381. int[] working_set = new int[2];
  382. while(true)
  383. {
  384. // show progress and do shrinking
  385. if(--counter == 0)
  386. {
  387. counter = Math.min(l,1000);
  388. if(shrinking!=0) do_shrinking();
  389. System.err.print(".");
  390. }
  391. if(select_working_set(working_set)!=0)
  392. {
  393. // reconstruct the whole gradient
  394. reconstruct_gradient();
  395. // reset active set size and check
  396. active_size = l;
  397. System.err.print("*");
  398. if(select_working_set(working_set)!=0)
  399. break;
  400. else
  401. counter = 1; // do shrinking next iteration
  402. }
  403. int i = working_set[0];
  404. int j = working_set[1];
  405. ++iter;
  406. // update alpha[i] and alpha[j], handle bounds carefully
  407. Qfloat[] Q_i = Q.get_Q(i,active_size);
  408. Qfloat[] Q_j = Q.get_Q(j,active_size);
  409. double C_i = get_C(i);
  410. double C_j = get_C(j);
  411. double old_alpha_i = alpha[i];
  412. double old_alpha_j = alpha[j];
  413. if(y[i]!=y[j])
  414. {
  415. double delta = (-G[i]-G[j])/(Q_i[i]+Q_j[j]+2*Q_i[j]);
  416. double diff = alpha[i] - alpha[j];
  417. alpha[i] += delta;
  418. alpha[j] += delta;
  419. if(diff > 0)
  420. {
  421. if(alpha[j] < 0)
  422. {
  423. alpha[j] = 0;
  424. alpha[i] = diff;
  425. }
  426. }
  427. else
  428. {
  429. if(alpha[i] < 0)
  430. {
  431. alpha[i] = 0;
  432. alpha[j] = -diff;
  433. }
  434. }
  435. if(diff > C_i - C_j)
  436. {
  437. if(alpha[i] > C_i)
  438. {
  439. alpha[i] = C_i;
  440. alpha[j] = C_i - diff;
  441. }
  442. }
  443. else
  444. {
  445. if(alpha[j] > C_j)
  446. {
  447. alpha[j] = C_j;
  448. alpha[i] = C_j + diff;
  449. }
  450. }
  451. }
  452. else
  453. {
  454. double delta = (G[i]-G[j])/(Q_i[i]+Q_j[j]-2*Q_i[j]);
  455. double sum = alpha[i] + alpha[j];
  456. alpha[i] -= delta;
  457. alpha[j] += delta;
  458. if(sum > C_i)
  459. {
  460. if(alpha[i] > C_i)
  461. {
  462. alpha[i] = C_i;
  463. alpha[j] = sum - C_i;
  464. }
  465. }
  466. else
  467. {
  468. if(alpha[j] < 0)
  469. {
  470. alpha[j] = 0;
  471. alpha[i] = sum;
  472. }
  473. }
  474. if(sum > C_j)
  475. {
  476. if(alpha[j] > C_j)
  477. {
  478. alpha[j] = C_j;
  479. alpha[i] = sum - C_j;
  480. }
  481. }
  482. else
  483. {
  484. if(alpha[i] < 0)
  485. {
  486. alpha[i] = 0;
  487. alpha[j] = sum;
  488. }
  489. }
  490. }
  491. // update G
  492. double delta_alpha_i = alpha[i] - old_alpha_i;
  493. double delta_alpha_j = alpha[j] - old_alpha_j;
  494. for(int k=0;k<active_size;k++)
  495. {
  496. G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
  497. }
  498. // update alpha_status and G_bar
  499. {
  500. boolean ui = is_upper_bound(i);
  501. boolean uj = is_upper_bound(j);
  502. update_alpha_status(i);
  503. update_alpha_status(j);
  504. int k;
  505. if(ui != is_upper_bound(i))
  506. {
  507. Q_i = Q.get_Q(i,l);
  508. if(ui)
  509. for(k=0;k<l;k++)
  510. G_bar[k] -= C_i * Q_i[k];
  511. else
  512. for(k=0;k<l;k++)
  513. G_bar[k] += C_i * Q_i[k];
  514. }
  515. if(uj != is_upper_bound(j))
  516. {
  517. Q_j = Q.get_Q(j,l);
  518. if(uj)
  519. for(k=0;k<l;k++)
  520. G_bar[k] -= C_j * Q_j[k];
  521. else
  522. for(k=0;k<l;k++)
  523. G_bar[k] += C_j * Q_j[k];
  524. }
  525. }
  526. }
  527. // calculate rho
  528. si.rho = calculate_rho();
  529. // calculate objective value
  530. {
  531. double v = 0;
  532. int i;
  533. for(i=0;i<l;i++)
  534. v += alpha[i] * (G[i] + b[i]);
  535. si.obj = v/2;
  536. }
  537. // put back the solution
  538. {
  539. for(int i=0;i<l;i++)
  540. alpha_[active_set[i]] = alpha[i];
  541. }
  542. si.upper_bound_p = Cp;
  543. si.upper_bound_n = Cn;
  544. System.out.print("noptimization finished, #iter = "+iter+"n");
  545. }
  546. // return 1 if already optimal, return 0 otherwise
  547. int select_working_set(int[] working_set)
  548. {
  549. // return i,j which maximize -grad(f)^T d , under constraint
  550. // if alpha_i == C, d != +1
  551. // if alpha_i == 0, d != -1
  552. double Gmax1 = -INF; // max { -grad(f)_i * d | y_i*d = +1 }
  553. int Gmax1_idx = -1;
  554. double Gmax2 = -INF; // max { -grad(f)_i * d | y_i*d = -1 }
  555. int Gmax2_idx = -1;
  556. for(int i=0;i<active_size;i++)
  557. {
  558. if(y[i]==+1) // y = +1
  559. {
  560. if(!is_upper_bound(i)) // d = +1
  561. {
  562. if(-G[i] > Gmax1)
  563. {
  564. Gmax1 = -G[i];
  565. Gmax1_idx = i;
  566. }
  567. }
  568. if(!is_lower_bound(i)) // d = -1
  569. {
  570. if(G[i] > Gmax2)
  571. {
  572. Gmax2 = G[i];
  573. Gmax2_idx = i;
  574. }
  575. }
  576. }
  577. else // y = -1
  578. {
  579. if(!is_upper_bound(i)) // d = +1
  580. {
  581. if(-G[i] > Gmax2)
  582. {
  583. Gmax2 = -G[i];
  584. Gmax2_idx = i;
  585. }
  586. }
  587. if(!is_lower_bound(i)) // d = -1
  588. {
  589. if(G[i] > Gmax1)
  590. {
  591. Gmax1 = G[i];
  592. Gmax1_idx = i;
  593. }
  594. }
  595. }
  596. }
  597. if(Gmax1+Gmax2 < eps)
  598.   return 1;
  599. working_set[0] = Gmax1_idx;
  600. working_set[1] = Gmax2_idx;
  601. return 0;
  602. }
  603. void do_shrinking()
  604. {
  605. int i,j,k;
  606. int[] working_set = new int[2];
  607. if(select_working_set(working_set)!=0) return;
  608. i = working_set[0];
  609. j = working_set[1];
  610. double Gm1 = -y[j]*G[j];
  611. double Gm2 = y[i]*G[i];
  612. // shrink
  613. for(k=0;k<active_size;k++)
  614. {
  615. if(is_lower_bound(k))
  616. {
  617. if(y[k]==+1)
  618. {
  619. if(-G[k] >= Gm1) continue;
  620. }
  621. else if(-G[k] >= Gm2) continue;
  622. }
  623. else if(is_upper_bound(k))
  624. {
  625. if(y[k]==+1)
  626. {
  627. if(G[k] >= Gm2) continue;
  628. }
  629. else if(G[k] >= Gm1) continue;
  630. }
  631. else continue;
  632. --active_size;
  633. swap_index(k,active_size);
  634. --k; // look at the newcomer
  635. }
  636. // unshrink, check all variables again before final iterations
  637. if(unshrinked || -(Gm1 + Gm2) > eps*10) return;
  638. unshrinked = true;
  639. reconstruct_gradient();
  640. for(k=l-1;k>=active_size;k--)
  641. {
  642. if(is_lower_bound(k))
  643. {
  644. if(y[k]==+1)
  645. {
  646. if(-G[k] < Gm1) continue;
  647. }
  648. else if(-G[k] < Gm2) continue;
  649. }
  650. else if(is_upper_bound(k))
  651. {
  652. if(y[k]==+1)
  653. {
  654. if(G[k] < Gm2) continue;
  655. }
  656. else if(G[k] < Gm1) continue;
  657. }
  658. else continue;
  659. swap_index(k,active_size);
  660. active_size++;
  661. ++k; // look at the newcomer
  662. }
  663. }
  664. double calculate_rho()
  665. {
  666. double r;
  667. int nr_free = 0;
  668. double ub = INF, lb = -INF, sum_free = 0;
  669. for(int i=0;i<active_size;i++)
  670. {
  671. double yG = y[i]*G[i];
  672. if(is_lower_bound(i))
  673. {
  674. if(y[i] > 0)
  675. ub = Math.min(ub,yG);
  676. else
  677. lb = Math.max(lb,yG);
  678. }
  679. else if(is_upper_bound(i))
  680. {
  681. if(y[i] < 0)
  682. ub = Math.min(ub,yG);
  683. else
  684. lb = Math.max(lb,yG);
  685. }
  686. else
  687. {
  688. ++nr_free;
  689. sum_free += yG;
  690. }
  691. }
  692. if(nr_free>0)
  693. r = sum_free/nr_free;
  694. else
  695. r = (ub+lb)/2;
  696. return r;
  697. }
  698. }
  699. //
  700. // Solver for nu-svm classification and regression
  701. //
  702. // additional constraint: e^T alpha = constant
  703. //
  704. final class Solver_NU extends Solver
  705. {
  706. private SolutionInfo si;
  707. void Solve(int l, Kernel Q, double[] b, byte[] y,
  708.    double[] alpha, double Cp, double Cn, double eps,
  709.    SolutionInfo si, int shrinking)
  710. {
  711. this.si = si;
  712. super.Solve(l,Q,b,y,alpha,Cp,Cn,eps,si,shrinking);
  713. }
  714. int select_working_set(int[] working_set)
  715. {
  716. // return i,j which maximize -grad(f)^T d , under constraint
  717. // if alpha_i == C, d != +1
  718. // if alpha_i == 0, d != -1
  719. double Gmax1 = -INF; // max { -grad(f)_i * d | y_i = +1, d = +1 }
  720. int Gmax1_idx = -1;
  721. double Gmax2 = -INF; // max { -grad(f)_i * d | y_i = +1, d = -1 }
  722. int Gmax2_idx = -1;
  723. double Gmax3 = -INF; // max { -grad(f)_i * d | y_i = -1, d = +1 }
  724. int Gmax3_idx = -1;
  725. double Gmax4 = -INF; // max { -grad(f)_i * d | y_i = -1, d = -1 }
  726. int Gmax4_idx = -1;
  727. for(int i=0;i<active_size;i++)
  728. {
  729. if(y[i]==+1) // y == +1
  730. {
  731. if(!is_upper_bound(i)) // d = +1
  732. {
  733. if(-G[i] > Gmax1)
  734. {
  735. Gmax1 = -G[i];
  736. Gmax1_idx = i;
  737. }
  738. }
  739. if(!is_lower_bound(i)) // d = -1
  740. {
  741. if(G[i] > Gmax2)
  742. {
  743. Gmax2 = G[i];
  744. Gmax2_idx = i;
  745. }
  746. }
  747. }
  748. else // y == -1
  749. {
  750. if(!is_upper_bound(i)) // d = +1
  751. {
  752. if(-G[i] > Gmax3)
  753. {
  754. Gmax3 = -G[i];
  755. Gmax3_idx = i;
  756. }
  757. }
  758. if(!is_lower_bound(i)) // d = -1
  759. {
  760. if(G[i] > Gmax4)
  761. {
  762. Gmax4 = G[i];
  763. Gmax4_idx = i;
  764. }
  765. }
  766. }
  767. }
  768. if(Math.max(Gmax1+Gmax2,Gmax3+Gmax4) < eps)
  769.   return 1;
  770. if(Gmax1+Gmax2 > Gmax3+Gmax4)
  771. {
  772. working_set[0] = Gmax1_idx;
  773. working_set[1] = Gmax2_idx;
  774. }
  775. else
  776. {
  777. working_set[0] = Gmax3_idx;
  778. working_set[1] = Gmax4_idx;
  779. }
  780. return 0;
  781. }
  782. void do_shrinking()
  783. {
  784. double Gmax1 = -INF; // max { -grad(f)_i * d | y_i = +1, d = +1 }
  785. double Gmax2 = -INF; // max { -grad(f)_i * d | y_i = +1, d = -1 }
  786. double Gmax3 = -INF; // max { -grad(f)_i * d | y_i = -1, d = +1 }
  787. double Gmax4 = -INF; // max { -grad(f)_i * d | y_i = -1, d = -1 }
  788. int k;
  789. for(k=0;k<active_size;k++)
  790. {
  791. if(!is_upper_bound(k))
  792. {
  793. if(y[k]==+1)
  794. {
  795. if(-G[k] > Gmax1) Gmax1 = -G[k];
  796. }
  797. else if(-G[k] > Gmax3) Gmax3 = -G[k];
  798. }
  799. if(!is_lower_bound(k))
  800. {
  801. if(y[k]==+1)
  802. {
  803. if(G[k] > Gmax2) Gmax2 = G[k];
  804. }
  805. else if(G[k] > Gmax4) Gmax4 = G[k];
  806. }
  807. }
  808. double Gm1 = -Gmax2;
  809. double Gm2 = -Gmax1;
  810. double Gm3 = -Gmax4;
  811. double Gm4 = -Gmax3;
  812. for(k=0;k<active_size;k++)
  813. {
  814. if(is_lower_bound(k))
  815. {
  816. if(y[k]==+1)
  817. {
  818. if(-G[k] >= Gm1) continue;
  819. }
  820. else if(-G[k] >= Gm3) continue;
  821. }
  822. else if(is_upper_bound(k))
  823. {
  824. if(y[k]==+1)
  825. {
  826. if(G[k] >= Gm2) continue;
  827. }
  828. else if(G[k] >= Gm4) continue;
  829. }
  830. else continue;
  831. --active_size;
  832. swap_index(k,active_size);
  833. --k; // look at the newcomer
  834. }
  835. // unshrink, check all variables again before final iterations
  836. if(unshrinked || Math.max(-(Gm1+Gm2),-(Gm3+Gm4)) > eps*10) return;
  837. unshrinked = true;
  838. reconstruct_gradient();
  839. for(k=l-1;k>=active_size;k--)
  840. {
  841. if(is_lower_bound(k))
  842. {
  843. if(y[k]==+1)
  844. {
  845. if(-G[k] < Gm1) continue;
  846. }
  847. else if(-G[k] < Gm3) continue;
  848. }
  849. else if(is_upper_bound(k))
  850. {
  851. if(y[k]==+1)
  852. {
  853. if(G[k] < Gm2) continue;
  854. }
  855. else if(G[k] < Gm4) continue;
  856. }
  857. else continue;
  858. swap_index(k,active_size);
  859. active_size++;
  860. ++k; // look at the newcomer
  861. }
  862. }
  863. double calculate_rho()
  864. {
  865. int nr_free1 = 0,nr_free2 = 0;
  866. double ub1 = INF, ub2 = INF;
  867. double lb1 = -INF, lb2 = -INF;
  868. double sum_free1 = 0, sum_free2 = 0;
  869. for(int i=0;i<active_size;i++)
  870. {
  871. if(y[i]==+1)
  872. {
  873. if(is_lower_bound(i))
  874. ub1 = Math.min(ub1,G[i]);
  875. else if(is_upper_bound(i))
  876. lb1 = Math.max(lb1,G[i]);
  877. else
  878. {
  879. ++nr_free1;
  880. sum_free1 += G[i];
  881. }
  882. }
  883. else
  884. {
  885. if(is_lower_bound(i))
  886. ub2 = Math.min(ub2,G[i]);
  887. else if(is_upper_bound(i))
  888. lb2 = Math.max(lb2,G[i]);
  889. else
  890. {
  891. ++nr_free2;
  892. sum_free2 += G[i];
  893. }
  894. }
  895. }
  896. double r1,r2;
  897. if(nr_free1 > 0)
  898. r1 = sum_free1/nr_free1;
  899. else
  900. r1 = (ub1+lb1)/2;
  901. if(nr_free2 > 0)
  902. r2 = sum_free2/nr_free2;
  903. else
  904. r2 = (ub2+lb2)/2;
  905. si.r = (r1+r2)/2;
  906. return (r1-r2)/2;
  907. }
  908. }
  909. //
  910. // Q matrices for various formulations
  911. //
  912. class SVC_Q extends Kernel
  913. {
  914. private final byte[] y;
  915. private final Cache cache;
  916. SVC_Q(svm_problem prob, svm_parameter param, byte[] y_)
  917. {
  918. super(prob.l, prob.x, param);
  919. y = (byte[])y_.clone();
  920. cache = new Cache(prob.l,(int)(param.cache_size*(1<<20)));
  921. }
  922. Qfloat[] get_Q(int i, int len)
  923. {
  924. Qfloat[][] data = new Qfloat[1][];
  925. int start;
  926. if((start = cache.get_data(i,data,len)) < len)
  927. {
  928. for(int j=start;j<len;j++)
  929. data[0][j] = (Qfloat)(y[i]*y[j]*kernel_function(i,j));
  930. }
  931. return data[0];
  932. }
  933. void swap_index(int i, int j)
  934. {
  935. cache.swap_index(i,j);
  936. super.swap_index(i,j);
  937. swap(byte,y[i],y[j]);
  938. }
  939. }
  940. class ONE_CLASS_Q extends Kernel
  941. {
  942. private final Cache cache;
  943. ONE_CLASS_Q(svm_problem prob, svm_parameter param)
  944. {
  945. super(prob.l, prob.x, param);
  946. cache = new Cache(prob.l,(int)(param.cache_size*(1<<20)));
  947. }
  948. Qfloat[] get_Q(int i, int len)
  949. {
  950. Qfloat[][] data = new Qfloat[1][];
  951. int start;
  952. if((start = cache.get_data(i,data,len)) < len)
  953. {
  954. for(int j=start;j<len;j++)
  955. data[0][j] = (Qfloat)kernel_function(i,j);
  956. }
  957. return data[0];
  958. }
  959. void swap_index(int i, int j)
  960. {
  961. cache.swap_index(i,j);
  962. super.swap_index(i,j);
  963. }
  964. }
  965. class SVR_Q extends Kernel
  966. {
  967. private final int l;
  968. private final Cache cache;
  969. private final byte[] sign;
  970. private final int[] index;
  971. private int next_buffer;
  972. private Qfloat[][] buffer;
  973. SVR_Q(svm_problem prob, svm_parameter param)
  974. {
  975. super(prob.l, prob.x, param);
  976. l = prob.l;
  977. cache = new Cache(l,(int)(param.cache_size*(1<<20)));
  978. sign = new byte[2*l];
  979. index = new int[2*l];
  980. for(int k=0;k<l;k++)
  981. {
  982. sign[k] = 1;
  983. sign[k+l] = -1;
  984. index[k] = k;
  985. index[k+l] = k;
  986. }
  987. buffer = new Qfloat[2][2*l];
  988. next_buffer = 0;
  989. }
  990. void swap_index(int i, int j)
  991. {
  992. swap(byte,sign[i],sign[j]);
  993. swap(int,index[i],index[j]);
  994. }
  995. Qfloat[] get_Q(int i, int len)
  996. {
  997. Qfloat[][] data = new Qfloat[1][];
  998. int real_i = index[i];
  999. if(cache.get_data(real_i,data,l) < l)
  1000. {
  1001. for(int j=0;j<l;j++)
  1002. data[0][j] = (Qfloat)kernel_function(real_i,j);
  1003. }
  1004. // reorder and copy
  1005. Qfloat buf[] = buffer[next_buffer];
  1006. next_buffer = 1 - next_buffer;
  1007. byte si = sign[i];
  1008. for(int j=0;j<len;j++)
  1009. buf[j] = si * sign[j] * data[0][index[j]];
  1010. return buf;
  1011. }
  1012. }
  1013. public class svm {
  1014. //
  1015. // construct and solve various formulations
  1016. //
  1017. private static void solve_c_svc(svm_problem prob, svm_parameter param,
  1018. double[] alpha, Solver.SolutionInfo si,
  1019. double Cp, double Cn)
  1020. {
  1021. int l = prob.l;
  1022. double[] minus_ones = new double[l];
  1023. byte[] y = new byte[l];
  1024. int i;
  1025. for(i=0;i<l;i++)
  1026. {
  1027. alpha[i] = 0;
  1028. minus_ones[i] = -1;
  1029. if(prob.y[i] > 0) y[i] = +1; else y[i]=-1;
  1030. }
  1031. Solver s = new Solver();
  1032. s.Solve(l, new SVC_Q(prob,param,y), minus_ones, y,
  1033. alpha, Cp, Cn, param.eps, si, param.shrinking);
  1034. double sum_alpha=0;
  1035. for(i=0;i<l;i++)
  1036. sum_alpha += alpha[i];
  1037. System.out.print("nu = "+sum_alpha/(param.C*prob.l)+"n");
  1038. for(i=0;i<l;i++)
  1039. alpha[i] *= y[i];
  1040. }
  1041. private static void solve_nu_svc(svm_problem prob, svm_parameter param,
  1042.   double[] alpha, Solver.SolutionInfo si)
  1043. {
  1044. int i;
  1045. int l = prob.l;
  1046. double nu = param.nu;
  1047. int y_pos = 0;
  1048. int y_neg = 0;
  1049. byte[] y = new byte[l];
  1050. for(i=0;i<l;i++)
  1051. if(prob.y[i]>0)
  1052. {
  1053. y[i] = +1;
  1054. ++y_pos;
  1055. }
  1056. else
  1057. {
  1058. y[i] = -1;
  1059. ++y_neg;
  1060. }
  1061. if(nu < 0 || nu*l/2 > Math.min(y_pos,y_neg))
  1062. {
  1063. System.err.print("specified nu is infeasiblen");
  1064. System.exit(1);
  1065. }
  1066. double sum_pos = nu*l/2;
  1067. double sum_neg = nu*l/2;
  1068. for(i=0;i<l;i++)
  1069. if(y[i] == +1)
  1070. {
  1071. alpha[i] = Math.min(1.0,sum_pos);
  1072. sum_pos -= alpha[i];
  1073. }
  1074. else
  1075. {
  1076. alpha[i] = Math.min(1.0,sum_neg);
  1077. sum_neg -= alpha[i];
  1078. }
  1079. double[] zeros = new double[l];
  1080. for(i=0;i<l;i++)
  1081. zeros[i] = 0;
  1082. Solver_NU s = new Solver_NU();
  1083. s.Solve(l, new SVC_Q(prob,param,y), zeros, y,
  1084. alpha, 1.0, 1.0, param.eps, si, param.shrinking);
  1085. double r = si.r;
  1086. System.out.print("C = "+1/r+"n");
  1087. for(i=0;i<l;i++)
  1088. alpha[i] *= y[i]/r;
  1089. si.rho /= r;
  1090. si.obj /= (r*r);
  1091. si.upper_bound_p = 1/r;
  1092. si.upper_bound_n = 1/r;
  1093. }
  1094. private static void solve_one_class(svm_problem prob, svm_parameter param,
  1095.      double[] alpha, Solver.SolutionInfo si)
  1096. {
  1097. int l = prob.l;
  1098. double[] zeros = new double[l];
  1099. byte[] ones = new byte[l];
  1100. int i;
  1101. int n = (int)(param.nu*prob.l); // # of alpha's at upper bound
  1102. if(n>=prob.l)
  1103. {
  1104. System.err.print("nu must be in (0,1)n");
  1105. System.exit(1);
  1106. }
  1107. for(i=0;i<n;i++)
  1108. alpha[i] = 1;
  1109. alpha[n] = param.nu * prob.l - n;
  1110. for(i=n+1;i<l;i++)
  1111. alpha[i] = 0;
  1112. for(i=0;i<l;i++)
  1113. {
  1114. zeros[i] = 0;
  1115. ones[i] = 1;
  1116. }
  1117. Solver s = new Solver();
  1118. s.Solve(l, new ONE_CLASS_Q(prob,param), zeros, ones,
  1119. alpha, 1.0, 1.0, param.eps, si, param.shrinking);
  1120. }
  1121. private static void solve_epsilon_svr(svm_problem prob, svm_parameter param,
  1122. double[] alpha, Solver.SolutionInfo si)
  1123. {
  1124. int l = prob.l;
  1125. double[] alpha2 = new double[2*l];
  1126. double[] linear_term = new double[2*l];
  1127. byte[] y = new byte[2*l];
  1128. int i;
  1129. for(i=0;i<l;i++)
  1130. {
  1131. alpha2[i] = 0;
  1132. linear_term[i] = param.p - prob.y[i];
  1133. y[i] = 1;
  1134. alpha2[i+l] = 0;
  1135. linear_term[i+l] = param.p + prob.y[i];
  1136. y[i+l] = -1;
  1137. }
  1138. Solver s = new Solver();
  1139. s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
  1140. alpha2, param.C, param.C, param.eps, si, param.shrinking);
  1141. double sum_alpha = 0;
  1142. for(i=0;i<l;i++)
  1143. {
  1144. alpha[i] = alpha2[i] - alpha2[i+l];
  1145. sum_alpha += Math.abs(alpha[i]);
  1146. }
  1147. System.out.print("nu = "+sum_alpha/(param.C*l)+"n");
  1148. }
  1149. private static void solve_nu_svr(svm_problem prob, svm_parameter param,
  1150. double[] alpha, Solver.SolutionInfo si)
  1151. {
  1152. if(param.nu < 0 || param.nu > 1)
  1153. {
  1154. System.err.print("specified nu is out of rangen");
  1155. System.exit(1);
  1156. }
  1157. int l = prob.l;
  1158. double C = param.C;
  1159. double[] alpha2 = new double[2*l];
  1160. double[] linear_term = new double[2*l];
  1161. byte[] y = new byte[2*l];
  1162. int i;
  1163. double sum = C * param.nu * l / 2;
  1164. for(i=0;i<l;i++)
  1165. {
  1166. alpha2[i] = alpha2[i+l] = Math.min(sum,C);
  1167. sum -= alpha2[i];
  1168. linear_term[i] = - prob.y[i];
  1169. y[i] = 1;
  1170. linear_term[i+l] = prob.y[i];
  1171. y[i+l] = -1;
  1172. }
  1173. Solver_NU s = new Solver_NU();
  1174. s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
  1175. alpha2, C, C, param.eps, si, param.shrinking);
  1176. System.out.print("epsilon = "+(-si.r)+"n");
  1177. for(i=0;i<l;i++)
  1178. alpha[i] = alpha2[i] - alpha2[i+l];
  1179. }
  1180. //
  1181. // decision_function
  1182. //
  1183. static class decision_function
  1184. {
  1185. double[] alpha;
  1186. double rho;
  1187. };
  1188. static decision_function svm_train_one(
  1189. svm_problem prob, svm_parameter param,
  1190. double Cp, double Cn)
  1191. {
  1192. double[] alpha = new double[prob.l];
  1193. Solver.SolutionInfo si = new Solver.SolutionInfo();
  1194. switch(param.svm_type)
  1195. {
  1196. case svm_parameter.C_SVC:
  1197. solve_c_svc(prob,param,alpha,si,Cp,Cn);
  1198. break;
  1199. case svm_parameter.NU_SVC:
  1200. solve_nu_svc(prob,param,alpha,si);
  1201. break;
  1202. case svm_parameter.ONE_CLASS:
  1203. solve_one_class(prob,param,alpha,si);
  1204. break;
  1205. case svm_parameter.EPSILON_SVR:
  1206. solve_epsilon_svr(prob,param,alpha,si);
  1207. break;
  1208. case svm_parameter.NU_SVR:
  1209. solve_nu_svr(prob,param,alpha,si);
  1210. break;
  1211. }
  1212. System.out.print("obj = "+si.obj+", rho = "+si.rho+"n");
  1213. // output SVs
  1214. int nSV = 0;
  1215. int nBSV = 0;
  1216. for(int i=0;i<prob.l;i++)
  1217. {
  1218. if(Math.abs(alpha[i]) > 0)
  1219. {
  1220. ++nSV;
  1221. if(prob.y[i] > 0)
  1222. {
  1223. if(Math.abs(alpha[i]) >= si.upper_bound_p)
  1224. ++nBSV;
  1225. }
  1226. else
  1227. {
  1228. if(Math.abs(alpha[i]) >= si.upper_bound_n)
  1229. ++nBSV;
  1230. }
  1231. }
  1232. }
  1233. System.out.print("nSV = "+nSV+", nBSV = "+nBSV+"n");
  1234. decision_function f = new decision_function();
  1235. f.alpha = alpha;
  1236. f.rho = si.rho;
  1237. return f;
  1238. }
  1239. //
  1240. // Interface functions
  1241. //
  1242. public static svm_model svm_train(svm_problem prob, svm_parameter param)
  1243. {
  1244. svm_model model = new svm_model();
  1245. model.param = param;
  1246. if(param.svm_type == svm_parameter.ONE_CLASS ||
  1247.    param.svm_type == svm_parameter.EPSILON_SVR ||
  1248.    param.svm_type == svm_parameter.NU_SVR)
  1249. {
  1250. // regression or one-class-svm
  1251. model.nr_class = 2;
  1252. model.label = null;
  1253. model.nSV = null;
  1254. model.sv_coef = new double[1][];
  1255. decision_function f = svm_train_one(prob,param,0,0);
  1256. model.rho = new double[1];
  1257. model.rho[0] = f.rho;
  1258. int nSV = 0;
  1259. int i;
  1260. for(i=0;i<prob.l;i++)
  1261. if(Math.abs(f.alpha[i]) > 0) ++nSV;
  1262. model.l = nSV;
  1263. model.SV = new svm_node[nSV][];
  1264. model.sv_coef[0] = new double[nSV];
  1265. int j = 0;
  1266. for(i=0;i<prob.l;i++)
  1267. if(Math.abs(f.alpha[i]) > 0)
  1268. {
  1269. model.SV[j] = prob.x[i];
  1270. model.sv_coef[0][j] = f.alpha[i];
  1271. ++j;
  1272. }
  1273. }
  1274. else
  1275. {
  1276. // classification
  1277. // find out the number of classes
  1278. int l = prob.l;
  1279. int max_nr_class = 16;
  1280. int nr_class = 0;
  1281. int[] label = new int[max_nr_class];
  1282. int[] count = new int[max_nr_class];
  1283. int[] index = new int[l];
  1284. int i;
  1285. for(i=0;i<l;i++)
  1286. {
  1287. int this_label = (int)prob.y[i];
  1288. int j;
  1289. for(j=0;j<nr_class;j++)
  1290. if(this_label == label[j])
  1291. {
  1292. ++count[j];
  1293. break;
  1294. }
  1295. index[i] = j;
  1296. if(j == nr_class)
  1297. {
  1298. if(nr_class == max_nr_class)
  1299. {
  1300. max_nr_class *= 2;
  1301. int[] new_data = new int[max_nr_class];
  1302. System.arraycopy(label,0,new_data,0,label.length);
  1303. label = new_data;
  1304. new_data = new int[max_nr_class];
  1305. System.arraycopy(count,0,new_data,0,count.length);
  1306. count = new_data;
  1307. }
  1308. label[nr_class] = this_label;
  1309. count[nr_class] = 1;
  1310. ++nr_class;
  1311. }
  1312. }
  1313. // group training data of the same class
  1314. int[] start = new int[nr_class];
  1315. start[0] = 0;
  1316. for(i=1;i<nr_class;i++)
  1317. start[i] = start[i-1]+count[i-1];
  1318. svm_node[][] x = new svm_node[l][];
  1319. for(i=0;i<l;i++)
  1320. {
  1321. x[start[index[i]]] = prob.x[i];
  1322. ++start[index[i]];
  1323. }
  1324. start[0] = 0;
  1325. for(i=1;i<nr_class;i++)
  1326. start[i] = start[i-1]+count[i-1];
  1327. // calculate weighted C
  1328. double[] weighted_C = new double[nr_class];
  1329. for(i=0;i<nr_class;i++)
  1330. weighted_C[i] = param.C;
  1331. for(i=0;i<param.nr_weight;i++)
  1332. {
  1333. int j;
  1334. for(j=0;j<nr_class;j++)
  1335. if(param.weight_label[i] == label[j])
  1336. break;
  1337. if(j == nr_class)
  1338. System.err.print("warning: class label "+param.weight_label[i]+" specified in weight is not foundn");
  1339. else
  1340. weighted_C[j] *= param.weight[i];
  1341. }
  1342. // train n*(n-1)/2 models
  1343. boolean[] nonzero = new boolean[l];
  1344. for(i=0;i<l;i++)
  1345. nonzero[i] = false;
  1346. decision_function[] f = new decision_function[nr_class*(nr_class-1)/2];
  1347. int p = 0;
  1348. for(i=0;i<nr_class;i++)
  1349. for(int j=i+1;j<nr_class;j++)
  1350. {
  1351. svm_problem sub_prob = new svm_problem();
  1352. int si = start[i], sj = start[j];
  1353. int ci = count[i], cj = count[j];
  1354. sub_prob.l = ci+cj;
  1355. sub_prob.x = new svm_node[sub_prob.l][];
  1356. sub_prob.y = new double[sub_prob.l];
  1357. int k;
  1358. for(k=0;k<ci;k++)
  1359. {
  1360. sub_prob.x[k] = x[si+k];
  1361. sub_prob.y[k] = +1;
  1362. }
  1363. for(k=0;k<cj;k++)
  1364. {
  1365. sub_prob.x[ci+k] = x[sj+k];
  1366. sub_prob.y[ci+k] = -1;
  1367. }
  1368. f[p] = svm_train_one(sub_prob,param,weighted_C[i],weighted_C[j]);
  1369. for(k=0;k<ci;k++)
  1370. if(!nonzero[si+k] && Math.abs(f[p].alpha[k]) > 0)
  1371. nonzero[si+k] = true;
  1372. for(k=0;k<cj;k++)
  1373. if(!nonzero[sj+k] && Math.abs(f[p].alpha[ci+k]) > 0)
  1374. nonzero[sj+k] = true;
  1375. ++p;
  1376. }
  1377. // build output
  1378. model.nr_class = nr_class;
  1379. model.label = new int[nr_class];
  1380. for(i=0;i<nr_class;i++)
  1381. model.label[i] = label[i];
  1382. model.rho = new double[nr_class*(nr_class-1)/2];
  1383. for(i=0;i<nr_class*(nr_class-1)/2;i++)
  1384. model.rho[i] = f[i].rho;
  1385. int nnz = 0;
  1386. int[] nz_count = new int[nr_class];
  1387. model.nSV = new int[nr_class];
  1388. for(i=0;i<nr_class;i++)
  1389. {
  1390. int nSV = 0;
  1391. for(int j=0;j<count[i];j++)
  1392. if(nonzero[start[i]+j])
  1393. {
  1394. ++nSV;
  1395. ++nnz;
  1396. }
  1397. model.nSV[i] = nSV;
  1398. nz_count[i] = nSV;
  1399. }
  1400. System.out.print("Total nSV = "+nnz+"n");
  1401. model.l = nnz;
  1402. model.SV = new svm_node[nnz][];
  1403. p = 0;
  1404. for(i=0;i<l;i++)
  1405. if(nonzero[i]) model.SV[p++] = x[i];
  1406. int[] nz_start = new int[nr_class];
  1407. nz_start[0] = 0;
  1408. for(i=1;i<nr_class;i++)
  1409. nz_start[i] = nz_start[i-1]+nz_count[i-1];
  1410. model.sv_coef = new double[nr_class-1][];
  1411. for(i=0;i<nr_class-1;i++)
  1412. model.sv_coef[i] = new double[nnz];
  1413. p = 0;
  1414. for(i=0;i<nr_class;i++)
  1415. for(int j=i+1;j<nr_class;j++)
  1416. {
  1417. // classifier (i,j): coefficients with
  1418. // i are in sv_coef[j-1][nz_start[i]...],
  1419. // j are in sv_coef[i][nz_start[j]...]
  1420. int si = start[i];
  1421. int sj = start[j];
  1422. int ci = count[i];
  1423. int cj = count[j];
  1424. int q = nz_start[i];
  1425. int k;
  1426. for(k=0;k<ci;k++)
  1427. if(nonzero[si+k])
  1428. model.sv_coef[j-1][q++] = f[p].alpha[k];
  1429. q = nz_start[j];
  1430. for(k=0;k<cj;k++)
  1431. if(nonzero[sj+k])
  1432. model.sv_coef[i][q++] = f[p].alpha[ci+k];
  1433. ++p;
  1434. }
  1435. }
  1436. return model;
  1437. }
  1438. public static double svm_predict(svm_model model, svm_node[] x)
  1439. {
  1440. if(model.param.svm_type == svm_parameter.ONE_CLASS ||
  1441.    model.param.svm_type == svm_parameter.EPSILON_SVR ||
  1442.    model.param.svm_type == svm_parameter.NU_SVR)
  1443. {
  1444. double[] sv_coef = model.sv_coef[0];
  1445. double sum = 0;
  1446. for(int i=0;i<model.l;i++)
  1447. sum += sv_coef[i] * Kernel.k_function(x,model.SV[i],model.param);
  1448. sum -= model.rho[0];
  1449. if(model.param.svm_type == svm_parameter.ONE_CLASS)
  1450. return (sum>0)?1:-1;
  1451. else
  1452. return sum;
  1453. }
  1454. else
  1455. {
  1456. int i;
  1457. int nr_class = model.nr_class;
  1458. int l = model.l;
  1459. double[] kvalue = new double[l];
  1460. for(i=0;i<l;i++)
  1461. kvalue[i] = Kernel.k_function(x,model.SV[i],model.param);
  1462. int[] start = new int[nr_class];
  1463. start[0] = 0;
  1464. for(i=1;i<nr_class;i++)
  1465. start[i] = start[i-1]+model.nSV[i-1];
  1466. int[] vote = new int[nr_class];
  1467. for(i=0;i<nr_class;i++)
  1468. vote[i] = 0;
  1469. int p=0;
  1470. for(i=0;i<nr_class;i++)
  1471. for(int j=i+1;j<nr_class;j++)
  1472. {
  1473. double sum = 0;
  1474. int si = start[i];
  1475. int sj = start[j];
  1476. int ci = model.nSV[i];
  1477. int cj = model.nSV[j];
  1478. int k;
  1479. double[] coef1 = model.sv_coef[j-1];
  1480. double[] coef2 = model.sv_coef[i];
  1481. for(k=0;k<ci;k++)
  1482. sum += coef1[si+k] * kvalue[si+k];
  1483. for(k=0;k<cj;k++)
  1484. sum += coef2[sj+k] * kvalue[sj+k];
  1485. sum -= model.rho[p++];
  1486. if(sum > 0)
  1487. ++vote[i];
  1488. else
  1489. ++vote[j];
  1490. }
  1491. int vote_max_idx = 0;
  1492. for(i=1;i<nr_class;i++)
  1493. if(vote[i] > vote[vote_max_idx])
  1494. vote_max_idx = i;
  1495. return model.label[vote_max_idx];
  1496. }
  1497. }
  1498. static final String svm_type_table[] =
  1499. {
  1500. "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",
  1501. };
  1502. static final String kernel_type_table[]=
  1503. {
  1504. "linear","polynomial","rbf","sigmoid",
  1505. };
  1506. public static void svm_save_model(String model_file_name, svm_model model) throws IOException
  1507. {
  1508. DataOutputStream fp = new DataOutputStream(new FileOutputStream(model_file_name));
  1509. svm_parameter param = model.param;
  1510. fp.writeBytes("svm_type "+svm_type_table[param.svm_type]+"n");
  1511. fp.writeBytes("kernel_type "+kernel_type_table[param.kernel_type]+"n");
  1512. if(param.kernel_type == svm_parameter.POLY)
  1513. fp.writeBytes("degree "+param.degree+"n");
  1514. if(param.kernel_type == svm_parameter.POLY ||
  1515.    param.kernel_type == svm_parameter.RBF ||
  1516.    param.kernel_type == svm_parameter.SIGMOID)
  1517. fp.writeBytes("gamma "+param.gamma+"n");
  1518. if(param.kernel_type == svm_parameter.POLY ||
  1519.    param.kernel_type == svm_parameter.SIGMOID)
  1520. fp.writeBytes("coef0 "+param.coef0+"n");
  1521. int nr_class = model.nr_class;
  1522. int l = model.l;
  1523. fp.writeBytes("nr_class "+nr_class+"n");
  1524. fp.writeBytes("total_sv "+l+"n");
  1525. {
  1526. fp.writeBytes("rho");
  1527. for(int i=0;i<nr_class*(nr_class-1)/2;i++)
  1528. fp.writeBytes(" "+model.rho[i]);
  1529. fp.writeBytes("n");
  1530. }
  1531. if(model.label != null)
  1532. {
  1533. fp.writeBytes("label");
  1534. for(int i=0;i<nr_class;i++)
  1535. fp.writeBytes(" "+model.label[i]);
  1536. fp.writeBytes("n");
  1537. }
  1538. if(model.nSV != null)
  1539. {
  1540. fp.writeBytes("nr_sv");
  1541. for(int i=0;i<nr_class;i++)
  1542. fp.writeBytes(" "+model.nSV[i]);
  1543. fp.writeBytes("n");
  1544. }
  1545. fp.writeBytes("SVn");
  1546. double[][] sv_coef = model.sv_coef;
  1547. svm_node[][] SV = model.SV;
  1548. for(int i=0;i<l;i++)
  1549. {
  1550. for(int j=0;j<nr_class-1;j++)
  1551. fp.writeBytes(sv_coef[j][i]+" ");
  1552. svm_node[] p = SV[i];
  1553. for(int j=0;j<p.length;j++)
  1554. fp.writeBytes(p[j].index+":"+p[j].value+" ");
  1555. fp.writeBytes("n");
  1556. }
  1557. fp.close();
  1558. }
  1559. private static double atof(String s)
  1560. {
  1561. return Double.valueOf(s).doubleValue();
  1562. }
  1563. private static int atoi(String s)
  1564. {
  1565. return Integer.parseInt(s);
  1566. }
  1567. public static svm_model svm_load_model(String model_file_name) throws IOException
  1568. {
  1569. BufferedReader fp = new BufferedReader(new FileReader(model_file_name));
  1570. // read parameters
  1571. svm_model model = new svm_model();
  1572. svm_parameter param = new svm_parameter();
  1573. model.param = param;
  1574. model.label = null;
  1575. model.nSV = null;
  1576. while(true)
  1577. {
  1578. String cmd = fp.readLine();
  1579. String arg = cmd.substring(cmd.indexOf(' ')+1);
  1580. if(cmd.startsWith("svm_type"))
  1581. {
  1582. int i;
  1583. for(i=0;i<svm_type_table.length;i++)
  1584. {
  1585. if(arg.indexOf(svm_type_table[i])!=-1)
  1586. {
  1587. param.svm_type=i;
  1588. break;
  1589. }
  1590. }
  1591. if(i == svm_type_table.length)
  1592. {
  1593. System.err.print("unknown svm type.n");
  1594. System.exit(1);
  1595. }
  1596. }
  1597. else if(cmd.startsWith("kernel_type"))
  1598. {
  1599. int i;
  1600. for(i=0;i<kernel_type_table.length;i++)
  1601. {
  1602. if(arg.indexOf(kernel_type_table[i])!=-1)
  1603. {
  1604. param.kernel_type=i;
  1605. break;
  1606. }
  1607. }
  1608. if(i == kernel_type_table.length)
  1609. {
  1610. System.err.print("unknown kernel function.n");
  1611. System.exit(1);
  1612. }
  1613. }
  1614. else if(cmd.startsWith("degree"))
  1615. param.degree = atof(arg);
  1616. else if(cmd.startsWith("gamma"))
  1617. param.gamma = atof(arg);
  1618. else if(cmd.startsWith("coef0"))
  1619. param.coef0 = atof(arg);
  1620. else if(cmd.startsWith("nr_class"))
  1621. model.nr_class = atoi(arg);
  1622. else if(cmd.startsWith("total_sv"))
  1623. model.l = atoi(arg);
  1624. else if(cmd.startsWith("rho"))
  1625. {
  1626. int n = model.nr_class * (model.nr_class-1)/2;
  1627. model.rho = new double[n];
  1628. StringTokenizer st = new StringTokenizer(arg);
  1629. for(int i=0;i<n;i++)
  1630. model.rho[i] = atof(st.nextToken());
  1631. }
  1632. else if(cmd.startsWith("label"))
  1633. {
  1634. int n = model.nr_class;
  1635. model.label = new int[n];
  1636. StringTokenizer st = new StringTokenizer(arg);
  1637. for(int i=0;i<n;i++)
  1638. model.label[i] = atoi(st.nextToken());
  1639. }
  1640. else if(cmd.startsWith("nr_sv"))
  1641. {
  1642. int n = model.nr_class;
  1643. model.nSV = new int[n];
  1644. StringTokenizer st = new StringTokenizer(arg);
  1645. for(int i=0;i<n;i++)
  1646. model.nSV[i] = atoi(st.nextToken());
  1647. }
  1648. else if(cmd.startsWith("SV"))
  1649. {
  1650. break;
  1651. }
  1652. else
  1653. {
  1654. System.err.print("unknown text in model filen");
  1655. System.exit(1);
  1656. }
  1657. }
  1658. // read sv_coef and SV
  1659. int m = model.nr_class - 1;
  1660. int l = model.l;
  1661. model.sv_coef = new double[m][l];
  1662. model.SV = new svm_node[l][];
  1663. for(int i=0;i<l;i++)
  1664. {
  1665. String line = fp.readLine();
  1666. StringTokenizer st = new StringTokenizer(line," tnrf:");
  1667. for(int k=0;k<m;k++)
  1668. model.sv_coef[k][i] = atof(st.nextToken());
  1669. int n = st.countTokens()/2;
  1670. model.SV[i] = new svm_node[n];
  1671. for(int j=0;j<n;j++)
  1672. {
  1673. model.SV[i][j] = new svm_node();
  1674. model.SV[i][j].index = atoi(st.nextToken());
  1675. model.SV[i][j].value = atof(st.nextToken());
  1676. }
  1677. }
  1678. fp.close();
  1679. return model;
  1680. }
  1681. }