knnctl.cpp
资源名称:Knn.rar [点击查看]
上传用户:xmhs66
上传日期:2022-07-26
资源大小:989k
文件大小:7k
源码类别:
生物技术
开发平台:
Visual C++
- /**********************************************************************
- * * filename: knnctl.cpp
- * * description: Knnctl类的实现,KNN算法实现的主体部分
- * * 1. 从文件中读取记录并解析到Mail类中
- * * 2. 对数据进行标准化
- * * 3. 求出错误率最小的K值
- * * student: Liwanjun
- * * data: 2010-03-29
- * **********************************************************************/
- #include "knnctl.h"
- #include <iostream>
- #include <fstream>
- #include <string>
- #include <cstdlib>
- #include <cmath>
- #include <algorithm>
- using std::string; using std::ifstream;
- using std::ofstream;using std::endl;
- using std::cout; using std::pair;
- using std::make_pair;
- Knnctl::Knnctl():kno(1), bSimi(false){}
- //从文件中读取记录,不进行解析
- void Knnctl::readFile(string filename)
- {
- ifstream infile(filename.c_str());
- if(!infile)
- {
- cout << "Error Opening: " << filename << endl;
- exit(-1);
- }
- string record;
- while(getline(infile, record))
- {
- recovec.push_back(record);
- }
- }
- //从文件中读取记录,解析并以Mail对象的形式存入到Knnctl的成员中
- void Knnctl::readFile(string filename, vector<Mail>& vec, bool istest)
- {
- ifstream infile(filename.c_str());
- if(!infile)
- {
- cout << "Error Opening: " << filename << endl;
- exit(-1);
- }
- if(vec.size() != 0)
- vec.erase(vec.begin(), vec.end());
- string record;
- while(getline(infile, record))
- vec.push_back(Mail(record, istest));
- }
- //随机分割记录,2/3作为学习,其余作为测试。学习集命名为study
- void Knnctl::divideFile()
- {
- int recordno = recovec.size();
- vector<int> randvec;
- string studyname("div_study");
- string testname("div_test");
- ofstream studyfile(studyname.c_str());
- ofstream testfile(testname.c_str());
- for(int i = 0; i < recordno; i++)
- {
- randvec.push_back(i);
- }
- random_shuffle(randvec.begin(), randvec.end());
- int dividno = randvec.size() * 2 / 3;
- int tempno = 0;
- for(int i = 0; i < dividno; i++)
- {
- tempno = randvec[i];
- studyvec.push_back(Mail(recovec[tempno], 0));
- studyfile << recovec[tempno] << endl;
- }
- for(int i = dividno; i < randvec.size(); i++)
- {
- tempno = randvec[i];
- testvec.push_back(Mail(recovec[tempno], 0));
- testfile << recovec[tempno] << endl;
- }
- }
- //进行[0,1]区间标准化
- void Knnctl::standard(vector<Mail>& mailvec, string filename)
- {
- vector<double> columnvec;
- vector<double>::iterator iterbeg, iterend;
- vector<double>::iterator maxiter, miniter, tempiter;
- ofstream standardfile(filename.c_str());
- double temp;
- int tempsize = mailvec[0].dimvec.size();
- for(int i = 0; i < tempsize; i++)
- {
- for(int j = 0; j < mailvec.size(); j++)
- {
- columnvec.push_back(mailvec[j].dimvec[i]);
- }
- iterbeg = columnvec.begin();
- iterend = columnvec.end();
- maxiter = max_element(iterbeg, iterend);
- miniter = min_element(iterbeg, iterend);
- //for debug: 查看某一属性的最大和最小值
- //std::cout << *maxiter << " "<< *miniter <<endl;
- temp = *maxiter - *miniter;
- for(int j = 0; j < mailvec.size(); j++)
- {
- mailvec[j].dimvec[i] = (mailvec[j].dimvec[i] - *miniter) / temp;
- }
- columnvec.erase(iterbeg, iterend);
- }
- for(int i = 0; i < mailvec.size(); i++)
- {
- iterbeg = mailvec[i].dimvec.begin();
- iterend = mailvec[i].dimvec.end();
- for(tempiter = iterbeg; tempiter != iterend; tempiter++)
- {
- standardfile << *tempiter << ",";
- }
- standardfile << endl;
- }
- }
- //让pair以值升序排序
- bool strict_weak_ordering(const pair<int, double> a, const pair<int, double> b)
- {
- return a.second < b.second;
- }
- //计算向量的相似度,默认为欧式距离
- void Knnctl::calSimi(vector<Mail>& svec, vector<Mail>& tvec)
- {
- bSimi = 1;
- vector<Mail>::iterator iterbegt, iterendt, itert;
- iterbegt = tvec.begin();
- iterendt = tvec.end();
- double distance;
- int count = 0;
- for(itert = iterbegt; itert != iterendt; itert++)
- {
- for(int i = 0; i < svec.size(); i++)
- {
- distance = (*itert).euclidDis(svec[i]);
- //可以用余弦求相似度,效果比欧氏距离稍差点
- // distance = (*itert).cosin(svec[i]);
- (*itert).disvec.push_back(make_pair(i, distance));
- }
- sort((*itert).disvec.begin(), (*itert).disvec.end(), strict_weak_ordering);
- cout << " fininsh number: " << ++count << endl;
- }
- }
- //循环求出错误率最小的K值
- void Knnctl::calKno(vector<Mail>& svec, vector<Mail>& tvec)
- {
- if(bSimi == false)
- calSimi(svec, tvec);
- vector<Mail>::iterator iterbegt, iterendt, itert;
- iterbegt = tvec.begin();
- iterendt = tvec.end();
- int recordno = svec.size();
- int total = tvec.size();
- int kmax = sqrt(recordno);
- int sum, sno, wcount = 0, k = 1;
- bool spam;
- int spamno = 0, unspamno = 0;
- double spamdis = 0, unspamdis = 0;
- double wrate = 0., wratemin = 100;
- for(; k <= kmax; k++)
- {
- for(itert = iterbegt; itert != iterendt; itert++)
- {
- for(int i = 0; i < k; i++)
- {
- sno = (*itert).disvec[i].first;
- if(svec[sno].isSpamMail() == 1)
- {
- spamno++;
- spamdis += (*itert).disvec[i].second;
- }
- else
- {
- unspamno++;
- unspamdis += (*itert).disvec[i].second;
- }
- }
- //在类别数量都相等的情况下,简单的选择了总距离最小的类
- if((spamno > unspamno) ||
- ((spamno == unspamno) && (spamdis < unspamdis)))
- spam = 1;
- else
- spam = 0;
- if(spam != (*itert).isSpamMail())
- wcount++;
- spam = spamno = unspamno = 0;
- spamdis = unspamdis = 0;
- }
- wrate = (double)wcount / total * 100.0;
- if(wrate < wratemin)
- {
- wratemin = wrate;
- kno = k;
- }
- cout << "When K= " << k <<" Wrong Rate Is:" << wrate
- << " % " << " Wrong Number Is: " << wcount << endl;
- wcount = 0;
- }
- cout << "==================================================" << endl;
- cout << "The best Kno is: " << kno << endl;
- }
- //确定了K值后,对真正的训练集和测试集进行分类
- void Knnctl::classification(vector<Mail>& svec, vector<Mail>& tvec)
- {
- ofstream result("knn_result_euclid");
- // ofstream result("knn_result_cosin");
- vector<Mail>::iterator iterbegt, iterendt, itert;
- iterbegt = tvec.begin();
- iterendt = tvec.end();
- int count = 0;
- int count_spam = 0;
- bool spam;
- int spamno = 0, unspamno = 0, sno;
- double spamdis = 0, unspamdis = 0;
- for(itert = iterbegt; itert != iterendt; itert++)
- {
- for(int i = 0; i < kno; i++)
- {
- sno = (*itert).disvec[i].first;
- if(svec[sno].isSpamMail() == 1)
- {
- spamno++;
- spamdis += (*itert).disvec[i].second;
- }
- else
- {
- unspamno++;
- unspamdis += (*itert).disvec[i].second;
- }
- }
- if((spamno > unspamno) ||
- ((spamno == unspamno) && (spamdis < unspamdis)))
- {
- spam = 1;
- count_spam++;
- }
- else
- spam = 0;
- result << (*itert).record() << " <--------------------------> "
- << spam << endl;
- cout << "the " << ++count << " record is: " << spam << endl;
- spam = spamno = unspamno = 0;
- spamdis = unspamdis = 0;
- }
- cout << "垃圾邮件的个数为 " << count_spam << endl;
- }