svm-toy.cpp
上传用户:xgw_05
上传日期:2014-12-08
资源大小:2726k
文件大小:11k
源码类别:

.net编程

开发平台:

Java

  1. #include <windows.h>
  2. #include <windowsx.h>
  3. #include <stdio.h>
  4. #include <ctype.h>
  5. #include <list>
  6. #include "../../svm.h"
  7. #define DEFAULT_PARAM "-t 2 -c 100"
  8. #define XLEN 500
  9. #define YLEN 500
  10. #define DrawLine(dc,x1,y1,x2,y2,c) 
  11. do { 
  12. HPEN hpen = CreatePen(PS_SOLID,0,c); 
  13. HPEN horig = SelectPen(dc,hpen); 
  14. MoveToEx(dc,x1,y1,NULL); 
  15. LineTo(dc,x2,y2); 
  16. SelectPen(dc,horig); 
  17. DeletePen(hpen); 
  18. } while(0)
  19. using namespace std;
  20. COLORREF colors[] =
  21. {
  22. RGB(0,0,0),
  23. RGB(0,120,120),
  24. RGB(120,120,0),
  25. RGB(120,0,120),
  26. RGB(0,200,200),
  27. RGB(200,200,0),
  28. RGB(200,0,200)
  29. };
  30. HWND main_window;
  31. HBITMAP buffer;
  32. HDC window_dc;
  33. HDC buffer_dc;
  34. HBRUSH brush1, brush2, brush3;
  35. HWND edit;
  36. enum {
  37. ID_BUTTON_CHANGE, ID_BUTTON_RUN, ID_BUTTON_CLEAR,
  38. ID_BUTTON_LOAD, ID_BUTTON_SAVE, ID_EDIT
  39. };
  40. struct point {
  41. double x, y;
  42. signed char value;
  43. };
  44. list<point> point_list;
  45. int current_value = 1;
  46. LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM);
  47. int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance,
  48.    PSTR szCmdLine, int iCmdShow)
  49. {
  50. static char szAppName[] = "SvmToy";
  51. MSG msg;
  52. WNDCLASSEX wndclass;
  53. wndclass.cbSize = sizeof(wndclass);
  54. wndclass.style = CS_HREDRAW | CS_VREDRAW;
  55. wndclass.lpfnWndProc = WndProc;
  56. wndclass.cbClsExtra = 0;
  57. wndclass.cbWndExtra = 0;
  58. wndclass.hInstance = hInstance;
  59. wndclass.hIcon = LoadIcon(NULL, IDI_APPLICATION);
  60. wndclass.hCursor = LoadCursor(NULL, IDC_ARROW);
  61. wndclass.hbrBackground = (HBRUSH) GetStockObject(BLACK_BRUSH);
  62. wndclass.lpszMenuName = NULL;
  63. wndclass.lpszClassName = szAppName;
  64. wndclass.hIconSm = LoadIcon(NULL, IDI_APPLICATION);
  65. RegisterClassEx(&wndclass);
  66. main_window = CreateWindow(szAppName, // window class name
  67.     "SVM Toy", // window caption
  68.     WS_OVERLAPPEDWINDOW,// window style
  69.     CW_USEDEFAULT, // initial x position
  70.     CW_USEDEFAULT, // initial y position
  71.     XLEN, // initial x size
  72.     YLEN+52, // initial y size
  73.     NULL, // parent window handle
  74.     NULL, // window menu handle
  75.     hInstance, // program instance handle
  76.     NULL); // creation parameters
  77. ShowWindow(main_window, iCmdShow);
  78. UpdateWindow(main_window);
  79. CreateWindow("button", "Change", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON,
  80.      0, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_CHANGE, hInstance, NULL);
  81. CreateWindow("button", "Run", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON,
  82.      50, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_RUN, hInstance, NULL);
  83. CreateWindow("button", "Clear", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON,
  84.      100, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_CLEAR, hInstance, NULL);
  85. CreateWindow("button", "Save", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON,
  86.      150, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_SAVE, hInstance, NULL);
  87. CreateWindow("button", "Load", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON,
  88.      200, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_LOAD, hInstance, NULL);
  89. edit = CreateWindow("edit", NULL, WS_CHILD | WS_VISIBLE,
  90.             250, YLEN, 250, 25, main_window, (HMENU) ID_EDIT, hInstance, NULL);
  91. Edit_SetText(edit,DEFAULT_PARAM);
  92. brush1 = CreateSolidBrush(colors[4]);
  93. brush2 = CreateSolidBrush(colors[5]);
  94. brush3 = CreateSolidBrush(colors[6]);
  95. window_dc = GetDC(main_window);
  96. buffer = CreateCompatibleBitmap(window_dc, XLEN, YLEN);
  97. buffer_dc = CreateCompatibleDC(window_dc);
  98. SelectObject(buffer_dc, buffer);
  99. PatBlt(buffer_dc, 0, 0, XLEN, YLEN, BLACKNESS);
  100. while (GetMessage(&msg, NULL, 0, 0)) {
  101. TranslateMessage(&msg);
  102. DispatchMessage(&msg);
  103. }
  104. return msg.wParam;
  105. }
  106. int getfilename( HWND hWnd , char *filename, int len, int save) 
  107. OPENFILENAME OpenFileName; 
  108. memset(&OpenFileName,0,sizeof(OpenFileName));
  109. filename[0]='';
  110.  
  111. OpenFileName.lStructSize       = sizeof(OPENFILENAME); 
  112. OpenFileName.hwndOwner         = hWnd; 
  113. OpenFileName.lpstrFile         = filename; 
  114. OpenFileName.nMaxFile          = len; 
  115. OpenFileName.Flags             = 0;
  116.  
  117. return save?GetSaveFileName(&OpenFileName):GetOpenFileName(&OpenFileName);
  118. }
  119. void clear_all()
  120. {
  121. point_list.clear();
  122. PatBlt(buffer_dc, 0, 0, XLEN, YLEN, BLACKNESS);
  123. InvalidateRect(main_window, 0, 0);
  124. }
  125. HBRUSH choose_brush(int v)
  126. {
  127. if(v==1) return brush1;
  128. else if(v==2) return brush2;
  129. else return brush3;
  130. }
  131. void draw_point(const point & p)
  132. {
  133. RECT rect;
  134. rect.left = int(p.x*XLEN);
  135. rect.top = int(p.y*YLEN);
  136. rect.right = int(p.x*XLEN) + 3;
  137. rect.bottom = int(p.y*YLEN) + 3;
  138. FillRect(window_dc, &rect, choose_brush(p.value));
  139. FillRect(buffer_dc, &rect, choose_brush(p.value));
  140. }
  141. void draw_all_points()
  142. {
  143. for(list<point>::iterator p = point_list.begin(); p != point_list.end(); p++)
  144. draw_point(*p);
  145. }
  146. void button_run_clicked()
  147. {
  148. // guard
  149. if(point_list.empty()) return;
  150. svm_parameter param;
  151. int i,j;
  152. // default values
  153. param.svm_type = C_SVC;
  154. param.kernel_type = RBF;
  155. param.degree = 3;
  156. param.gamma = 0;
  157. param.coef0 = 0;
  158. param.nu = 0.5;
  159. param.cache_size = 40;
  160. param.C = 1;
  161. param.eps = 1e-3;
  162. param.p = 0.1;
  163. param.shrinking = 1;
  164. param.nr_weight = 0;
  165. param.weight_label = NULL;
  166. param.weight = NULL;
  167. // parse options
  168. char str[1024];
  169. Edit_GetLine(edit, 0, str, sizeof(str));
  170. const char *p = str;
  171. while (1) {
  172. while (*p && *p != '-')
  173. p++;
  174. if (*p == '')
  175. break;
  176. p++;
  177. switch (*p++) {
  178. case 's':
  179. param.svm_type = atoi(p);
  180. break;
  181. case 't':
  182. param.kernel_type = atoi(p);
  183. break;
  184. case 'd':
  185. param.degree = atof(p);
  186. break;
  187. case 'g':
  188. param.gamma = atof(p);
  189. break;
  190. case 'r':
  191. param.coef0 = atof(p);
  192. break;
  193. case 'n':
  194. param.nu = atof(p);
  195. break;
  196. case 'm':
  197. param.cache_size = atof(p);
  198. break;
  199. case 'c':
  200. param.C = atof(p);
  201. break;
  202. case 'e':
  203. param.eps = atof(p);
  204. break;
  205. case 'p':
  206. param.p = atof(p);
  207. break;
  208. case 'h':
  209. param.shrinking = atoi(p);
  210. break;
  211. case 'w':
  212. ++param.nr_weight;
  213. param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight);
  214. param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight);
  215. param.weight_label[param.nr_weight-1] = atoi(p);
  216. while(*p && !isspace(*p)) ++p;
  217. param.weight[param.nr_weight-1] = atof(p);
  218. break;
  219. }
  220. }
  221. // build problem
  222. svm_problem prob;
  223. prob.l = point_list.size();
  224. prob.y = new double[prob.l];
  225. if(param.svm_type == EPSILON_SVR ||
  226.    param.svm_type == NU_SVR)
  227. {
  228. if(param.gamma == 0) param.gamma = 1;
  229. svm_node *x_space = new svm_node[2 * prob.l];
  230. prob.x = new svm_node *[prob.l];
  231. i = 0;
  232. for (list<point>::iterator q = point_list.begin(); q != point_list.end(); q++, i++)
  233. {
  234. x_space[2 * i].index = 1;
  235. x_space[2 * i].value = q->x;
  236. x_space[2 * i + 1].index = -1;
  237. prob.x[i] = &x_space[2 * i];
  238. prob.y[i] = q->y;
  239. }
  240. // build model & classify
  241. svm_model *model = svm_train(&prob, &param);
  242. svm_node x[2];
  243. x[0].index = 1;
  244. x[1].index = -1;
  245. int *j = new int[XLEN];
  246. for (i = 0; i < XLEN; i++)
  247. {
  248. x[0].value = (double) i / XLEN;
  249. j[i] = (int)(YLEN*svm_predict(model, x));
  250. }
  251. DrawLine(buffer_dc,0,0,0,YLEN,colors[0]);
  252. DrawLine(window_dc,0,0,0,YLEN,colors[0]);
  253. int p = (int)(param.p * YLEN);
  254. for(int i=1; i < XLEN; i++)
  255. {
  256. DrawLine(buffer_dc,i,0,i,YLEN,colors[0]);
  257. DrawLine(window_dc,i,0,i,YLEN,colors[0]);
  258. DrawLine(buffer_dc,i-1,j[i-1],i,j[i],colors[5]);
  259. DrawLine(window_dc,i-1,j[i-1],i,j[i],colors[5]);
  260. if(param.svm_type == EPSILON_SVR)
  261. {
  262. DrawLine(buffer_dc,i-1,j[i-1]+p,i,j[i]+p,colors[2]);
  263. DrawLine(window_dc,i-1,j[i-1]+p,i,j[i]+p,colors[2]);
  264. DrawLine(buffer_dc,i-1,j[i-1]-p,i,j[i]-p,colors[2]);
  265. DrawLine(window_dc,i-1,j[i-1]-p,i,j[i]-p,colors[2]);
  266. }
  267. }
  268. svm_destroy_model(model);
  269. delete[] j;
  270. delete[] x_space;
  271. delete[] prob.x;
  272. delete[] prob.y;
  273. }
  274. else
  275. {
  276. if(param.gamma == 0) param.gamma = 0.5;
  277. svm_node *x_space = new svm_node[3 * prob.l];
  278. prob.x = new svm_node *[prob.l];
  279. i = 0;
  280. for (list<point>::iterator q = point_list.begin(); q != point_list.end(); q++, i++)
  281. {
  282. x_space[3 * i].index = 1;
  283. x_space[3 * i].value = q->x;
  284. x_space[3 * i + 1].index = 2;
  285. x_space[3 * i + 1].value = q->y;
  286. x_space[3 * i + 2].index = -1;
  287. prob.x[i] = &x_space[3 * i];
  288. prob.y[i] = q->value;
  289. }
  290. // build model & classify
  291. svm_model *model = svm_train(&prob, &param);
  292. svm_node x[3];
  293. x[0].index = 1;
  294. x[1].index = 2;
  295. x[2].index = -1;
  296. for (i = 0; i < XLEN; i++)
  297. for (j = 0; j < YLEN; j++) {
  298. x[0].value = (double) i / XLEN;
  299. x[1].value = (double) j / YLEN;
  300. double d = svm_predict(model, x);
  301. SetPixel(window_dc, i, j, colors[(int)d]);
  302. SetPixel(buffer_dc, i, j, colors[(int)d]);
  303. }
  304. svm_destroy_model(model);
  305. delete[] x_space;
  306. delete[] prob.x;
  307. delete[] prob.y;
  308. }
  309. free(param.weight_label);
  310. free(param.weight);
  311. draw_all_points();
  312. }
  313. LRESULT CALLBACK WndProc(HWND hwnd, UINT iMsg, WPARAM wParam, LPARAM lParam)
  314. {
  315. HDC hdc;
  316. PAINTSTRUCT ps;
  317. switch (iMsg) {
  318. case WM_LBUTTONDOWN:
  319. {
  320. int x = LOWORD(lParam);
  321. int y = HIWORD(lParam);
  322. point p = {(double)x/XLEN, (double)y/YLEN, current_value};
  323. point_list.push_back(p);
  324. draw_point(p);
  325. }
  326. return 0;
  327. case WM_PAINT:
  328. {
  329. hdc = BeginPaint(hwnd, &ps);
  330. BitBlt(hdc, 0, 0, XLEN, YLEN, buffer_dc, 0, 0, SRCCOPY);
  331. EndPaint(hwnd, &ps);
  332. }
  333. return 0;
  334. case WM_COMMAND:
  335. {
  336. int id = LOWORD(wParam);
  337. switch (id) {
  338. case ID_BUTTON_CHANGE:
  339. ++current_value;
  340. if(current_value > 3) current_value = 1;
  341. break;
  342. case ID_BUTTON_RUN:
  343. button_run_clicked();
  344. break;
  345. case ID_BUTTON_CLEAR:
  346. clear_all();
  347. break;
  348. case ID_BUTTON_SAVE:
  349. {
  350. char filename[1024];
  351. if(getfilename(hwnd,filename,1024,1))
  352. {
  353. FILE *fp = fopen(filename,"w");
  354. if(fp)
  355. {
  356. for (list<point>::iterator p = point_list.begin(); p != point_list.end(); p++)
  357. fprintf(fp,"%d 1:%f 2:%fn",p->value,p->x,p->y);
  358. fclose(fp);
  359. }
  360. }
  361. }
  362. break;
  363. case ID_BUTTON_LOAD:
  364. {
  365. char filename[1024];
  366. if(getfilename(hwnd,filename,1024,0))
  367. {
  368. FILE *fp = fopen(filename,"r");
  369. if(fp)
  370. {
  371. clear_all();
  372. char buf[4096];
  373. while(fgets(buf,sizeof(buf),fp))
  374. {
  375. int v;
  376. double x,y;
  377. if(sscanf(buf,"%d%*d:%lf%*d:%lf",&v,&x,&y)!=3)
  378. break;
  379. point p = {x,y,v};
  380. point_list.push_back(p);
  381. }
  382. fclose(fp);
  383. draw_all_points();
  384. }
  385. }
  386. }
  387. break;
  388. }
  389. }
  390. return 0;
  391. case WM_DESTROY:
  392. PostQuitMessage(0);
  393. return 0;
  394. }
  395. return DefWindowProc(hwnd, iMsg, wParam, lParam);
  396. }