Classifier.cpp
上传用户:sanxfzhen
上传日期:2014-12-28
资源大小:2324k
文件大小:55k
源码类别:

多国语言处理

开发平台:

Visual C++

  1. // Classifier.cpp: implementation of the CClassifier class.
  2. //
  3. //////////////////////////////////////////////////////////////////////
  4. #include "stdafx.h"
  5. #include "TextClassify.h"
  6. #include "Classifier.h"
  7. #include "WordSegment.h"
  8. #include "Message.h"
  9. #include <math.h>
  10. #include <direct.h>
  11. #ifdef _DEBUG
  12. #undef THIS_FILE
  13. static char THIS_FILE[]=__FILE__;
  14. #define new DEBUG_NEW
  15. #endif
  16. CClassifier theClassifier;
  17. const DWORD CClassifier::dwModelFileID=0xFFEFFFFF;
  18. //////////////////////////////////////////////////////////////////////
  19. // Construction/Destruction
  20. //////////////////////////////////////////////////////////////////////
  21. CClassifier::CClassifier()
  22. {
  23. n_Type=-1;
  24. m_pDocs=NULL;
  25. m_pSimilarityRatio=NULL;
  26. m_pProbability=NULL;
  27. m_lDocNum=0;
  28. m_nClassNum=0;
  29. }
  30. CClassifier::~CClassifier()
  31. {
  32. }
  33. //参数bGenDic=false代表无需重新扫描文档得到训练文档集中所有特征,一般在层次分类时使用
  34. //参数nType用来决定分类模型的类别,nType=0代表KNN分类器,nType=1代表SVM分类器
  35. bool CClassifier::Train(int nType, bool bFlag)
  36. {
  37. this->n_Type=nType;
  38. CTime startTime;
  39. CTimeSpan totalTime;
  40. if(bFlag)
  41. {
  42. InitTrain();
  43. //生成所有候选特征项,将其保存在m_lstWordList中
  44. GenDic();
  45. }
  46. CMessage::PrintStatusInfo("");
  47. if(m_lstWordList.GetCount()==0)
  48. return false;
  49. if(m_lstTrainCatalogList.GetCataNum()==0)
  50. return false;
  51. //清空特征项列表m_lstTrainWordList
  52. m_lstTrainWordList.InitWordList();
  53. //为特征项列表m_lstWordList中的每个特征加权
  54. CMessage::PrintInfo(_T("开始计算候选特征集中每个特征的类别区分度,请稍候..."));
  55. startTime=CTime::GetCurrentTime();
  56. FeatherWeight(m_lstWordList);
  57. totalTime=CTime::GetCurrentTime()-startTime;
  58. CMessage::PrintInfo(_T("特征区分度计算结束,耗时")+totalTime.Format("%H:%M:%S"));
  59. CMessage::PrintStatusInfo("");
  60. //从特征项列表m_lstWordList中选出最优特征
  61. CMessage::PrintInfo(_T("开始进行特征选择,请稍候..."));
  62. startTime=CTime::GetCurrentTime();
  63. FeatherSelection(m_lstTrainWordList);
  64.     //为最优特征集m_lstTrainWordList中的每个特征建立一个ID
  65. m_lstTrainWordList.IndexWord();
  66. totalTime=CTime::GetCurrentTime()-startTime;
  67. CMessage::PrintInfo(_T("特征选择结束,耗时")+totalTime.Format("%H:%M:%S"));
  68. CMessage::PrintStatusInfo("");
  69. // 清空m_lstWordList,释放它占用的空间
  70. m_lstWordList.InitWordList();
  71. CMessage::PrintInfo("开始生成文档向量,请稍候...");
  72. startTime=CTime::GetCurrentTime();
  73. GenModel();
  74. totalTime=CTime::GetCurrentTime()-startTime;
  75. CMessage::PrintInfo(_T("文档向量生成结束,耗时")+totalTime.Format("%H:%M:%S"));
  76. CMessage::PrintStatusInfo("");
  77. CMessage::PrintInfo("开始保存分类模型,请稍候...");
  78. startTime=CTime::GetCurrentTime();
  79. WriteModel(m_paramClassifier.m_txtResultDir+"\model.prj",nType);
  80. totalTime=CTime::GetCurrentTime()-startTime;
  81. CMessage::PrintInfo(_T("保存分类模型结束,耗时")+totalTime.Format("%H:%M:%S"));
  82. //训练SVM分类器必须在保存训练文档的文档向量后进行
  83. if(nType == 1)
  84. {
  85. CMessage::PrintInfo("开始训练SVM,请稍候...");
  86. m_lstTrainCatalogList.InitCatalogList(2); //删除文档向量所占用的空间
  87. startTime=CTime::GetCurrentTime();
  88. TrainSVM();
  89. totalTime=CTime::GetCurrentTime()-startTime;
  90. CMessage::PrintInfo(_T("SVM分类器训练结束,耗时")+totalTime.Format("%H:%M:%S"));
  91. CMessage::PrintStatusInfo("");
  92. }
  93. //为分类做好准备,否则不能进行分类
  94. Prepare();
  95. CMessage::PrintStatusInfo("");
  96. return TRUE;
  97. }
  98. void CClassifier::TrainSVM()
  99. {
  100. CString str;
  101. CTime tmStart;
  102. CTimeSpan tmSpan;
  103. m_paramClassifier.m_strModelFile="model";
  104. for(int i=1;i<=m_lstTrainCatalogList.GetCataNum();i++)
  105. {
  106. tmStart=CTime::GetCurrentTime();
  107. str.Format("正在训练第%d个SVM分类器,请稍侯...",i);
  108. CMessage::PrintInfo(str);
  109. m_theSVM.com_param.trainfile=m_paramClassifier.m_txtResultDir+"\train.txt";
  110. m_theSVM.com_param.modelfile.Format("%s\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i);
  111. m_theSVM.svm_learn_main(i);
  112. tmSpan=CTime::GetCurrentTime()-tmStart;
  113. str.Format("第%d个SVM分类器训练完成,耗时%s!",i,tmSpan.Format("%H:%M:%S"));
  114. CMessage::PrintInfo(str);
  115. }
  116. }
  117. void CClassifier::TrainBAYES()
  118. {
  119. /*
  120. CString str;
  121. CTime tmStart;
  122. CTimeSpan tmSpan;
  123. m_paramClassifier.m_strModelFile="model";
  124. for(int i=1;i<=m_lstTrainCatalogList.GetCataNum();i++)
  125. {
  126. tmStart=CTime::GetCurrentTime();
  127. str.Format("正在训练第%d个SVM分类器,请稍侯...",i);
  128. CMessage::PrintInfo(str);
  129. m_theSVM.com_param.trainfile=m_paramClassifier.m_txtResultDir+"\train.txt";
  130. m_theSVM.com_param.modelfile.Format("%s\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i);
  131. m_theSVM.svm_learn_main(i);
  132. tmSpan=CTime::GetCurrentTime()-tmStart;
  133. str.Format("第%d个SVM分类器训练完成,耗时%s!",i,tmSpan.Format("%H:%M:%S"));
  134. CMessage::PrintInfo(str);
  135. }
  136. */
  137. }
  138. // fill an array of CTrain::sSortType (train word length)
  139. // nCatalog mean the value of element of the array is the weight
  140. // of nCatalog(as an index of catalog) for each individual word
  141. // if nCatalog==-1 then sum weight for all catalog
  142. void CClassifier::GenSortBuf(CWordList& wordList,sSortType *psSortBuf,int nCatalog)
  143. {
  144. int nTotalCata=m_lstTrainCatalogList.GetCataNum();
  145. int i;
  146. ASSERT(nCatalog<nTotalCata);
  147. long lWordCount=0 ;
  148. POSITION pos_word = wordList.GetFirstPosition();
  149. CString str;
  150. while(pos_word!= NULL)  // for each word
  151. {
  152. CWordNode& wordnode = wordList.GetNext(pos_word,str);
  153. psSortBuf[lWordCount].pclsWordNode = &wordnode; 
  154. strcpy(psSortBuf[lWordCount].word,str);
  155. ASSERT(wordnode.m_nAllocLen==nTotalCata);
  156. if(nCatalog==-1)
  157. {
  158. psSortBuf[lWordCount].dWeight+=wordnode.m_dWeight;
  159. }
  160. else
  161. psSortBuf[lWordCount].dWeight=wordnode.m_pCataWeight[nCatalog];
  162. //拷贝词属于类的概率
  163. if(this->n_Type==2)
  164. for(i=0;i<nTotalCata;i++)
  165. {
  166. psSortBuf[lWordCount].pclsWordNode->m_pCataWeightPro[i]=wordnode.m_pCataWeightPro[i];
  167. // CString strtemp;
  168. // strtemp.Format("123 %f",psSortBuf[lWordCount].pclsWordNode->m_pCataWeightPro[i]);
  169. // CMessage::PrintInfo(strtemp);
  170. }
  171. lWordCount++;
  172. }
  173. }
  174. //从m_lstWordList选出最优特征子集,存到dstWordList中
  175. void CClassifier::FeatherSelection(CWordList& dstWordList)
  176. {
  177. if(m_lstWordList.GetCount()<=0) return;
  178. dstWordList.InitWordList();
  179. m_lstWordList.IndexWord();
  180. sSortType *psSortBuf;
  181. int nDistinctWordNum = m_lstWordList.GetCount();
  182. psSortBuf = new sSortType[nDistinctWordNum ];  // the distinct number of the word 
  183. ASSERT(psSortBuf!=NULL);
  184. long lDocNum=m_lstTrainCatalogList.GetDocNum();
  185. for(int i=0;i<nDistinctWordNum ;i++)
  186. {
  187. psSortBuf[i].pclsWordNode = NULL;
  188. psSortBuf[i].dWeight   = 0;
  189. }
  190. // catalog indivial selecting
  191. if(m_paramClassifier.m_nSelMode==CClassifierParam::nFSM_IndividualModel)
  192. {
  193. int nCatalogWordSize=m_paramClassifier.m_nWordSize;
  194. int nTotalCata=m_lstTrainCatalogList.GetCataNum();
  195. for(int i=0;i<nTotalCata;i++)
  196. {
  197. GenSortBuf(m_lstWordList,psSortBuf,i);//-1 mean sum all catalog
  198. Sort(psSortBuf,nDistinctWordNum-1);
  199. int nSelectWordNum=0;
  200. for(int j=0;j<nDistinctWordNum&&nSelectWordNum<nCatalogWordSize;j++)
  201. {
  202. CWordNode wordNode;
  203. if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF)
  204. psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum);
  205. else if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF_DIFF)
  206. psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum,true);
  207. wordNode.m_dWeight=psSortBuf[j].pclsWordNode->m_dWeight;
  208. wordNode.m_lDocFreq=psSortBuf[j].pclsWordNode->m_lDocFreq;
  209. wordNode.m_lWordFreq=psSortBuf[j].pclsWordNode->m_lWordFreq;
  210. dstWordList.SetAt(psSortBuf[j].word,wordNode);
  211. nSelectWordNum++;
  212. }
  213. }
  214. }
  215. // total selecting
  216. else //if(m_paramClassifier.m_nSelMode==CClassifierParam::nFSM_GolbalMode)
  217. {
  218. int iWord=0;
  219. GenSortBuf(m_lstWordList,psSortBuf,-1);//-1 mean sum all catalog
  220. Sort(psSortBuf,nDistinctWordNum-1);
  221. int nSelectWordNum=m_paramClassifier.m_nWordSize;
  222. if (nSelectWordNum>nDistinctWordNum)
  223. nSelectWordNum=nDistinctWordNum;
  224. for(i=0;i<nSelectWordNum;i++)
  225. {
  226. CWordNode wordNode;
  227. wordNode.InitBuffer(m_lstTrainCatalogList.GetCataNum());
  228. if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF)
  229. psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum);
  230. else if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF_DIFF)
  231. psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum,true);
  232. wordNode.m_dWeight=psSortBuf[i].pclsWordNode->m_dWeight;
  233. wordNode.m_lDocFreq=psSortBuf[i].pclsWordNode->m_lDocFreq;
  234. wordNode.m_lWordFreq=psSortBuf[i].pclsWordNode->m_lWordFreq;
  235. if(this->n_Type==2)
  236. // 拷贝词在不同类中的概率
  237. for(int k=0;k<m_lstTrainCatalogList.GetCataNum();k++)
  238. {
  239. double abc = psSortBuf[i].pclsWordNode->m_pCataWeightPro[k];
  240. wordNode.m_pCataWeightPro[k] = psSortBuf[i].pclsWordNode->m_pCataWeightPro[k];
  241. // CString strtemp;
  242. // strtemp.Format("123 %f",wordNode.m_pCataWeightPro[k]);
  243. // CMessage::PrintInfo(strtemp);
  244. }
  245. dstWordList.SetAt(psSortBuf[i].word,wordNode);
  246. }
  247. }
  248. delete [] psSortBuf;
  249. }
  250. void CClassifier::FeatherWeight(CWordList& wordList)
  251. {
  252. // ------------------------------------------------------------------------------
  253. //  based on document number model
  254. int N; //总的文档数;
  255. int N_c; //C类的文档数
  256. int N_ft; //含有ft的文档数
  257. int N_c_ft; //C类中含有ft的文档数
  258. // ------------------------------------------------------------------------------
  259. //  based on word number model
  260. long N_W;    //总的词数 m_lWordNum;
  261. long N_W_C;  //C类词数 CCatalogNode.m_lTotalWordNum;
  262. long N_W_f_t; //f_t出现的总次数
  263. long N_W_C_f_t;//C类中f_t出现的次数
  264. // ------------------------------------------------------------------------------
  265. double P_c_ft,P_c_n_ft,P_n_c_ft,P_n_c_n_ft;
  266. double P_c,P_n_c;
  267. double P_ft,P_n_ft;
  268. // ------------------------------------------------------------------------------
  269. POSITION pos_cata,pos_word;
  270. CString     strWord;
  271. // calculate the weight of each word to all catalog
  272. N = m_lstTrainCatalogList.GetDocNum();
  273. N_W = wordList.GetWordNum();
  274. int nTotalCata=m_lstTrainCatalogList.GetCataNum();
  275. pos_word = wordList.GetFirstPosition();
  276. while(pos_word!= NULL)  // for each word
  277. {
  278. CWordNode& wordnode = wordList.GetNext(pos_word,strWord);
  279. wordnode.m_dWeight=0;
  280. ASSERT(wordnode.m_pCataWeight);
  281. ASSERT(wordnode.m_pCataWeightPro);
  282. CMessage::PrintStatusInfo("特征:"+strWord+"..."); 
  283. N_ft = wordnode.GetDocNum();  
  284. N_W_f_t = wordnode.GetWordNum();
  285. int nCataCount=0;
  286. pos_cata = m_lstTrainCatalogList.GetFirstPosition();
  287. while(pos_cata!=NULL)  // for each catalog 
  288. {
  289. CCatalogNode& catanode = m_lstTrainCatalogList.GetNext(pos_cata);
  290. N_c  = catanode.GetDocNum();
  291. N_W_C  = catanode.m_lTotalWordNum;
  292. N_c_ft = wordnode.GetCataDocNum(catanode.m_idxCata);
  293. N_W_C_f_t =wordnode.GetCataWordNum(catanode.m_idxCata);
  294. // calculation model 
  295. if(m_paramClassifier.m_nOpMode==CClassifierParam::nOpWordMode)   
  296. {
  297. P_c     = 1.0 * N_W_C /N_W;
  298. P_ft = 1.0 * N_W_f_t/N_W;
  299. P_c_ft  = 1.0 * N_W_C_f_t/N_W;
  300. }
  301. else //if(m_paramClassifier.m_nOpMode==CClassifierParam::nOpDocMode)
  302. {
  303. P_c = 1.0 * N_c /N; //C类出现的概率
  304. P_ft = 1.0 * N_ft/N; //含有ft的文档出现的概率
  305. P_c_ft = 1.0 * N_c_ft/N; //C类中含有ft的文档的概率
  306. }
  307. P_n_c = 1 - P_c;
  308. P_n_ft = 1 - P_ft;
  309. P_n_c_ft = P_ft - P_c_ft;
  310. P_c_n_ft = P_c - P_c_ft;
  311. P_n_c_n_ft  = P_n_ft - P_c_n_ft;
  312. wordnode.m_pCataWeight[nCataCount]=0;
  313. // feature selection model
  314. if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_XXMode)
  315. {
  316. // Right half of IG
  317. if ( (fabs(P_c * P_n_ft) > dZero) && ( fabs(P_c_n_ft) > dZero) ) 
  318. {
  319. wordnode.m_pCataWeight[nCataCount]+=P_c_n_ft * log( P_c_n_ft/(P_c * P_n_ft) );
  320. }
  321. }
  322. else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_MIMode)
  323. {
  324. // Mutual Informaiton feature selection
  325. if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) ) 
  326. {
  327. wordnode.m_pCataWeight[nCataCount]+= P_c * log( P_c_ft/(P_c * P_ft) );
  328. }
  329. }
  330. else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_CEMode)
  331. {
  332. // Cross Entropy for text feature selection
  333. if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) ) 
  334. {
  335. wordnode.m_pCataWeight[nCataCount]+= P_c_ft * log( P_c_ft/(P_c * P_ft) );
  336. }
  337. }
  338. else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_X2Mode)
  339. {
  340. // X^2 Statistics feature selection
  341. if ( (fabs(P_n_c * P_ft * P_n_ft) > dZero) ) 
  342. {
  343. wordnode.m_pCataWeight[nCataCount]+= (P_c_ft * P_n_c_n_ft - P_n_c_ft * P_c_n_ft) * (P_c_ft * P_n_c_n_ft - P_n_c_ft * P_c_n_ft) / ( P_ft * P_n_c * P_n_ft);
  344. }
  345. }
  346. else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_WEMode)
  347. {
  348. // Weight of Evielence for text feature selection
  349. double odds_c_ft;
  350. double odds_c;
  351. double P_c_inv_ft=P_c_ft/P_ft;
  352. if( fabs(P_c_inv_ft) < dZero )
  353. odds_c_ft = 1.0 / ( N * N -1);
  354. else if ( fabs(P_c_inv_ft-1) < dZero )
  355. odds_c_ft = N * N -1;
  356. else
  357. odds_c_ft = P_c_inv_ft / (1.0 - P_c_inv_ft);
  358. if( fabs(P_c) < dZero )
  359. odds_c = 1.0 / ( N * N -1);
  360. else if ( fabs(P_c-1) < dZero )
  361. odds_c = N * N -1;
  362. else
  363. odds_c = P_c / (1.0 - P_c);
  364. if( fabs(odds_c) > dZero && fabs(odds_c_ft) > dZero )
  365. {
  366. wordnode.m_pCataWeight[nCataCount]+= P_c * P_ft * fabs( log(odds_c_ft / odds_c) );
  367. }
  368. }
  369. else //if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_IGMode) 
  370. {
  371. // Information gain feature selection
  372. if ( (fabs(P_c * P_n_ft) > dZero) && ( fabs(P_c_n_ft) > dZero) ) 
  373. {
  374. wordnode.m_pCataWeight[nCataCount]+=P_c_n_ft * log( P_c_n_ft/(P_c * P_n_ft) );
  375. }
  376. if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) ) 
  377. {
  378. wordnode.m_pCataWeight[nCataCount]+= P_c_ft * log( P_c_ft/(P_c * P_ft) );
  379. }
  380. }
  381. wordnode.m_dWeight+=wordnode.m_pCataWeight[nCataCount];
  382. wordnode.m_pCataWeightPro[nCataCount] = 1.0 * (N_c_ft+1)/(2+N);//词str属于类别nCataCount的概率
  383. /*
  384. CString strtemp;
  385. strtemp.Format("第%d个类中,词的权重是%lf",nCataCount,wordnode.m_pCataWeight[nCataCount]);
  386. CMessage::PrintInfo(strtemp);
  387. */
  388. nCataCount++;
  389. }
  390. ASSERT(nCataCount==nTotalCata);
  391. }
  392. CMessage::PrintStatusInfo("");
  393. }
  394. //计算每一篇训练文档向量的每一维的权重
  395. void CClassifier::ComputeWeight(bool bMult)
  396. {
  397. long lWordNum=m_lstTrainWordList.GetCount();
  398. if(m_lstTrainWordList.GetCount()<=0) return;
  399. long lDocNum=m_lstTrainCatalogList.GetDocNum();
  400. if(lDocNum<=0) return;
  401. m_lstTrainWordList.ComputeWeight(lDocNum,bMult);
  402. double sum=0.0;
  403. int i=0;
  404. POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition();
  405. while(pos_cata != NULL)  // for each catalog 
  406. {
  407. //取类列表中的每一个类
  408. CCatalogNode& cataNode = m_lstTrainCatalogList.GetNext(pos_cata);
  409. POSITION pos_doc  = cataNode.GetFirstPosition();
  410. while(pos_doc!=NULL)
  411. {
  412. CDocNode& docNode=cataNode.GetNext(pos_doc);
  413. sum=0.0;
  414. for(i=0;i<docNode.m_nAllocLen;i++)
  415. {
  416. docNode.m_sWeightSet[i].s_dWeight*=docNode.m_sWeightSet[i].s_tfi;
  417. sum+=(docNode.m_sWeightSet[i].s_dWeight*docNode.m_sWeightSet[i].s_dWeight);
  418. }
  419. sum=sqrt(sum);
  420. for(i=0;i<docNode.m_nAllocLen;i++)
  421. docNode.m_sWeightSet[i].s_dWeight/=sum;
  422. CMessage::PrintStatusInfo("计算文档"+docNode.m_strDocName+"向量每一维的权重");
  423. }
  424. }
  425. }
  426. void CClassifier::QuickSort(sSortType *psData, int iLo,int iHi)
  427. {
  428.     int Lo, Hi;
  429. double Mid;
  430. sSortType t;
  431.     Lo = iLo;
  432.     Hi = iHi;
  433.     Mid = psData[(Lo + Hi)/2].dWeight;
  434.     do
  435. {
  436. while(psData[Lo].dWeight > Mid) Lo++;
  437. while(psData[Hi].dWeight < Mid) Hi--;
  438. if(Lo <= Hi)
  439. {
  440. t = psData[Lo];
  441. psData[Lo]=psData[Hi];
  442. psData[Hi]=t;
  443. Lo++;
  444. Hi--;
  445. }
  446. }while(Hi>Lo);
  447.     if(Hi > iLo) QuickSort(psData, iLo, Hi);
  448.     if(Lo < iHi) QuickSort(psData, Lo, iHi);
  449. }
  450. void CClassifier::Sort(sSortType *psData,int nSize)
  451. {
  452. QuickSort(psData,0,nSize);
  453. }
  454. // Give m_lstWordList & m_lstTrainCatalogList
  455. // Output the present vector of each document;
  456. // bFlag=false 层次分类的时候使用
  457. void CClassifier::GenModel()
  458. {
  459. CDocNode::AllocTempBuffer(m_lstTrainWordList.GetCount());
  460. POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition();
  461. while(pos_cata != NULL)  // for each catalog 
  462. {
  463. //取类列表中的每一个类
  464. CCatalogNode& cataNode = m_lstTrainCatalogList.GetNext(pos_cata);
  465. POSITION pos_doc  = cataNode.GetFirstPosition();
  466. while(pos_doc!=NULL)
  467. {
  468. CDocNode& docNode=cataNode.GetNext(pos_doc);
  469. if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
  470. docNode.ScanChineseWithDict(cataNode.m_strDirName.GetBuffer(0),m_lstTrainWordList);
  471. else
  472. docNode.ScanEnglishWithDict(cataNode.m_strDirName.GetBuffer(0),m_lstTrainWordList,m_paramClassifier.m_bStem);
  473. docNode.GenDocVector();
  474. CMessage::PrintStatusInfo("生成文档"+docNode.m_strDocName+"的文档向量");
  475. }
  476. }
  477. CDocNode::DeallocTempBuffer();
  478. }
  479. // generate original dictionary (the largest one)
  480. // form train files
  481. bool CClassifier::GenDic()
  482. {
  483. m_lstWordList.InitWordList();
  484. CTime startTime;
  485. CTimeSpan totalTime;
  486. startTime=CTime::GetCurrentTime();
  487. CMessage::PrintInfo(_T("分词程序初始化,请稍候..."));
  488. if(!g_wordSeg.InitWorgSegment(theApp.m_strPath.GetBuffer(0),m_paramClassifier.m_nLanguageType))
  489. {
  490. CMessage::PrintError(_T("分词程序初始化失败!"));
  491. return false;
  492. }
  493. if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
  494. g_wordSeg.SetSegSetting(CWordSegment::uPlace);
  495. totalTime=CTime::GetCurrentTime()-startTime;
  496. CMessage::PrintInfo(_T("分词程序初始化结束,耗时")+totalTime.Format("%H:%M:%S"));
  497. startTime=CTime::GetCurrentTime();
  498. CMessage::PrintInfo(_T("开始扫描训练文档,请稍候..."));
  499. if(m_lstTrainCatalogList.BuildLib(m_paramClassifier.m_txtTrainDir)<=0)
  500. {
  501. CMessage::PrintError("训练文档的总数为0!");
  502. return false;
  503. }
  504. CString strFileName;
  505. POSITION pos_cata=m_lstTrainCatalogList.GetFirstPosition();
  506. int nCount,nCataNum;
  507. nCataNum=m_lstTrainCatalogList.GetCataNum();
  508. while(pos_cata!=NULL)
  509. {
  510. CCatalogNode& catalognode=m_lstTrainCatalogList.GetNext(pos_cata);
  511. POSITION pos_doc=catalognode.GetFirstPosition();
  512. while(pos_doc!=NULL)
  513. {
  514. CDocNode& docnode=catalognode.GetNext(pos_doc);
  515. CMessage::PrintStatusInfo(_T("扫描文档")+docnode.m_strDocName);
  516. if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
  517. nCount=docnode.ScanChinese(catalognode.m_strDirName.GetBuffer(0),
  518. m_lstWordList,nCataNum,catalognode.m_idxCata);
  519. else
  520. nCount=docnode.ScanEnglish(catalognode.m_strDirName.GetBuffer(0),
  521. m_lstWordList,nCataNum,catalognode.m_idxCata,
  522. m_paramClassifier.m_bStem);
  523. if(nCount==0)
  524. {
  525. CMessage::PrintError("文件"+catalognode.m_strDirName+"\"+docnode.m_strDocName+"无内容!");
  526. continue;
  527. }
  528. else if(nCount<0)
  529. {
  530. CMessage::PrintError("文件"+catalognode.m_strDirName+"\"+docnode.m_strDocName+"无法打开!");
  531. continue;
  532. }
  533. catalognode.m_lTotalWordNum+=nCount;// information collection point
  534. }
  535. }
  536. g_wordSeg.FreeWordSegment();
  537. totalTime=CTime::GetCurrentTime()-startTime;
  538. CMessage::PrintInfo(_T("扫描训练文档结束,耗时")+totalTime.Format("%H:%M:%S"));
  539. return true;
  540. }
  541. void CClassifier::InitTrain()
  542. {
  543. m_lstTrainWordList.InitWordList();
  544. m_lstTrainCatalogList.InitCatalogList();
  545. m_lstWordList.InitWordList();
  546. }
  547. //参数nType用来决定分类模型的类别,nType=0代表KNN分类器,nType=1代表SVM分类器
  548. bool CClassifier::WriteModel(CString strFileName, int nType)
  549. {
  550. CFile fOut;
  551. if( !fOut.Open(strFileName,CFile::modeCreate | CFile::modeWrite) )
  552. {
  553. CMessage::PrintError("无法创建文件"+strFileName+"!");
  554. return false;
  555. }
  556. CArchive ar(&fOut,CArchive::store);
  557. if(nType==0)
  558. {
  559. m_lstTrainWordList.DumpToFile(m_paramClassifier.m_txtResultDir+"\features.dat");
  560. m_lstTrainWordList.DumpWordList(m_paramClassifier.m_txtResultDir+"\features.txt");
  561. m_lstTrainCatalogList.DumpToFile(m_paramClassifier.m_txtResultDir+"\train.dat");
  562. m_lstTrainCatalogList.DumpDocList(m_paramClassifier.m_txtResultDir+"\train.txt");
  563. m_paramClassifier.DumpToFile(m_paramClassifier.m_txtResultDir+"\params.dat");
  564. ar<<dwModelFileID;
  565. ar<<CString("params.dat");
  566. ar<<CString("features.dat");
  567. ar<<CString("train.dat");
  568. }
  569. else if(nType==1)
  570. {
  571. m_lstTrainWordList.DumpToFile(m_paramClassifier.m_txtResultDir+"\features.dat");
  572. m_lstTrainWordList.DumpWordList(m_paramClassifier.m_txtResultDir+"\features.txt");
  573. m_lstTrainCatalogList.DumpToFile(m_paramClassifier.m_txtResultDir+"\train.dat",1);
  574. m_lstTrainCatalogList.DumpDocList(m_paramClassifier.m_txtResultDir+"\train.txt");
  575. m_paramClassifier.DumpToFile(m_paramClassifier.m_txtResultDir+"\params.dat");
  576. m_theSVM.com_param.classifier_num=m_lstTrainCatalogList.GetCataNum();
  577. m_theSVM.com_param.trainfile="train.txt";
  578. m_theSVM.com_param.resultpath=m_paramClassifier.m_txtResultDir;
  579. m_theSVM.com_param.DumpToFile(m_paramClassifier.m_txtResultDir+"\svmparams.dat");
  580. ar<<dwModelFileID;
  581. ar<<CString("params.dat");
  582. ar<<CString("features.dat");
  583. ar<<CString("train.dat");
  584. ar<<CString("svmparams.dat");
  585. }
  586. else if(nType==2)
  587. {
  588. m_lstTrainWordList.DumpToFile(m_paramClassifier.m_txtResultDir+"\features.dat");
  589. m_lstTrainWordList.DumpWordList(m_paramClassifier.m_txtResultDir+"\features.txt");
  590. m_lstTrainWordList.DumpWordProList(m_paramClassifier.m_txtResultDir+"\WordPro.txt",m_lstTrainCatalogList.GetCataNum());
  591. m_lstTrainCatalogList.DumpToFile(m_paramClassifier.m_txtResultDir+"\train.dat");
  592. m_lstTrainCatalogList.DumpDocList(m_paramClassifier.m_txtResultDir+"\train.txt");
  593. m_paramClassifier.DumpToFile(m_paramClassifier.m_txtResultDir+"\params.dat");
  594. ar<<dwModelFileID;
  595. ar<<CString("params.dat");
  596. ar<<CString("features.dat");
  597. ar<<CString("train.dat");
  598. ar<<CString("WordPro.txt");
  599. }
  600. ar.Close();
  601. fOut.Close(); 
  602. return true;
  603. }
  604. bool CClassifier::OpenModel(CString strFileName)
  605. {
  606. CFile fIn;
  607. if(!fIn.Open(strFileName,CFile::modeRead))
  608. {
  609. CMessage::PrintError("无法打开文件"+strFileName+"!") ;
  610. return false;
  611. }
  612. CTime startTime=CTime::GetCurrentTime();
  613. CMessage::PrintInfo(_T("正在打开分类模型文件,请稍候..."));
  614. CArchive ar(&fIn,CArchive::load);
  615. CString str,strPath;
  616. DWORD dwFileID;
  617. //读入文件格式标识符
  618. strPath=strFileName.Left(strFileName.ReverseFind('\'));
  619. ar>>dwFileID;
  620. if(dwFileID!=dwModelFileID)
  621. {
  622. ar.Close();
  623. fIn.Close();
  624. CMessage::PrintError("分类模型文件的格式不正确!");
  625. return false;
  626. }
  627. ar>>str;
  628. if(!m_paramClassifier.GetFromFile(strPath+"\"+str))
  629. {
  630. CMessage::PrintError(_T("无法打开训练参数文件"+str+"!"));
  631. return false;
  632. }
  633. m_paramClassifier.m_txtResultDir=strPath;
  634. if(m_paramClassifier.m_nClassifierType==0)
  635. {
  636. ar>>str;
  637. m_lstTrainWordList.InitWordList();
  638. if(!m_lstTrainWordList.GetFromFile(strPath+"\"+str))
  639. {
  640. CMessage::PrintError(_T("无法打开特征类表文件"+str+"!"));
  641. return false;
  642. }
  643. ar>>str;
  644. m_lstTrainCatalogList.InitCatalogList();
  645. if(!m_lstTrainCatalogList.GetFromFile(strPath+"\"+str))
  646. {
  647. CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!"));
  648. return false;
  649. }
  650. }
  651. else if(m_paramClassifier.m_nClassifierType==1)
  652. {
  653. ar>>str;
  654. m_lstTrainWordList.InitWordList();
  655. if(!m_lstTrainWordList.GetFromFile(strPath+"\"+str))
  656. {
  657. CMessage::PrintError(_T("无法打开特征类表文件"+str+"!"));
  658. return false;
  659. }
  660. //对于SVM分类起来说m_lstTrainCatalogList其实没用
  661. //此处读入它只是为了在CLeftViw中显示某些统计信息时使用
  662. ar>>str;
  663. m_lstTrainCatalogList.InitCatalogList();
  664. if(!m_lstTrainCatalogList.GetFromFile(strPath+"\"+str))
  665. {
  666. CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!"));
  667. return false;
  668. }
  669. ar>>str;
  670. if(!m_theSVM.com_param.GetFromFile(strPath+"\"+str))
  671. {
  672. CMessage::PrintError(_T("无法打开SVM训练参数文件"+str+"!"));
  673. return false;
  674. }
  675. m_theSVM.com_param.trainfile=strPath+"\train.txt";
  676. m_theSVM.com_param.resultpath=strPath;
  677. }
  678. else if(m_paramClassifier.m_nClassifierType==2)
  679. {
  680. ar>>str;
  681. m_lstTrainWordList.InitWordList();
  682. if(!m_lstTrainWordList.GetFromFile(strPath+"\"+str))
  683. {
  684. CMessage::PrintError(_T("无法打开特征类表文件"+str+"!"));
  685. return false;
  686. }
  687. ar>>str;
  688. m_lstTrainCatalogList.InitCatalogList();
  689. if(!m_lstTrainCatalogList.GetFromFile(strPath+"\"+str))
  690. {
  691. CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!"));
  692. return false;
  693. }
  694. ar>>str;
  695. if(!m_lstTrainWordList.GetProFromFile(strPath+"\"+str))
  696. {
  697. CMessage::PrintError(_T("无法打开特征词类属概率文件"+str+"!"));
  698. return false;
  699. }
  700. }
  701. ar.Close();
  702. fIn.Close();
  703. Prepare();
  704. CTimeSpan totalTime=CTime::GetCurrentTime()-startTime;
  705. CMessage::PrintInfo(_T("分类模型文件已经打开,耗时")+totalTime.Format("%H:%M:%S")+"rn");
  706. str.Empty();
  707. m_paramClassifier.GetParamString(str);
  708. CMessage::PrintInfo(str);
  709. return true;
  710. }
  711. bool CClassifier::Classify()
  712. {
  713. m_lstTrainCatalogList.DumpCataList(m_paramClassifier.m_strResultDir+"\classes.txt");
  714. CTime startTime;
  715. CTimeSpan totalTime;
  716. startTime=CTime::GetCurrentTime();
  717. CMessage::PrintInfo(_T("正在扫描测试文档,请稍候..."));
  718. if(m_paramClassifier.m_bEvaluation)
  719. {
  720. //vBuildLib方法中已经清空了g_lstTestCatalogList,所以此处无需再对其初始化
  721. m_lstTestCatalogList.BuildLib(m_paramClassifier.m_strTestDir);
  722. if(!m_lstTestCatalogList.BuildCatalogID(m_lstTrainCatalogList))
  723. {
  724. CMessage::PrintError("测试文件中包含有无法识别的类别!");
  725. return false;
  726. }
  727. }
  728. else
  729. {
  730. m_lstTestCatalogList.InitCatalogList();
  731. CCatalogNode catalognode;
  732. catalognode.m_strDirName=m_paramClassifier.m_strTestDir;
  733. catalognode.m_strCatalogName="测试文档";
  734. catalognode.m_idxCata=-1;
  735. POSITION posTemp=m_lstTestCatalogList.AddCata(catalognode);
  736. CCatalogNode& cataTemp=m_lstTestCatalogList.GetAt(posTemp);
  737. cataTemp.SetStartDocID(0);
  738. cataTemp.ScanDirectory(m_paramClassifier.m_strTestDir);
  739. }
  740. if(m_lstTestCatalogList.GetDocNum()<=0)
  741. {
  742. CMessage::PrintError("测试文件总数为0!rn如果不需要对分类结果进行评价时,分类文档必须在"分类文档目录"下,而不是它的子目录下!");
  743. return false;
  744. }
  745. totalTime=CTime::GetCurrentTime()-startTime;
  746. CMessage::PrintInfo(_T("扫描测试文档结束,耗时")+totalTime.Format("%H:%M:%S"));
  747. startTime=CTime::GetCurrentTime();
  748. CMessage::PrintInfo(_T("正在对测试文档进行分类,请稍候..."));
  749. long lCorrect=0,lUnknown=0;
  750. lUnknown=Classify(m_lstTestCatalogList);
  751. lCorrect=SaveResults(m_lstTestCatalogList,m_paramClassifier.m_strResultDir+"\results.txt");
  752. long lTotalNum=m_lstTestCatalogList.GetDocNum()-lUnknown;
  753. CString str;
  754. totalTime=CTime::GetCurrentTime()-startTime;
  755. CMessage::PrintInfo(_T("测试文档分类结束,耗时")+totalTime.Format("%H:%M:%S"));
  756. if (lUnknown>0) 
  757. {
  758. str.Format("无法分类的文档数%d:",lUnknown);
  759. CMessage::PrintInfo(str);
  760. }
  761. if(m_paramClassifier.m_bEvaluation&&lTotalNum>0&&lCorrect>0)
  762. str.Format("测试文档总数:%d,准确率:%f",m_lstTestCatalogList.GetDocNum(),(float)(lCorrect)/(float)(lTotalNum));
  763. else
  764. str.Format("测试文档总数:%d",m_lstTestCatalogList.GetDocNum());
  765. CMessage::PrintInfo(str);
  766. return true;
  767. }
  768. //对Smart格式的文档进行分类
  769. bool CClassifier::ClassifySmart()
  770. {
  771. m_lstTrainCatalogList.DumpCataList(m_paramClassifier.m_strResultDir+"\classes.txt");
  772. m_lstTestCatalogList.InitCatalogList();
  773. CCatalogNode catalognode;
  774. catalognode.m_strDirName=m_paramClassifier.m_strTestDir;
  775. catalognode.m_strCatalogName="测试文档";
  776. catalognode.m_idxCata=-1;
  777. POSITION posTemp=m_lstTestCatalogList.AddCata(catalognode);
  778. CCatalogNode& cataTemp=m_lstTestCatalogList.GetAt(posTemp);
  779. FILE *stream1,*stream2;
  780. if( (stream1 = fopen( m_paramClassifier.m_strTestDir, "r" )) == NULL )
  781. {
  782. CMessage::PrintError("无法打开文件"+m_paramClassifier.m_strTestDir+"!");
  783. return false;
  784. }
  785. //如果是SVM分类器,则需要先将所有测试文档转换为向量,保存到文件test.dat
  786. if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
  787. {
  788. m_theSVM.com_param.classifyfile=m_paramClassifier.m_strResultDir+"\test.dat";
  789. if((stream2=fopen(m_theSVM.com_param.classifyfile,"w"))==NULL)
  790. {
  791. CMessage::PrintError("无法创建测试文档向量文件"+m_theSVM.com_param.classifyfile+"!");
  792. return false;
  793. }
  794. }
  795. CTime startTime;
  796. CTimeSpan totalTime;
  797. startTime=CTime::GetCurrentTime();
  798. CMessage::PrintInfo(_T("正在对测试文档进行分类,请稍候..."));
  799. char fname[10],type[1024],line[4096],content[100*1024];
  800. //falg=1 下一行的内容是文档的类别
  801. //flag=2 下一行的内容是文档的标题
  802. //flag=3 下一行的内容是文档的内容
  803. int flag=0,nCount,len,i;
  804. long lUnknown=0,lDocNum=0;
  805. CStringArray typeArray;
  806. CString strFileName,strCopyFile;
  807. bool bTitle=false; //是否已经读出标题
  808. double dThreshold=(double)m_paramClassifier.m_dThreshold/100.0;
  809. int nWordNum=m_lstTrainWordList.GetCount();
  810. while(!feof(stream1))
  811. {
  812. if(fgets(line,4096,stream1)==NULL) continue;
  813. if(line[0]=='.')
  814. {
  815. if(flag==3)
  816. {
  817. CDocNode doc;
  818. posTemp=cataTemp.AddDoc(doc);
  819. CDocNode& docnode=cataTemp.GetAt(posTemp);
  820. docnode.m_strDocName=fname;
  821. if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
  822. {
  823. nCount=KNNCategory(content,docnode,false);
  824. }
  825. else
  826. {
  827. if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
  828. nCount=docnode.ScanChineseStringWithDict(content,m_lstTrainWordList)-1;
  829. else
  830. nCount=docnode.ScanEnglishStringWithDict(content,m_lstTrainWordList,
  831. m_paramClassifier.m_bStem)-1;
  832. fprintf(stream2,"%d",1);
  833. for(i=0;i<nWordNum;i++)
  834. {
  835. if(docnode.m_pTemp[i].s_tfi!=0) 
  836. fprintf(stream2," %d:%f",i+1,docnode.m_pTemp[i].s_dWeight);
  837. }
  838. fprintf(stream2,"n");
  839. }
  840. if(m_paramClassifier.m_bEvaluation) typeArray.Add(type);
  841. if(nCount<0) 
  842. {
  843. CMessage::PrintError("无法识别文档"+docnode.m_strDocName+"的类别!");
  844. lUnknown++;
  845. }
  846. CMessage::PrintStatusInfo(_T("扫描文档")+docnode.m_strDocName);
  847. flag=0;
  848. bTitle=false;
  849. fname[0]=0;
  850. type[0]=0;
  851. content[0]=0;
  852. lDocNum++;
  853. }
  854. switch(line[1])
  855. {
  856. case 'I':
  857. strcpy(fname,line+3);
  858. len=strlen(fname);
  859. if(fname[len-1]='r') fname[len-1]='';
  860. break;
  861. case 'C':
  862. flag=1;
  863. break;
  864. case 'T':
  865. flag=2;
  866. break;
  867. case 'W':
  868. flag=3;
  869. break;
  870. }
  871. }
  872. else
  873. {
  874. switch(flag)
  875. {
  876. case 1:
  877. strcpy(type,line);
  878. break;
  879. case 2:
  880. if(!bTitle)
  881. {
  882. strcpy(content,line);
  883. bTitle=true;
  884. }
  885. else
  886. strcat(content,line);
  887. break;
  888. case 3:
  889. strcat(content,line);
  890. break;
  891. }
  892. }
  893. }
  894. fclose(stream1);
  895. CMessage::PrintStatusInfo("");
  896. if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM) 
  897. {
  898. fclose(stream2);
  899. startTime=CTime::GetCurrentTime();
  900. CMessage::PrintInfo(_T("正在使用SVM分类器对文档进行分类,请稍候..."));
  901. m_theSVM.com_param.classifyfile=m_paramClassifier.m_strResultDir+"\test.dat";
  902. SVMClassifyVectorFile(m_theSVM.com_param.classifyfile);
  903. totalTime=CTime::GetCurrentTime()-startTime;
  904. CMessage::PrintInfo(_T("SVM分类过程结束,耗时")+totalTime.Format("%H:%M:%S"));
  905. }
  906. long lCorrect=SaveResults(m_lstTestCatalogList,m_paramClassifier.m_strResultDir+"\results.txt",&typeArray);
  907. long lTotalNum=m_lstTestCatalogList.GetDocNum()-lUnknown;
  908. CString str;
  909. totalTime=CTime::GetCurrentTime()-startTime;
  910. CMessage::PrintInfo(_T("测试文档分类结束,耗时")+totalTime.Format("%H:%M:%S"));
  911. if (lUnknown>0) 
  912. {
  913. str.Format("无法分类的文档数%d:",lUnknown);
  914. CMessage::PrintInfo(str);
  915. }
  916. if(m_paramClassifier.m_bEvaluation&&lTotalNum>0&&lCorrect>0)
  917. str.Format("测试文档总数:%d,准确率:%f",m_lstTestCatalogList.GetDocNum(),(float)(lCorrect)/(float)(lTotalNum));
  918. else
  919. str.Format("测试文档总数:%d",m_lstTestCatalogList.GetDocNum());
  920. CMessage::PrintInfo(str);
  921. return true;
  922. }
  923. //对文档进行分类,计算文档和每个类别的相似度,返回值为类别无法识别的文档总数
  924. long CClassifier::SVMClassify(CCatalogList &cataList)
  925. {
  926. long lUnknown=0;
  927. FILE *stream;
  928. m_theSVM.com_param.classifyfile=m_paramClassifier.m_strResultDir+"\test.dat";
  929. if((stream=fopen(m_theSVM.com_param.classifyfile,"w"))==NULL)
  930. {
  931. CMessage::PrintError("无法创建测试文档向量文件"+m_theSVM.com_param.classifyfile+"!");
  932. return 0;
  933. }
  934. CTime startTime;
  935. CTimeSpan totalTime;
  936. CString str;
  937. int nCount=0;
  938. long lWordNum=m_lstTrainWordList.GetCount();
  939. startTime=CTime::GetCurrentTime();
  940. CMessage::PrintInfo(_T("正在生成测试文档的向量形式,请稍候..."));
  941. POSITION pos_cata=cataList.GetFirstPosition();
  942. while(pos_cata!=NULL)
  943. {
  944. CCatalogNode& catalognode=cataList.GetNext(pos_cata);
  945. char *path=catalognode.m_strDirName.GetBuffer(0);
  946. POSITION pos_doc=catalognode.GetFirstPosition();
  947. while(pos_doc!=NULL)
  948. {
  949. CDocNode& docnode=catalognode.GetNext(pos_doc);
  950. if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
  951. nCount=docnode.ScanChineseWithDict(path,m_lstTrainWordList);
  952. else
  953. nCount=docnode.ScanEnglishWithDict(path,m_lstTrainWordList,
  954. m_paramClassifier.m_bStem);
  955. fprintf(stream,"%d",catalognode.m_idxCata+1);
  956. for(int i=0;i<lWordNum;i++)
  957. {
  958. if(docnode.m_pTemp[i].s_tfi!=0) 
  959. fprintf(stream," %d:%f",i+1,docnode.m_pTemp[i].s_dWeight);
  960. }
  961. fprintf(stream,"n");
  962. CMessage::PrintStatusInfo(_T("扫描文档")+docnode.m_strDocName);
  963. if(nCount<0)
  964. {
  965. str="无法识别文档";
  966. str+=catalognode.m_strDirName;
  967. str+="\"+docnode.m_strDocName+"的类别!";
  968. lUnknown++;
  969. CMessage::PrintError(str);
  970. }
  971. }
  972. }
  973. CMessage::PrintStatusInfo(_T(""));
  974. fclose(stream);
  975. totalTime=CTime::GetCurrentTime()-startTime;
  976. CMessage::PrintInfo(_T("测试文档的向量生成结束,耗时")+totalTime.Format("%H:%M:%S"));
  977. startTime=CTime::GetCurrentTime();
  978. CMessage::PrintInfo(_T("正在使用SVM分类器对文档进行分类,请稍候..."));
  979. //为每篇文档和各个类别的相似读数组分配内存
  980. pos_cata=cataList.GetFirstPosition();
  981. while(pos_cata!=NULL)
  982. {
  983. CCatalogNode& catalognode=cataList.GetNext(pos_cata);
  984. POSITION pos_doc=catalognode.GetFirstPosition();
  985. while(pos_doc!=NULL)
  986. {
  987. CDocNode& docnode=catalognode.GetNext(pos_doc);
  988. docnode.AllocResultsBuffer(m_nClassNum);
  989. }
  990. }
  991. SVMClassifyVectorFile(m_theSVM.com_param.classifyfile);
  992. CMessage::PrintStatusInfo(_T(""));
  993. totalTime=CTime::GetCurrentTime()-startTime;
  994. CMessage::PrintInfo(_T("SVM分类过程结束,耗时")+totalTime.Format("%H:%M:%S"));
  995. return lUnknown;
  996. }
  997. //使用KNN方法对文档进行分类,计算文档和每个类别的相似度
  998. //返回值为类别无法识别的文档总数
  999. long CClassifier::KNNClassify(CCatalogList& cataList,int nCmpType)
  1000. {
  1001. long docID=0,lUnknown=0;
  1002. CString str;
  1003. POSITION pos_cata=cataList.GetFirstPosition();
  1004. while(pos_cata!=NULL)
  1005. {
  1006. CCatalogNode& cataNode=cataList.GetNext(pos_cata);
  1007. POSITION pos_doc=cataNode.GetFirstPosition();
  1008. char *path=cataNode.m_strDirName.GetBuffer(0);
  1009. while(pos_doc!=NULL)
  1010. {
  1011. CDocNode& docNode=cataNode.GetNext(pos_doc);
  1012. short id=KNNCategory(path, docNode, true, nCmpType);
  1013. if(id==-1) 
  1014. {
  1015. str="无法识别文档";
  1016. str+=cataNode.m_strDirName;
  1017. str+="\"+docNode.m_strDocName+"的类别!";
  1018. CMessage::PrintError(str);
  1019. lUnknown++;
  1020. }
  1021. CMessage::PrintStatusInfo(_T("扫描文档")+docNode.m_strDocName);
  1022. }
  1023. }
  1024. return lUnknown;
  1025. }
  1026. //计算文档和每个类别的相似度,返回与文档相似度最大的类别ID
  1027. //nCmpType代表相似度的不同计算方式
  1028. short CClassifier::KNNCategory(char *pPath, CDocNode &docNode, bool bFile, int nCmpType)
  1029. {
  1030. short nCataID=-1;
  1031. if(KNNClassify(pPath,docNode,bFile,nCmpType)) nCataID=SingleCategory(docNode);
  1032. return nCataID;
  1033. }
  1034. //如果bFile为真,则参数file为文件的文件名称(包括它的路径)
  1035. short CClassifier::KNNCategory(char *file, bool bFile, int nCmpType)
  1036. {
  1037. CDocNode docNode;
  1038. short id=-1;
  1039. if(bFile)
  1040. {
  1041. char *fname=strrchr(file,'\');
  1042. if(fname==NULL) return -1;
  1043. docNode.m_strDocName=(fname+1);
  1044. char path[MAX_PATH];
  1045. strncpy(path,file,fname-file);
  1046. path[fname-file]=0;
  1047. id=KNNCategory(path,docNode,bFile,nCmpType);
  1048. }
  1049. else
  1050. id=KNNCategory(file,docNode,bFile,nCmpType);
  1051. return id;
  1052. }
  1053. //生成文档docNode的向量形式,调用方法ComputeSimRatio计算器其和每一个类别的相似度
  1054. //参数bFile为真代表pPath为docNode的路径,否则代表需要进行分类的文档的内容
  1055. //参数nCmpType代表相似度的不同计算方式
  1056. bool CClassifier::KNNClassify(char *pPath, CDocNode &docNode, bool bFile, int nCmpType)
  1057. {
  1058. int nCount=0;
  1059. if(bFile)
  1060. {
  1061. if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
  1062. nCount=docNode.ScanChineseWithDict(pPath,m_lstTrainWordList);
  1063. else
  1064. nCount=docNode.ScanEnglishWithDict(pPath,m_lstTrainWordList,
  1065. m_paramClassifier.m_bStem);
  1066. }
  1067. else
  1068. {
  1069. if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
  1070. nCount=docNode.ScanChineseStringWithDict(pPath,m_lstTrainWordList);
  1071. else
  1072. nCount=docNode.ScanEnglishStringWithDict(pPath,m_lstTrainWordList,
  1073. m_paramClassifier.m_bStem);
  1074. }
  1075. if((m_lDocNum>0)&&(nCount>0))
  1076. {
  1077. ComputeSimRatio(docNode,nCmpType);
  1078. return true;
  1079. }
  1080. else
  1081. return false;
  1082. }
  1083. //得到与文档docNode的相似度大于阈值dThreshold的所有类别
  1084. //如果没有大于值阈值的类别,则返回相似度最大的类别
  1085. bool CClassifier::MultiCategory(CDocNode &docNode, CArray<short,short>& aryResult, double dThreshold)
  1086. {
  1087. double *pSimRatio=docNode.m_pResults;
  1088. if(pSimRatio==NULL) return false;
  1089. double dMax=pSimRatio[0];
  1090. short nMax=0;
  1091. bool bFound=false;
  1092. aryResult.RemoveAll();
  1093. for(short i=1;i<m_nClassNum;i++)
  1094. {
  1095. if(pSimRatio[i]>dMax)
  1096. {
  1097. dMax=pSimRatio[i];
  1098. nMax=i;
  1099. }
  1100. if(pSimRatio[i]>dThreshold) 
  1101. {
  1102. aryResult.Add(i);
  1103. bFound=true;
  1104. }
  1105. }
  1106. if(!bFound) aryResult.Add(nMax);
  1107. return true;
  1108. }
  1109. //计算文档docNode和每一个类别的相似度
  1110. //nCmpType代表相似度的不同计算方式
  1111. void CClassifier::ComputeSimRatio(CDocNode &docNode,int nCmpType)
  1112. {
  1113. //计算文档与训练集中每一篇文档的相似度
  1114. int i;
  1115. long k;
  1116. for(i=0;i<m_lDocNum;i++)
  1117. {
  1118. m_pSimilarityRatio[i].dWeight=m_pDocs[i].pDocNode->ComputeSimilarityRatio();
  1119. m_pSimilarityRatio[i].lDocID=i;
  1120. }
  1121. //将测试文档与训练文档集中文档的相似度进行降序排序
  1122. Sort(m_pSimilarityRatio,m_lDocNum-1);
  1123. docNode.AllocResultsBuffer(m_nClassNum);
  1124. double *pSimRatio=docNode.m_pResults;
  1125. for(i=0;i<m_nClassNum;i++) pSimRatio[i]=0;
  1126. if(nCmpType<=0)
  1127. {
  1128. //计算出测试文档的k近邻在每个类别中的数目
  1129. for(i=0;i<m_paramClassifier.m_nKNN;i++)
  1130. {
  1131. k=m_pSimilarityRatio[i].lDocID;
  1132. k=m_pDocs[k].nCataID;
  1133. pSimRatio[k]+=1;
  1134. }
  1135. //按照"测试文档的k近邻在某个类别中的数目/k"得到测试文档和这个类别的相似度
  1136. for(i=0;i<m_nClassNum;i++)
  1137. pSimRatio[i]/=m_paramClassifier.m_nKNN;
  1138. }
  1139. else if(nCmpType==1)
  1140. {
  1141. for(i=0;i<m_paramClassifier.m_nKNN;i++)
  1142. {
  1143. k=m_pSimilarityRatio[i].lDocID;
  1144. k=m_pDocs[k].nCataID;
  1145. pSimRatio[k]+=m_pSimilarityRatio[i].dWeight;
  1146. }
  1147. }
  1148. }
  1149. bool CClassifier::SVMClassify(char *pPath, CDocNode &docNode, bool bFile)
  1150. {
  1151. int nCount=0;
  1152. if(bFile)
  1153. {
  1154. if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
  1155. nCount=docNode.ScanChineseWithDict(pPath,m_lstTrainWordList);
  1156. else
  1157. nCount=docNode.ScanEnglishWithDict(pPath,m_lstTrainWordList,
  1158. m_paramClassifier.m_bStem);
  1159. }
  1160. else
  1161. {
  1162. if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
  1163. nCount=docNode.ScanChineseStringWithDict(pPath,m_lstTrainWordList);
  1164. else
  1165. nCount=docNode.ScanEnglishStringWithDict(pPath,m_lstTrainWordList,
  1166. m_paramClassifier.m_bStem);
  1167. }
  1168. if((m_lDocNum>0)&&(nCount>0))
  1169. {
  1170. DOC doc;
  1171. CString str;
  1172. docNode.GenDocVector(doc);
  1173. docNode.AllocResultsBuffer(m_nClassNum);
  1174. for(int i=0;i<m_nClassNum;i++)
  1175. {
  1176. str.Format("%s\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i+1);
  1177. theClassifier.m_theSVM.com_param.modelfile=str;
  1178. docNode.m_pResults[i]=theClassifier.m_theSVM.svm_classify(doc);
  1179. }
  1180. free(doc.words);
  1181. return true;
  1182. }
  1183. else
  1184. return false;
  1185. }
  1186. void CClassifier::Prepare()
  1187. {
  1188. CTime startTime;
  1189. CTimeSpan totalTime;
  1190. if(m_pDocs!=NULL)
  1191. {
  1192. m_lDocNum=0;
  1193. free(m_pDocs);
  1194. m_pDocs=NULL;
  1195. }
  1196. if(m_pSimilarityRatio!=NULL)
  1197. {
  1198. m_lDocNum=0;
  1199. delete[] m_pSimilarityRatio;
  1200. m_pSimilarityRatio=NULL;
  1201. }
  1202. if(m_pProbability!=NULL)
  1203. {
  1204. m_nClassNum=0;
  1205. delete[] m_pProbability;
  1206. m_pProbability=NULL;
  1207. }
  1208. m_nClassNum=m_lstTrainCatalogList.GetCataNum();
  1209. m_lDocNum=m_lstTrainCatalogList.GetDocNum();
  1210. if(m_paramClassifier.m_nKNN>m_lDocNum) m_paramClassifier.m_nKNN=m_lDocNum;
  1211. m_pSimilarityRatio=new DocWeight[m_lDocNum];
  1212. m_pProbability=new DocWeight[m_nClassNum];
  1213. m_pDocs=(DocCatalog*)malloc(sizeof(DocCatalog)*m_lDocNum);
  1214. long num=0;
  1215. POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition();
  1216. while(pos_cata != NULL)  // for each catalog 
  1217. {
  1218. CCatalogNode& catanode = m_lstTrainCatalogList.GetNext(pos_cata);
  1219. short idxCata=catanode.m_idxCata;
  1220. POSITION pos_doc  = catanode.GetFirstPosition();
  1221. while(pos_doc!=NULL)
  1222. {
  1223. CDocNode& docnode=catanode.GetNext(pos_doc);
  1224. m_pDocs[num].pDocNode=&docnode;
  1225. m_pDocs[num].nCataID=idxCata;
  1226. num++;
  1227. }
  1228. }
  1229. CDocNode::AllocTempBuffer(m_lstTrainWordList.GetCount());
  1230. }
  1231. void CClassifier::Sort(DocWeight *pData,int nSize)
  1232. {
  1233. QuickSort(pData,0,nSize);
  1234. }
  1235. void CClassifier::QuickSort(DocWeight *psData, int iLo,int iHi)
  1236. {
  1237.     int Lo, Hi;
  1238. double Mid;
  1239. DocWeight t;
  1240.     Lo = iLo;
  1241.     Hi = iHi;
  1242.     Mid = psData[(Lo + Hi)/2].dWeight;
  1243.     do
  1244. {
  1245. while(psData[Lo].dWeight > Mid) Lo++;
  1246. while(psData[Hi].dWeight < Mid) Hi--;
  1247. if(Lo <= Hi)
  1248. {
  1249. t = psData[Lo];
  1250. psData[Lo]=psData[Hi];
  1251. psData[Hi]=t;
  1252. Lo++;
  1253. Hi--;
  1254. }
  1255. }while(Hi>Lo);
  1256.     if(Hi > iLo) QuickSort(psData, iLo, Hi);
  1257.     if(Lo < iHi) QuickSort(psData, Lo, iHi);
  1258. }
  1259. //将分类结果保存到文件strFileName中,返回正确分类的文档总数
  1260. //如果分类参数中要求拷贝文件到结果类别目录,则执行拷贝操作
  1261. //参数typeArray只有在多类分类,且需要进行评价的时候才会用到
  1262. long CClassifier::SaveResults(CCatalogList &cataList, CString strFileName, CStringArray *aryType)
  1263. {
  1264. FILE *stream;
  1265. if( (stream = fopen(strFileName, "w+" )) == NULL )
  1266. {
  1267. CMessage::PrintError("无法创建分类结果文件"+strFileName+"!");
  1268. return -1;
  1269. }
  1270. CString str1,str2;
  1271. long lCorrect=0;
  1272. long docID=0;
  1273. int i;
  1274. char path[MAX_PATH];
  1275. CArray<short,short> aryResult;
  1276. CArray<short,short> aryAnswer;
  1277. double dThreshold=(double)m_paramClassifier.m_dThreshold/100.0;
  1278. POSITION pos_cata=cataList.GetFirstPosition();
  1279. while(pos_cata!=NULL)
  1280. {
  1281. CCatalogNode& cataNode=cataList.GetNext(pos_cata);
  1282. short id=cataNode.m_idxCata;
  1283. strcpy(path,cataNode.m_strDirName.GetBuffer(0));
  1284. POSITION pos_doc=cataNode.GetFirstPosition();
  1285. while(pos_doc!=NULL)
  1286. {
  1287. CDocNode& docNode=cataNode.GetNext(pos_doc);
  1288. if(docNode.m_nCataID<0) continue;
  1289. str1.Empty();
  1290. str2.Empty();
  1291. //如果是多类分类
  1292. if(m_paramClassifier.m_nClassifyType==CClassifierParam::nFT_Multi)
  1293. {
  1294. MultiCategory(docNode,aryResult,dThreshold);
  1295. //如果需要将分类结果拷贝到分类结果目录
  1296. if(m_paramClassifier.m_bCopyFiles)
  1297. {
  1298. for(i=0;i<aryResult.GetSize();i++)
  1299. {
  1300. m_lstTrainCatalogList.GetCataName(aryResult[i],str1);
  1301. str2=str2+str1+",";
  1302. if(m_paramClassifier.m_bCopyFiles)
  1303. CopyFile(docNode.m_strDocName.GetBuffer(0),path,
  1304. m_paramClassifier.m_strResultDir.GetBuffer(0),str1.GetBuffer(0));
  1305. }
  1306. str2.SetAt(str2.GetLength()-1,' ');
  1307. }
  1308. //如果需要对分类结果进行评价
  1309. if(m_paramClassifier.m_bEvaluation)
  1310. {
  1311. m_lstTrainCatalogList.GetCataIDArrayFromString(aryType->GetAt(docID).GetBuffer(0),aryAnswer);
  1312. //得到答案字符串
  1313. for(i=0;i<aryAnswer.GetSize();i++)
  1314. {
  1315. str1.Format("%d",aryAnswer[i]);
  1316. str2+=(str1+",");
  1317. }
  1318. str2.SetAt(str2.GetLength()-1,' ');
  1319. fprintf(stream,"%d %s %s",docID,docNode.m_strDocName,str2);
  1320. //得到分类结果字符串
  1321. str2.Empty();
  1322. for(i=0;i<aryResult.GetSize();i++)
  1323. {
  1324. str1.Format("%d",aryResult[i]);
  1325. str2+=(str1+",");
  1326. }
  1327. str2=str2.Left(str2.GetLength()-1);
  1328. fprintf(stream,"%sn",str2);
  1329. }
  1330. else
  1331. {
  1332. if(str2.IsEmpty())
  1333. {
  1334. for(i=0;i<aryResult.GetSize();i++)
  1335. {
  1336. m_lstTrainCatalogList.GetCataName(aryResult[i],str1);
  1337. str2=str2+str1+",";
  1338. }
  1339. str2.SetAt(str2.GetLength()-1,' ');
  1340. }
  1341. fprintf(stream,"%dt%stt%sn",docID,docNode.m_strDocName,str2);
  1342. }
  1343. }
  1344. //如果是单类分类
  1345. else
  1346. {
  1347. //如果需要将分类结果拷贝到分类结果目录
  1348. if(m_paramClassifier.m_bCopyFiles)
  1349. {
  1350. m_lstTrainCatalogList.GetCataName(docNode.m_nCataID,str1);
  1351. CopyFile(docNode.m_strDocName.GetBuffer(0),
  1352. cataNode.m_strDirName.GetBuffer(0),
  1353. m_paramClassifier.m_strResultDir.GetBuffer(0),
  1354. str1.GetBuffer(0));
  1355. }
  1356. //如果需要对分类结果进行评价
  1357. if(m_paramClassifier.m_bEvaluation)
  1358. {
  1359. if(docNode.m_nCataID==id) lCorrect++;
  1360. fprintf(stream,"%d %s %d %dn",docID,docNode.m_strDocName,
  1361. cataNode.m_idxCata,docNode.m_nCataID);
  1362. }
  1363. else
  1364. {
  1365. if(str1.IsEmpty()) m_lstTrainCatalogList.GetCataName(docNode.m_nCataID,str1);
  1366. fprintf(stream,"%dt%stt%sn",docID,docNode.m_strDocName,str1);
  1367. }
  1368. }
  1369. docID++;
  1370. }
  1371. }
  1372. fclose(stream);
  1373. return lCorrect;
  1374. }
  1375. void CClassifier::CopyFile(char *pFileName, char *pSource, char *pTarget, char *pCatalog)
  1376. {
  1377. char targetFile[MAX_PATH];
  1378. strcpy(targetFile,pTarget);
  1379. strcat(targetFile,"\");
  1380. strcat(targetFile,pCatalog);
  1381. if(_chdir(targetFile)<0)
  1382. if(_mkdir(targetFile)<0) return;
  1383. char sourceFile[MAX_PATH];
  1384. strcpy(sourceFile,pSource);
  1385. strcat(sourceFile,"\");
  1386. strcat(sourceFile,pFileName);
  1387. strcat(targetFile,"\");
  1388. strcat(targetFile,pFileName);
  1389. ::CopyFile(sourceFile,targetFile,false);
  1390. }
  1391. void CClassifier::Evaluate(CString strPath)
  1392. {
  1393. CString strFileName=strPath;
  1394. strFileName=strFileName+"\multieval.exe ";
  1395. strFileName=strFileName+theClassifier.m_paramClassifier.m_strResultDir+"\classes.txt ";
  1396. strFileName=strFileName+theClassifier.m_paramClassifier.m_strResultDir+"\results.txt";
  1397. if(WinExec(strFileName,SW_SHOWNORMAL)<32)
  1398. AfxMessageBox("分类结果评测程序不存在!");
  1399. }
  1400. short CClassifier::SingleCategory(CDocNode &docNode)
  1401. {
  1402. short nCataID=-1;
  1403. double *pSimRatio=docNode.m_pResults;
  1404. //得到文档的所属类别nMaxCatID
  1405. double dMaxNum=pSimRatio[0];
  1406. nCataID=0;
  1407. for(int i=1;i<m_nClassNum;i++)
  1408. {
  1409. if(pSimRatio[i]>dMaxNum)
  1410. {
  1411. dMaxNum=pSimRatio[i];
  1412. nCataID=i;
  1413. }
  1414. }
  1415. docNode.m_nCataID=nCataID;
  1416. return nCataID;
  1417. }
  1418. short CClassifier::SVMCategory(char *pPath, CDocNode &docNode, bool bFile)
  1419. {
  1420. short nCataID=-1;
  1421. if(SVMClassify(pPath,docNode,bFile)) nCataID=SingleCategory(docNode);
  1422. return nCataID;
  1423. }
  1424. short CClassifier::SVMCategory(char *file, bool bFile)
  1425. {
  1426. CDocNode docNode;
  1427. short id=-1;
  1428. if(bFile)
  1429. {
  1430. char *fname=strrchr(file,'\');
  1431. if(fname==NULL) return -1;
  1432. docNode.m_strDocName=(fname+1);
  1433. char path[MAX_PATH];
  1434. strncpy(path,file,fname-file);
  1435. path[fname-file]=0;
  1436. id=SVMCategory(path,docNode,bFile);
  1437. }
  1438. else
  1439. id=SVMCategory(file,docNode,bFile);
  1440. return id;
  1441. }
  1442. void CClassifier::SVMClassifyVectorFile(CString strFileName)
  1443. {
  1444. //为了计算分类结果,用来保存每个分类器分类结果的数组
  1445. CTime startTime;
  1446. CTimeSpan totalTime;
  1447. CString str;
  1448. long num=m_lstTestCatalogList.GetDocNum(),lDocNum=0;
  1449. double *fpWeight=new double[num];
  1450. POSITION pos_doc, pos_cata;
  1451. m_theSVM.com_param.classifyfile=strFileName;
  1452. for(int i=1;i<=m_nClassNum;i++)
  1453. {
  1454. memset(fpWeight,0,sizeof(double)*num);
  1455. startTime=CTime::GetCurrentTime();
  1456. str.Format("正在使用第%d个SVM分类器对文档进行分类,请稍候...",i);
  1457. CMessage::PrintInfo(str);
  1458. str.Format("%s\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i);
  1459. m_theSVM.com_param.modelfile=str;
  1460. m_theSVM.svm_classify(i,fpWeight);
  1461. //将文档和当前类别的相似度赋给m_pResults[i-1]
  1462. lDocNum=0;
  1463. pos_cata=m_lstTestCatalogList.GetFirstPosition();
  1464. while(pos_cata!=NULL)
  1465. {
  1466. CCatalogNode& catalognode=m_lstTestCatalogList.GetNext(pos_cata);
  1467. pos_doc=catalognode.GetFirstPosition();
  1468. while(pos_doc!=NULL)
  1469. {
  1470. CDocNode& docnode=catalognode.GetNext(pos_doc);
  1471. docnode.AllocResultsBuffer(m_nClassNum);
  1472. docnode.m_pResults[i-1]=fpWeight[lDocNum];
  1473. lDocNum++;
  1474. }
  1475. }
  1476. totalTime=CTime::GetCurrentTime()-startTime;
  1477. str.Format("第%d个SVM分类器分类结束,耗时",i);
  1478. CMessage::PrintInfo(str+totalTime.Format("%H:%M:%S"));
  1479. }
  1480. delete[] fpWeight;
  1481. //计算和文档的相似度最大的类别
  1482. pos_cata=m_lstTestCatalogList.GetFirstPosition();
  1483. while(pos_cata!=NULL)
  1484. {
  1485. CCatalogNode& catalognode=m_lstTestCatalogList.GetNext(pos_cata);
  1486. pos_doc=catalognode.GetFirstPosition();
  1487. while(pos_doc!=NULL)
  1488. {
  1489. CDocNode& docnode=catalognode.GetNext(pos_doc);
  1490. docnode.m_nCataID=SingleCategory(docnode);
  1491. }
  1492. }
  1493. }
  1494. //计算文档docNode属于每一个类别的概率
  1495. void CClassifier::ComputePro(CDocNode &docNode)
  1496. {
  1497. //计算文档与训练集中每一类文档的概率
  1498. int i,j,l,FeaNum=0;
  1499. long k;
  1500. int N; //总的文档数;
  1501. int N_c; //C类的文档数
  1502. int N_Cata; //总类数
  1503. N = m_lstTrainCatalogList.GetDocNum();
  1504. POSITION pos_cata;
  1505. CString     strWord;
  1506. // calculate the weight of each word to all catalog
  1507. N = m_lstTrainCatalogList.GetDocNum();
  1508. N_Cata=m_lstTrainCatalogList.GetCataNum();
  1509. docNode.AllocResultsBuffer(m_nClassNum);
  1510. int nCataCount=0;
  1511. double ClassPro=0.0;
  1512. pos_cata = m_lstTrainCatalogList.GetFirstPosition();
  1513. i=m_lstTrainWordList.GetCount();
  1514. POSITION pos=m_lstTrainWordList.GetFirstPosition();
  1515. for(l=0;l<m_nClassNum;l++)
  1516. {
  1517. CCatalogNode& catanode = m_lstTrainCatalogList.GetNext(pos_cata);
  1518. N_c  = catanode.GetDocNum();
  1519. m_pProbability[l].dWeight=(1.0+N_c)/(N_Cata+N);
  1520. m_pProbability[l].lDocID=l;
  1521. }
  1522. for(j=0;j<i;j++)
  1523. {
  1524. if(docNode.m_pTemp[j].s_tfi!=0) 
  1525. {
  1526. FeaNum++;
  1527. CWordNode &WordProNode = m_lstTrainWordList.GetWordProByID(pos,j);
  1528. for(l=0;l<m_nClassNum;l++)
  1529. {
  1530. m_pProbability[l].dWeight+=WordProNode.m_pCataWeightPro[l];
  1531. // CString strtemp;
  1532. // strtemp.Format("%lf  %d",m_pProbability[s].dWeight,s);
  1533. // CMessage::PrintInfo(strtemp);
  1534. }
  1535. }
  1536. }
  1537. for(l=0;l<m_nClassNum;l++)
  1538. {
  1539. m_pProbability[l].dWeight/=FeaNum;
  1540. if(ClassPro<m_pProbability[l].dWeight)
  1541. {
  1542. ClassPro = m_pProbability[l].dWeight;
  1543. docNode.m_nCataID = l;
  1544. }
  1545. }
  1546. /*
  1547. while(pos_cata!=NULL)  // for each catalog 
  1548. {
  1549. docNode.m_pResults[nCataCount]=0;
  1550. CCatalogNode& catanode = m_lstTrainCatalogList.GetNext(pos_cata);
  1551. N_c  = catanode.GetDocNum();
  1552. m_pProbability[nCataCount].dWeight=(1.0+N_c)/(N_Cata+N);
  1553. m_pProbability[nCataCount].lDocID=nCataCount;
  1554. // for(int i=0;i<lWordNum;i++)
  1555. // {
  1556. // if(docnode.m_pTemp[i].s_tfi!=0) 
  1557. // fprintf(stream," %d:%f",i+1,docnode.m_pTemp[i].s_dWeight);
  1558. // }
  1559. // m_pProbability[nCataCount].dWeight+=docNode.ComputeProbability(m_lstTrainWordList,nCataCount);
  1560. // CString strttemp;
  1561. // strttemp.Format("%lf",m_pProbability[nCataCount].dWeight);
  1562. // CMessage::PrintInfo(strttemp);
  1563. POSITION pos=m_lstTrainWordList.GetFirstPosition();
  1564. for(j=0;j<i;j++)
  1565. {
  1566. if(docNode.m_pTemp[j].s_tfi!=0) 
  1567. {
  1568. m_pProbability[nCataCount].dWeight+=m_lstTrainWordList.GetWordProByID(pos,j,nCataCount);
  1569. // CString str = m_lstTrainWordList.GetWordByID(j);
  1570. // m_pProbability[nCataCount].dWeight+=1;
  1571. // CWordNode &wordnode = m_lstTrainWordList.m_lstWordList[str];
  1572. // m_pProbability[nCataCount].dWeight+= wordnode.m_pCataWeightPro[nCataCount];
  1573. }
  1574. // i=docNode.m_sWeightSet[nCataCount].s_idxWord;
  1575. // CString str = m_lstTrainWordList.GetWordByID(i);
  1576. // CWordNode &wordnode = m_lstTrainWordList.m_lstWordList[str];
  1577. // m_pProbability[nCataCount].dWeight*= wordnode.m_pCataWeightPro[nCataCount];
  1578. }
  1579. if(ClassPro<m_pProbability[nCataCount].dWeight)
  1580. {
  1581. ClassPro = m_pProbability[nCataCount].dWeight;
  1582. docNode.m_nCataID = nCataCount;
  1583. // CString strtemp;
  1584. // strtemp.Format("%d  %lf  %lf",nCataCount,(1.0+N_c)/(N_Cata+N),ClassPro);
  1585. // CMessage::PrintInfo(strtemp);
  1586. }
  1587. nCataCount++;
  1588. }
  1589. */
  1590. }
  1591. //生成文档docNode的向量形式,调用方法ComputePro计算器其和每一个类别的相似度
  1592. //参数bFile为真代表pPath为docNode的路径,否则代表需要进行分类的文档的内容
  1593. //参数nCmpType代表相似度的不同计算方式
  1594. bool CClassifier::BAYESClassify(char *pPath, CDocNode &docNode, bool bFile)
  1595. {
  1596. int nCount=0;
  1597. if(bFile)
  1598. {
  1599. if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
  1600. nCount=docNode.ScanChineseWithDict(pPath,m_lstTrainWordList);
  1601. else
  1602. nCount=docNode.ScanEnglishWithDict(pPath,m_lstTrainWordList,
  1603. m_paramClassifier.m_bStem);
  1604. }
  1605. else
  1606. {
  1607. if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
  1608. nCount=docNode.ScanChineseStringWithDict(pPath,m_lstTrainWordList);
  1609. else
  1610. nCount=docNode.ScanEnglishStringWithDict(pPath,m_lstTrainWordList,
  1611. m_paramClassifier.m_bStem);
  1612. }
  1613. if((m_lDocNum>0)&&(nCount>0))
  1614. {
  1615. ComputePro(docNode);
  1616. return true;
  1617. }
  1618. else
  1619. return false;
  1620. }
  1621. //计算文档和每个类别的相似度,返回与文档相似度最大的类别ID
  1622. short CClassifier::BAYESCategory(char *pPath, CDocNode &docNode, bool bFile)
  1623. {
  1624. BAYESClassify(pPath,docNode,bFile);
  1625. return docNode.m_nCataID;
  1626. }
  1627. long CClassifier::BAYESClassify(CCatalogList &cataList)
  1628. {
  1629. long docID=0,lUnknown=0;
  1630. CString str;
  1631. POSITION pos_cata=cataList.GetFirstPosition();
  1632. while(pos_cata!=NULL)
  1633. {
  1634. CCatalogNode& cataNode=cataList.GetNext(pos_cata);
  1635. POSITION pos_doc=cataNode.GetFirstPosition();
  1636. char *path=cataNode.m_strDirName.GetBuffer(0);
  1637. while(pos_doc!=NULL)
  1638. {
  1639. CDocNode& docNode=cataNode.GetNext(pos_doc);
  1640. short id=BAYESCategory(path, docNode, true);
  1641. if(id==-1) 
  1642. {
  1643. str="无法识别文档";
  1644. str+=cataNode.m_strDirName;
  1645. str+="\"+docNode.m_strDocName+"的类别!";
  1646. CMessage::PrintError(str);
  1647. lUnknown++;
  1648. }
  1649. CMessage::PrintStatusInfo(_T("扫描文档")+docNode.m_strDocName);
  1650. }
  1651. }
  1652. return lUnknown;
  1653. }
  1654. short CClassifier::GetCategory(char *file, bool bFile)
  1655. {
  1656. short result=-1;
  1657. if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
  1658. result=KNNCategory(file,bFile);
  1659. else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
  1660. result=SVMCategory(file,bFile);
  1661. return result;
  1662. }
  1663. short CClassifier::GetCategory(char *path, CDocNode &docNode, bool bFile)
  1664. {
  1665. short result=-1;
  1666. if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
  1667. result=KNNCategory(path,docNode,bFile);
  1668. else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
  1669. result=SVMCategory(path,docNode,bFile);
  1670. return result;
  1671. }
  1672. bool CClassifier::Classify(char *path, CDocNode &docNode, bool bFile)
  1673. {
  1674. bool result=false;
  1675. if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
  1676. result=KNNClassify(path,docNode,bFile);
  1677. else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
  1678. result=SVMClassify(path,docNode,bFile);
  1679. else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_BAYES)
  1680. result=BAYESClassify(path,docNode,bFile);
  1681. return result;
  1682. }
  1683. long CClassifier::Classify(CCatalogList &cataList)
  1684. {
  1685. long lUnknown=0;
  1686. if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
  1687. lUnknown=KNNClassify(m_lstTestCatalogList);
  1688. else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
  1689. lUnknown=SVMClassify(m_lstTestCatalogList);
  1690. else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_BAYES)
  1691. lUnknown=BAYESClassify(m_lstTestCatalogList);
  1692. else 
  1693. CMessage::PrintError("无法确定分类器的类型!");
  1694. return lUnknown;
  1695. }