predict.cpp
上传用户:xgw_05
上传日期:2014-12-08
资源大小:2726k
文件大小:8k
源码类别:

.net编程

开发平台:

Java

  1. #include <stdlib.h>
  2. #include <string.h>
  3. #include <fstream.h>
  4. #include "globals.h"
  5. #include "example_set.h"
  6. #include "svm_c.h"
  7. #include "svm_nu.h"
  8. #include "parameters.h"
  9. #include "kernel.h"
  10. #include "version.h"
  11. // global svm-objects
  12. kernel_c* kernel=0;
  13. parameters_c* parameters=0;
  14. svm_c* svm;
  15. example_set_c* training_set=0;
  16. int is_linear=1; // linear kernel?
  17. struct example_set_list{
  18.   example_set_c* the_set;
  19.   example_set_list* next;
  20. };
  21. example_set_list* test_sets = 0;
  22. void print_help(){
  23.   cout<<endl;
  24.   cout<<"predict: predict a set of examples with a trained SVM."<<endl<<endl;
  25.   cout<<"usage: predict"<<endl
  26.       <<"       predict <FILE>"<<endl
  27.       <<"       predict <FILE1> <FILE2> ..."<<endl<<endl;
  28.   cout<<"The input has to consist of:"<<endl
  29.       <<"- the svm parameters"<<endl
  30.       <<"- the kernel definition"<<endl
  31.       <<"- the training result set"<<endl
  32.       <<"- one or more sets to predict"<<endl;
  33.   cout<<endl<<"See the documentation for the input format. The first example set to be entered is considered to be the training set, all others are test sets. Each input file can consist of one or more definitions. If no input file is specified, the input is read from <stdin>."<<endl<<endl;
  34.   cout<<endl<<"This software is free only for non-commercial use. It must not be modified and distributed without prior permission of the author. The author is not responsible for implications from the use of this software."<<endl;
  35.   exit(0);
  36. };
  37. void read_input(istream& input_stream, char* filename){
  38.   // returns number of examples sets read
  39.   char* s = new char[MAXCHAR];
  40.   char next;
  41.   next = input_stream.peek();
  42.   if(next == EOF){ 
  43.     // set stream to eof
  44.     next = input_stream.get(); 
  45.   };
  46.   while(! input_stream.eof()){
  47.     if('#' == next){
  48.       // ignore comment
  49.       input_stream.getline(s,MAXCHAR);
  50.     }
  51.     else if('n' == next){
  52.       // ignore newline
  53.       next = input_stream.get();
  54.     }
  55.     else if('@' == next){
  56.       // new section
  57.       input_stream >> s;
  58.       if(0==strcmp("@parameters",s)){
  59. // read parameters
  60. if(parameters == 0){
  61.   parameters = new parameters_c();
  62.   input_stream >> *parameters;
  63. }
  64. else{
  65.   cout <<"*** ERROR: Parameters multiply defined"<<endl;
  66.   throw input_exception();
  67. };
  68.       }
  69.       else if(0==strcmp("@examples",s)){
  70. if(0 == training_set){
  71.   // input training set
  72.   training_set = new example_set_c();
  73.   if(0 != parameters){
  74.     training_set->set_format(parameters->default_example_format);
  75.   };
  76.   input_stream  >> *training_set;     
  77.   training_set->set_filename(filename);
  78.   cout<<"   read "<<training_set->size()<<" examples, format "<<training_set->my_format<<", dimension = "<<training_set->get_dim()<<"."<<endl;
  79. }
  80. else{
  81.   // input test sets
  82.   example_set_list* test_set = new example_set_list;
  83.   test_set->the_set = new example_set_c();
  84.   if(0 != parameters){
  85.     (test_set->the_set)->set_format(parameters->default_example_format);
  86.   };
  87.   input_stream >> *(test_set->the_set);
  88.   (test_set->the_set)->set_filename(filename);
  89.   test_set->next = test_sets;
  90.   test_sets = test_set;
  91.   cout<<"   read "<<(test_set->the_set)->size()<<" examples, format "<<(test_set->the_set)->my_format<<", dimension = "<<(test_set->the_set)->get_dim()<<"."<<endl;
  92. };
  93.       }
  94.       else if(0==strcmp("@kernel",s)){
  95. if(0 == kernel){
  96.   kernel_container_c k_cont;
  97.   input_stream >> k_cont;
  98.   kernel = k_cont.get_kernel();
  99. }
  100. else{
  101.   cout <<"*** ERROR: Kernel multiply defined"<<endl;
  102.   throw input_exception();
  103. };
  104.       };
  105.     }
  106.     else{
  107.       // default = "@examples"
  108.       if(0 == training_set){
  109. // input training set
  110. training_set = new example_set_c();
  111. if(0 != parameters){
  112.   training_set->set_format(parameters->default_example_format);
  113. };
  114. input_stream  >> *training_set;     
  115. training_set->set_filename(filename);
  116. cout<<"   read "<<training_set->size()<<" examples, format "<<training_set->my_format<<", dimension = "<<training_set->get_dim()<<"."<<endl;
  117.       }
  118.       else{
  119. // input test sets
  120. example_set_list* test_set = new example_set_list;
  121. test_set->the_set = new example_set_c();
  122. if(0 != parameters){
  123.   (test_set->the_set)->set_format(parameters->default_example_format);
  124. };
  125. input_stream >> *(test_set->the_set);
  126. (test_set->the_set)->set_filename(filename);
  127. test_set->next = test_sets;
  128. test_sets = test_set;
  129. cout<<"   read "<<(test_set->the_set)->size()<<" examples, format "<<(test_set->the_set)->my_format<<", dimension = "<<(test_set->the_set)->get_dim()<<"."<<endl;
  130.       };
  131.     };
  132.     next = input_stream.peek();
  133.     if(next == EOF){ 
  134.       // set stream to eof
  135.       next = input_stream.get(); 
  136.     };
  137.   };
  138.   delete []s;
  139. };
  140. ///////////////////////////////////////////////////////////////
  141. int main(int argc,char* argv[]){
  142.   cout<<"*** mySVM version "<<mysvmversion<<" ***"<<endl;
  143.   // read objects
  144.   try{
  145.     if(argc<2){
  146.       cout<<"Reading from STDIN"<<endl;
  147.       // read vom cin
  148.       read_input(cin,"mysvm");
  149.     }
  150.     else{
  151.       char* s = argv[1];
  152.       if((0==strcmp("-h",s)) || (0==strcmp("-help",s)) || (0==strcmp("--help",s))){
  153. // print out command-line help
  154. print_help();
  155.       }
  156.       else{
  157. // read in all input files
  158. for(int i=1;i<argc;i++){
  159.   if(0==strcmp(argv[i],"-")){
  160.     cout<<"Reading from STDIN"<<endl;
  161.     // read vom cin
  162.     read_input(cin,"mysvm");
  163.   }
  164.   else{
  165.     cout<<"Reading "<<argv[i]<<endl;
  166.     ifstream input_file(argv[i]);
  167.     if(input_file.bad()){
  168.       cout<<"ERROR: Could not read file ""<<argv[i]<<"", exiting."<<endl;
  169.       exit(1);
  170.     };
  171.     read_input(input_file,argv[i]);
  172.     input_file.close();
  173.   };
  174. };
  175.       };
  176.     };
  177.   }
  178.   catch(general_exception &the_ex){
  179.     cout<<"*** Error while reading input: "<<the_ex.error_msg<<endl;
  180.     exit(1);
  181.   }
  182.   catch(...){
  183.     cout<<"*** Program ended because of unknown error while reading input"<<endl;
  184.     exit(1);
  185.   };
  186.   if(0 == parameters){
  187.     parameters = new parameters_c();
  188.     if(training_set->initialised_pattern_y()){
  189.       parameters->is_pattern = 1;
  190.       parameters->do_scale_y = 0;
  191.     };
  192.   };
  193.   if(0 == kernel){
  194.     kernel = new kernel_dot_c();
  195.   };
  196.   if(0 == training_set){
  197.     cout << "*** ERROR: You did not enter the training set"<<endl;
  198.     exit(1);
  199.   };
  200.   if(parameters->is_distribution){
  201.     svm = new svm_distribution_c();
  202.   }
  203.   else if(parameters->is_nu){
  204.     if(parameters->is_pattern){
  205.       svm = new svm_nu_pattern_c();
  206.     }
  207.     else{
  208.       svm = new svm_nu_regression_c();
  209.     };
  210.   }
  211.   else if(parameters->is_pattern){
  212.     svm = new svm_pattern_c();
  213.   }
  214.   else{
  215.     svm = new svm_regression_c();
  216.   };
  217.   // scale examples
  218.   if(parameters->do_scale){
  219.     training_set->scale(parameters->do_scale_y);
  220.   };
  221.   kernel->init(parameters->kernel_cache,training_set);
  222.   svm->init(kernel,parameters);
  223.   svm->set_svs(training_set);
  224.   // testing
  225.   if(0 != test_sets){
  226.     cout<<"----------------------------------------"<<endl;
  227.     cout<<"Predicting"<<endl;
  228.     example_set_c* next_test;
  229.     SVMINT test_no = 0;
  230.     char* outname = new char[MAXCHAR];
  231.     while(test_sets != 0){
  232.       test_no++;
  233.       next_test = test_sets->the_set;
  234.       if(training_set->initialised_scale()){
  235. next_test->scale(training_set->get_exp(),
  236.  training_set->get_var(),
  237.  training_set->get_dim());
  238.       };
  239.       if(next_test->initialised_y()){
  240. cout<<"Testing examples from file "<<(next_test->get_filename())<<endl;
  241. svm->test(next_test,1);
  242.       };
  243.       cout<<"Predicting examples from file "<<(next_test->get_filename())<<endl;
  244.       svm->predict(next_test);
  245.       // output to file .pred
  246.       strcpy(outname,next_test->get_filename());
  247.       strcat(outname,".pred");
  248.       ofstream output_file(outname,
  249.    ios::out|ios::trunc);
  250.       next_test->output_ys(output_file);
  251.       output_file.close();
  252.       cout<<"Prediction saved in file "<<(next_test->get_filename())<<".pred"<<endl;
  253.       test_sets = test_sets->next; // skip delete!
  254.     };
  255.     delete []outname;
  256.   };
  257.   if(parameters->verbosity > 1){
  258.     cout << "mysvm ended successfully."<<endl;
  259.   };
  260.   return(0);
  261. };