knnctl.cpp
上传用户:xmhs66
上传日期:2022-07-26
资源大小:989k
文件大小:7k
源码类别:

生物技术

开发平台:

Visual C++

  1. /**********************************************************************
  2.  * * filename: knnctl.cpp
  3.  * * description: Knnctl类的实现,KNN算法实现的主体部分
  4.  * * 1. 从文件中读取记录并解析到Mail类中
  5.  * * 2. 对数据进行标准化 
  6.  * * 3. 求出错误率最小的K值 
  7.  * * student: Liwanjun
  8.  * * data: 2010-03-29 
  9.  * **********************************************************************/ 
  10. #include "knnctl.h"
  11. #include <iostream>
  12. #include <fstream>
  13. #include <string>
  14. #include <cstdlib>
  15. #include <cmath>
  16. #include <algorithm>
  17. using std::string; using std::ifstream;
  18. using std::ofstream;using std::endl;
  19. using std::cout; using std::pair;
  20. using std::make_pair;
  21. Knnctl::Knnctl():kno(1), bSimi(false){}
  22. //从文件中读取记录,不进行解析
  23. void Knnctl::readFile(string filename)
  24. {
  25. ifstream infile(filename.c_str());
  26. if(!infile)
  27. {
  28. cout << "Error Opening: " << filename << endl;
  29. exit(-1);
  30. }
  31. string record;
  32. while(getline(infile, record))
  33. {
  34. recovec.push_back(record);
  35. }
  36. }
  37. //从文件中读取记录,解析并以Mail对象的形式存入到Knnctl的成员中
  38. void Knnctl::readFile(string filename, vector<Mail>& vec, bool istest)
  39. {
  40. ifstream infile(filename.c_str());
  41. if(!infile)
  42. {
  43. cout << "Error Opening: " << filename << endl;
  44. exit(-1);
  45. }
  46. if(vec.size() != 0)
  47. vec.erase(vec.begin(), vec.end());
  48. string record;
  49. while(getline(infile, record))
  50. vec.push_back(Mail(record, istest));
  51. }
  52. //随机分割记录,2/3作为学习,其余作为测试。学习集命名为study
  53. void Knnctl::divideFile()
  54. {
  55. int  recordno = recovec.size();
  56. vector<int> randvec;
  57. string studyname("div_study");
  58. string testname("div_test");
  59. ofstream studyfile(studyname.c_str());
  60. ofstream testfile(testname.c_str());
  61. for(int i = 0; i < recordno; i++)
  62. {
  63. randvec.push_back(i);
  64. }
  65. random_shuffle(randvec.begin(), randvec.end());
  66. int dividno = randvec.size() * 2 / 3;
  67. int tempno = 0;
  68. for(int i = 0; i < dividno; i++)
  69. {
  70. tempno = randvec[i];
  71. studyvec.push_back(Mail(recovec[tempno], 0));
  72. studyfile << recovec[tempno] << endl;
  73. }
  74. for(int i = dividno; i < randvec.size(); i++)
  75. {
  76. tempno = randvec[i];
  77. testvec.push_back(Mail(recovec[tempno], 0));
  78. testfile << recovec[tempno] << endl;
  79. }
  80. }
  81. //进行[0,1]区间标准化
  82. void Knnctl::standard(vector<Mail>& mailvec, string filename)
  83. {
  84. vector<double>  columnvec;
  85. vector<double>::iterator iterbeg, iterend;
  86. vector<double>::iterator  maxiter, miniter, tempiter;
  87. ofstream standardfile(filename.c_str());
  88. double temp;
  89. int tempsize = mailvec[0].dimvec.size();
  90. for(int i = 0; i < tempsize; i++)
  91. {
  92. for(int j = 0; j < mailvec.size(); j++)
  93. {
  94. columnvec.push_back(mailvec[j].dimvec[i]);
  95. }
  96. iterbeg = columnvec.begin();
  97. iterend = columnvec.end();
  98. maxiter = max_element(iterbeg, iterend);
  99. miniter = min_element(iterbeg, iterend);
  100. //for debug: 查看某一属性的最大和最小值
  101. //std::cout << *maxiter << "  "<< *miniter <<endl;
  102. temp = *maxiter - *miniter;
  103. for(int j = 0; j < mailvec.size(); j++)
  104. {
  105. mailvec[j].dimvec[i] = (mailvec[j].dimvec[i] - *miniter) / temp;
  106. }
  107. columnvec.erase(iterbeg, iterend);
  108. }
  109. for(int i = 0; i < mailvec.size(); i++)
  110. {
  111. iterbeg = mailvec[i].dimvec.begin();
  112. iterend = mailvec[i].dimvec.end();
  113. for(tempiter = iterbeg; tempiter != iterend; tempiter++)
  114. {
  115. standardfile << *tempiter << ","; 
  116. }
  117. standardfile << endl;
  118. }
  119. }
  120. //让pair以值升序排序
  121. bool strict_weak_ordering(const pair<int, double> a, const pair<int, double> b)
  122. {
  123. return a.second < b.second;
  124. }
  125. //计算向量的相似度,默认为欧式距离
  126. void Knnctl::calSimi(vector<Mail>& svec, vector<Mail>& tvec)
  127. {
  128. bSimi = 1;
  129. vector<Mail>::iterator  iterbegt, iterendt, itert;
  130. iterbegt = tvec.begin();
  131. iterendt = tvec.end();
  132.    
  133. double distance;
  134. int count = 0;
  135. for(itert = iterbegt; itert != iterendt; itert++)
  136. {
  137. for(int i = 0; i < svec.size(); i++)
  138. {
  139. distance = (*itert).euclidDis(svec[i]);
  140. //可以用余弦求相似度,效果比欧氏距离稍差点
  141. // distance = (*itert).cosin(svec[i]);
  142. (*itert).disvec.push_back(make_pair(i, distance));
  143. }
  144. sort((*itert).disvec.begin(), (*itert).disvec.end(), strict_weak_ordering);
  145. cout << " fininsh number: " << ++count << endl;
  146. }
  147. }
  148. //循环求出错误率最小的K值
  149. void Knnctl::calKno(vector<Mail>& svec, vector<Mail>& tvec)
  150. {
  151. if(bSimi == false) 
  152. calSimi(svec, tvec);
  153. vector<Mail>::iterator  iterbegt, iterendt, itert;
  154. iterbegt = tvec.begin();
  155. iterendt = tvec.end();
  156. int  recordno  = svec.size();
  157. int total = tvec.size();
  158. int  kmax  = sqrt(recordno);
  159. int  sum, sno, wcount = 0, k = 1;
  160. bool  spam;
  161. int spamno  = 0,  unspamno  = 0;
  162. double  spamdis = 0,  unspamdis = 0;
  163. double  wrate  = 0.,  wratemin  = 100;
  164. for(; k <= kmax; k++)
  165. {
  166. for(itert = iterbegt; itert != iterendt; itert++)
  167. {
  168. for(int i = 0; i < k; i++)
  169. {
  170. sno = (*itert).disvec[i].first;
  171. if(svec[sno].isSpamMail() == 1)
  172. {
  173. spamno++;
  174. spamdis += (*itert).disvec[i].second;
  175. }
  176. else
  177. {
  178. unspamno++;
  179. unspamdis += (*itert).disvec[i].second;
  180. }
  181. }
  182. //在类别数量都相等的情况下,简单的选择了总距离最小的类
  183. if((spamno > unspamno) || 
  184. ((spamno == unspamno) && (spamdis < unspamdis)))
  185. spam = 1;
  186. else
  187. spam = 0;
  188. if(spam != (*itert).isSpamMail())
  189. wcount++;
  190. spam = spamno = unspamno = 0;
  191. spamdis = unspamdis = 0;
  192. }
  193. wrate = (double)wcount / total * 100.0; 
  194. if(wrate < wratemin)
  195. {
  196. wratemin = wrate;
  197. kno = k;
  198. }
  199. cout << "When K= " << k <<"  Wrong Rate Is:" << wrate 
  200.  << " %  " << "  Wrong Number Is: " << wcount << endl;
  201. wcount = 0;
  202. }
  203. cout << "==================================================" << endl;
  204. cout << "The best Kno is: " << kno << endl;
  205. }
  206. //确定了K值后,对真正的训练集和测试集进行分类
  207. void Knnctl::classification(vector<Mail>& svec, vector<Mail>& tvec)
  208. {
  209. ofstream result("knn_result_euclid");
  210. // ofstream result("knn_result_cosin");
  211. vector<Mail>::iterator  iterbegt, iterendt, itert;
  212. iterbegt = tvec.begin();
  213. iterendt = tvec.end();
  214. int count = 0;
  215. int count_spam = 0;
  216. bool  spam;
  217. int spamno  = 0,  unspamno  = 0, sno;
  218. double  spamdis = 0,  unspamdis = 0;
  219. for(itert = iterbegt; itert != iterendt; itert++)
  220. {
  221. for(int i = 0; i < kno; i++)
  222. {
  223. sno = (*itert).disvec[i].first;
  224. if(svec[sno].isSpamMail() == 1)
  225. {
  226. spamno++;
  227. spamdis += (*itert).disvec[i].second;
  228. }
  229. else
  230. {
  231. unspamno++;
  232. unspamdis += (*itert).disvec[i].second;
  233. }
  234. }
  235. if((spamno > unspamno) || 
  236. ((spamno == unspamno) && (spamdis < unspamdis)))
  237. {
  238. spam = 1;
  239. count_spam++;
  240. }
  241. else
  242. spam = 0;
  243. result << (*itert).record() << "  <-------------------------->  "
  244. << spam << endl;
  245. cout << "the " << ++count << " record is: " << spam << endl;
  246. spam = spamno = unspamno = 0;
  247. spamdis = unspamdis = 0;
  248. }
  249. cout << "垃圾邮件的个数为 " << count_spam << endl;
  250. }