OpenCV Machine Learning(Tree 구조) 참고자료
1 #include "opencv2/ml/ml.hpp" 2 #include "opencv2/core/core_c.h" 3 #include <stdio.h> 4 #include <map> 5 6 void help() { 7 printf( 8 "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees:\n" 9 "CvDTree dtree;\n" 10 "CvBoost boost;\n" 11 "CvRTrees rtrees;\n" 12 "CvERTrees ertrees;\n" 13 "CvGBTrees gbtrees;\n" 14 "Call:\n\t./tree_engine [-r <response_column>] [-c] <csv filename>\n" 15 "where -r <response_column> specified the 0-based index of the response (0 by default)\n" 16 "-c specifies that the response is categorical (it's ordered by default) and\n" 17 "<csv filename> is the name of training data file in comma-separated value format\n\n"); 18 } 19 20 int count_classes(CvMLData& data) { 21 cv::Mat r(data.get_responses()); 22 std::map<int, int> rmap; 23 int i, n = (int)r.total(); 24 for (i = 0; i < n; i++) { 25 float val = r.at<float>(i); 26 int ival = cvRound(val); 27 if (ival != val) 28 return -1; 29 rmap[ival] = 1; 30 } 31 return rmap.size(); 32 } 33 34 void print_result(float train_err, float test_err, const CvMat* _var_imp) { 35 printf("train error %f\n", train_err); 36 printf("test error %f\n\n", test_err); 37 38 if (_var_imp) { 39 cv::Mat var_imp(_var_imp), sorted_idx; 40 cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING); 41 42 printf("variable importance:\n"); 43 int i, n = (int)var_imp.total(); 44 int type = var_imp.type(); 45 CV_Assert(type == CV_32F || type == CV_64F); 46 47 for (i = 0; i < n; i++) { 48 int k = sorted_idx.at<int>(i); 49 printf("%d\t%f\n", k, type == CV_32F ? var_imp.at<float>(k) : var_imp.at<double>(k)); 50 } 51 } 52 printf("\n"); 53 } 54 55 int main(int argc, char** argv) { 56 if (argc < 2) { 57 help(); 58 return 0; 59 } 60 const char* filename = 0; 61 int response_idx = 0; 62 bool categorical_response = false; 63 64 for (int i = 1; i < argc; i++) { 65 if (strcmp(argv[i], "-r") == 0) 66 sscanf(argv[++i], "%d", &response_idx); 67 else if (strcmp(argv[i], "-c") == 0) 68 categorical_response = true; 69 else if (argv[i][0] != '-') 70 filename = argv[i]; 71 else { 72 printf("Error. Invalid option %s\n", argv[i]); 73 help(); 74 return -1; 75 } 76 } 77 78 printf("\nReading in %s...\n\n", filename); 79 CvDTree dtree; 80 CvBoost boost; 81 CvRTrees rtrees; 82 CvERTrees ertrees; 83 CvGBTrees gbtrees; 84 85 CvMLData data; 86 87 88 CvTrainTestSplit spl(0.5f); 89 90 if (data.read_csv(filename) == 0) { 91 data.set_response_idx(response_idx); 92 if (categorical_response) 93 data.change_var_type(response_idx, CV_VAR_CATEGORICAL); 94 data.set_train_test_split(&spl); 95 96 printf("======DTREE=====\n"); 97 dtree.train(&data, CvDTreeParams(10, 2, 0, false, 16, 0, false, false, 0)); 98 print_result(dtree.calc_error(&data, CV_TRAIN_ERROR), dtree.calc_error(&data, CV_TEST_ERROR), dtree.get_var_importance()); 99 100 if (categorical_response && count_classes(data) == 2) { 101 printf("======BOOST=====\n"); 102 boost.train(&data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0)); 103 print_result(boost.calc_error(&data, CV_TRAIN_ERROR), boost.calc_error(&data, CV_TEST_ERROR), 0); //doesn't compute importance 104 } 105 106 printf("======RTREES=====\n"); 107 rtrees.train(&data, CvRTParams(10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER)); 108 print_result(rtrees.calc_error(&data, CV_TRAIN_ERROR), rtrees.calc_error(&data, CV_TEST_ERROR), rtrees.get_var_importance()); 109 110 printf("======ERTREES=====\n"); 111 ertrees.train(&data, CvRTParams(10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER)); 112 print_result(ertrees.calc_error(&data, CV_TRAIN_ERROR), ertrees.calc_error(&data, CV_TEST_ERROR), ertrees.get_var_importance()); 113 114 printf("======GBTREES=====\n"); 115 gbtrees.train(&data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.05f, 0.6f, 10, true)); 116 print_result(gbtrees.calc_error(&data, CV_TRAIN_ERROR), gbtrees.calc_error(&data, CV_TEST_ERROR), 0); //doesn't compute importance 117 } else 118 printf("File can not be read"); 119 120 return 0; 121 }
Machine Learning은 주로 Java나 Python에서 제공하는 라이브러리로 작업하는 경우가 많지만 저를 포함하여 C++ 라이브러리를 사용하는 사람들도 있을 수 있어서 참고자료를 올려놓습니다.
CSV 포맷의 파일을 받아서 여러 가지 트리 구조에서 결과를 찍어내는 소스코드입니다.
댓글
이 글 공유하기
다른 글
-
[C/C++]Binary 파일 줄 바꿈 방법
[C/C++]Binary 파일 줄 바꿈 방법
2014.11.03 -
MySQL 회원 정보 테이블 생성
MySQL 회원 정보 테이블 생성
2014.10.26 -
Run-Time Check Failure #2 – Stack around the variable 'x' was corrupted
Run-Time Check Failure #2 – Stack around the variable 'x' was corrupted
2014.08.08 -
std::cout 속성 정리
std::cout 속성 정리
2013.07.04