1 #include "opencv2/ml/ml.hpp" 
  2 #include "opencv2/core/core_c.h" 
  3 #include <stdio.h> 
  4 #include <map> 
  5  
  6 void help() { 
  7   printf( 
  8     "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees:\n" 
  9     "CvDTree dtree;\n" 
 10     "CvBoost boost;\n" 
 11     "CvRTrees rtrees;\n" 
 12     "CvERTrees ertrees;\n" 
 13     "CvGBTrees gbtrees;\n" 
 14     "Call:\n\t./tree_engine [-r <response_column>] [-c] <csv filename>\n" 
 15     "where -r <response_column> specified the 0-based index of the response (0 by default)\n" 
 16     "-c specifies that the response is categorical (it's ordered by default) and\n" 
 17     "<csv filename> is the name of training data file in comma-separated value format\n\n"); 
 18 } 
 19  
 20 int count_classes(CvMLData& data) { 
 21   cv::Mat r(data.get_responses()); 
 22   std::map<int, int> rmap; 
 23   int i, n = (int)r.total();
 24   for (i = 0; i < n; i++) {
 25     float val = r.at<float>(i); 
 26     int ival = cvRound(val); 
 27     if (ival != val) 
 28       return -1; 
 29     rmap[ival] = 1; 
 30   } 
 31   return rmap.size(); 
 32 } 
 33  
 34 void print_result(float train_err, float test_err, const CvMat* _var_imp) {
 35   printf("train error    %f\n", train_err); 
 36   printf("test error    %f\n\n", test_err); 
 37  
 38   if (_var_imp) { 
 39     cv::Mat var_imp(_var_imp), sorted_idx; 
 40     cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING);
 41  
 42     printf("variable importance:\n"); 
 43     int i, n = (int)var_imp.total();
 44     int type = var_imp.type(); 
 45     CV_Assert(type == CV_32F || type == CV_64F); 
 46  
 47     for (i = 0; i < n; i++) {
 48       int k = sorted_idx.at<int>(i); 
 49       printf("%d\t%f\n", k, type == CV_32F ? var_imp.at<float>(k) : var_imp.at<double>(k)); 
 50     } 
 51   } 
 52   printf("\n"); 
 53 } 
 54  
 55 int main(int argc, char** argv) {
 56   if (argc < 2) {
 57     help(); 
 58     return 0;
 59   } 
 60   const char* filename = 0;
 61   int response_idx = 0;
 62   bool categorical_response = false; 
 63  
 64   for (int i = 1; i < argc; i++) {
 65     if (strcmp(argv[i], "-r") == 0) 
 66       sscanf(argv[++i], "%d", &response_idx); 
 67     else if (strcmp(argv[i], "-c") == 0) 
 68       categorical_response = true; 
 69     else if (argv[i][0] != '-') 
 70       filename = argv[i]; 
 71     else { 
 72       printf("Error. Invalid option %s\n", argv[i]); 
 73       help(); 
 74       return -1; 
 75     } 
 76   } 
 77  
 78   printf("\nReading in %s...\n\n", filename); 
 79   CvDTree dtree; 
 80   CvBoost boost; 
 81   CvRTrees rtrees; 
 82   CvERTrees ertrees; 
 83   CvGBTrees gbtrees; 
 84  
 85   CvMLData data; 
 86  
 87  
 88   CvTrainTestSplit spl(0.5f); 
 89  
 90   if (data.read_csv(filename) == 0) {
 91     data.set_response_idx(response_idx);
 92     if (categorical_response) 
 93       data.change_var_type(response_idx, CV_VAR_CATEGORICAL); 
 94     data.set_train_test_split(&spl);
 95  
 96     printf("======DTREE=====\n"); 
 97     dtree.train(&data, CvDTreeParams(10, 2, 0, false, 16, 0, false, false, 0));
 98     print_result(dtree.calc_error(&data, CV_TRAIN_ERROR), dtree.calc_error(&data, CV_TEST_ERROR), dtree.get_var_importance());
 99  
100     if (categorical_response && count_classes(data) == 2) {
101       printf("======BOOST=====\n"); 
102       boost.train(&data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
103       print_result(boost.calc_error(&data, CV_TRAIN_ERROR), boost.calc_error(&data, CV_TEST_ERROR), 0); //doesn't compute importance
104     } 
105  
106     printf("======RTREES=====\n"); 
107     rtrees.train(&data, CvRTParams(10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER));
108     print_result(rtrees.calc_error(&data, CV_TRAIN_ERROR), rtrees.calc_error(&data, CV_TEST_ERROR), rtrees.get_var_importance());
109  
110     printf("======ERTREES=====\n"); 
111     ertrees.train(&data, CvRTParams(10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER));
112     print_result(ertrees.calc_error(&data, CV_TRAIN_ERROR), ertrees.calc_error(&data, CV_TEST_ERROR), ertrees.get_var_importance());
113  
114     printf("======GBTREES=====\n"); 
115     gbtrees.train(&data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.05f, 0.6f, 10, true)); 
116     print_result(gbtrees.calc_error(&data, CV_TRAIN_ERROR), gbtrees.calc_error(&data, CV_TEST_ERROR), 0); //doesn't compute importance
117   } else 
118     printf("File can not be read"); 
119  
120   return 0;
121 }


Machine Learning은 주로 Java나 Python에서 제공하는 라이브러리로 작업하는 경우가 많지만 저를 포함하여 C++ 라이브러리를 사용하는 사람들도 있을 수 있어서 참고자료를 올려놓습니다. 

CSV 포맷의 파일을 받아서 여러 가지 트리 구조에서 결과를 찍어내는 소스코드입니다.