Press "Enter" to skip to content

Opencv学习笔记 – 使用opencvsharp和决策树进行训练和预测

一、决策树

 

决策树是最早的机器学习算法之一,起源于对人类某些决策过程的模仿,属于监督学习算法。决策树的优点是易于理解,有些决策树既可以做分类,也可以做回归。在排名前十的数据挖掘算法中有两种是决策树[1]。决策树有许多不同版本,典型版本是最早出现的ID3算法,以及对其改进后形成的C4.5算法,这两种算法可用于分类。对ID3算法改进的另一个分支为“分类和回归树”(Classification AndRegression Trees,CART)算法,可用于分类或回归。CART算法为随机森林和Boosting等重要算法提供了基础。在OpenCV中,决策树实现的是CART算法。

 

1、决策树的核心问题

 

在决策树中,通常将样本向量中的特征称为样本的属性,下文将使用“属性”这一习惯性称呼。决策树通过把样本从根节点排列到某个叶子节点对样本进行分类。根节点是树第一次切分的位置,叶子节点即为样本所属的分类标签。树上的每一个节点都表示了对样本的某个属性的测试,并且该节点的每一个后继分支对应于该属性的一个可能值。分类样本的方法是从这棵树的根节点开始,测试这个节点指定的属性,然后按照给定样本的该属性值对应的树枝向下移动。这一过程在以新节点为根的子树上重复[2]。

 

决策树的核心问题是:自顶向下的各个节点应选择何种属性进行切分,才能获得最好的分类器?因此,选择最佳切分属性是决策树的关键所在。

 

2、最佳切分属性的选择

 

(1)信息熵

 

(2)信息增益

 

(3)信息增益率

 

(4)基尼系数

 

(5)均方误差

 

3、停止标准

 

决策树遵循贪婪的递归分裂节点,它们是如何停止又是何时停止的呢?实际上,可以应用许多策略来定义停止标准(Stopping Criteria)。最常见的是数据点的最小数量,如果进一步切分会违反此约束,则停止切分。另一个停止标准是树的深度。停止标准与其他参数一起可以帮助我们实现具有较好泛化能力的决策树模型。非常深或具有太多非叶子节点的决策树通常会导致过拟合。

 

4、剪枝

 

由于决策树的建立完全依赖训练样本,因此算法很容易对训练集过拟合,导致泛化能力变差。为了解决过拟合问题,需要对决策树进行剪枝(Pruning),即去掉一些节点,包括叶子节点和中间节点,以简化决策树。剪枝类似于线性回归的正则化,可以增加决策树的泛化能力。

 

剪枝的常用方法有预剪枝和后剪枝两种。预剪枝是在构建决策树过程中,提前终止决策树的生长,从而避免产生过多节点。该方法简单,但实用性不强,因为很难准确判断应何时终止生长。后剪枝是在决策树构建完成后再去掉一些节点。常见的后剪枝方法有悲观错误剪枝、最小错误剪枝、代价复杂度剪枝和基于错误的剪枝四种。OpenCV中的CART算法采用的是代价复杂度剪枝,即先生成决策树,然后生成所有可能的剪枝后的CART树,最后使用交叉验证来检验各种剪枝的效果,选择泛化能力最好的剪枝策略。

 

二、OpenCV函数说明

 

cv::ml::DTrees类表示单个决策树或决策树集合,它是RTrees和Boost的基类。

 

1、创建空决策树

 

cv::ml::DTrees::create函数可使用指定的参数创建空决策树;

 

之后使用 cv::ml::StatModel::train函数训练该决策树模型;

 

或者使用Algorithm::load<DTrees> (filename)从文件中加载决策树模型。

 

2、模型的基本设置

 

1)参数MaxDepth:树的最大可能深度

 

2)参数MinSampleCount:节点中的最小样本数

 

3)参数RegressionAccuracy:回归树的终止标准

 

4)参数MaxCategories:表示特征属性为类形式时最大类的数量

 

5)参数CVFolds:使用k折叠交叉验证剪枝时的交叉验证折数

 

6)参数Use1SERule:应用1SE规则剪枝标志位

 

7)参数TruncatePrunedTree:截断剪枝树标志位

 

8)参数priors:先验类概率Mat数组

 

9)参数UseSurrogates:是否构建代理切分标志位

 

3、训练决策树

 

 

函数参数:

 

◎ samples:训练集。

 

◎ layout:指定训练集的样本排列方式。具体如下:• ROW_SAMPLE:每个训练样本占一行的训练集。• COL_SAMPLE:每个训练样本占一列的训练集。

 

◎ responses与训练集样本排列顺序对应的标签向量。

 

三、mushroom数据集

 

数据集下载地址:

 

UCI Machine Learning Repository: Mushroom Data Set https://archive.ics.uci.edu/ml/datasets/mushroom 下面使用UCI数据集中的蘑菇可食用数据集[5]——Mushroom数据集,介绍决策树的分类应用。

 

Mushroom数据集包括以下主要文件。

 

agaricus-lepiota.data:样本数据。取前5个样本如下图看一下(第一列是标签,可食用(e)或有毒(p)),其它列都是数据

 

 

agaricus-lepiota.names:数据集概述与样本属性描述。

 

Index:数据集文件索引。

 

三、opencv(c++)代码参考

 

Learning-OpenCV-3_examples/example_21-01.cpp at 4fe1f6c8bb477e4393ea3cd94749441d93f9b3dd · oreillymedia/Learning-OpenCV-3_examples · GitHub https://github.com/oreillymedia/Learning-OpenCV-3_examples/blob/4fe1f6c8bb477e4393ea3cd94749441d93f9b3dd/example_21-01.cpp 代码是从上面官方git直接copy下来的,opencv3和opencv4决策树部分没什幺差别。

 

//Example 21-1. Creating and training a decision tree
#include <opencv2/opencv.hpp>
#include <stdio.h>
#include <iostream>
using namespace std;
using namespace cv;
void help(char **argv) {
  cout << "
"
       << "Using binary decision trees to learn to recognize poisonous
"
       << "    from edible mushrooms based on visible attributes.
" 
       << "    This program demonstrates how to create and a train a 
"
       << "    decision tree using ml library in OpenCV.
"
       << "Call:
" << argv[0] << " <csv-file-path>
"
       << "
If you don't enter a file, it defaults to ../mushroom/agaricus-lepiota.data
"
       << endl;
}
int main(int argc, char *argv[]) {
  // If the caller gave a filename, great. Otherwise, use a default.
  //
  const char *csv_file_name = argc >= 2 ? argv[1] : "../mushroom/agaricus-lepiota.data";
  cout << "OpenCV Version: " << CV_VERSION << endl;
  help(argv);
  // Read in the CSV file that we were given.
  //
  cv::Ptr<cv::ml::TrainData> data_set =
      cv::ml::TrainData::loadFromCSV(csv_file_name, // Input file name
                                     0, // Header lines (ignore this many)
                                     0, // Responses are (start) at thie column
                                     1, // Inputs start at this column
                                     "cat[0-22]" // All 23 columns are categorical
                                     );
  // Use defaults for delimeter (',') and missch ('?')
  // Verify that we read in what we think.
  //
  int n_samples = data_set->getNSamples();
  if (n_samples == 0) {
    cerr << "Could not read file: " << csv_file_name << endl;
    exit(-1);
  } else {
    cout << "Read " << n_samples << " samples from " << csv_file_name << endl;
  }
  // Split the data, so that 90% is train data
  //
  data_set->setTrainTestSplitRatio(0.90, false);
  int n_train_samples = data_set->getNTrainSamples();
  int n_test_samples = data_set->getNTestSamples();
  cout << "Found " << n_train_samples << " Train Samples, and "
       << n_test_samples << " Test Samples" << endl;
  // Create a DTrees classifier.
  //
  cv::Ptr<cv::ml::RTrees> dtree = cv::ml::RTrees::create();
  // set parameters
  //
  // These are the parameters from the old mushrooms.cpp code
  // Set up priors to penalize "poisonous" 10x as much as "edible"
  //
  float _priors[] = {1.0, 10.0};
  cv::Mat priors(1, 2, CV_32F, _priors);
  dtree->setMaxDepth(8);
  dtree->setMinSampleCount(10);
  dtree->setRegressionAccuracy(0.01f);
  dtree->setUseSurrogates(false /* true */);
  dtree->setMaxCategories(15);
  dtree->setCVFolds(0 /*10*/); // nonzero causes core dump
  dtree->setUse1SERule(true);
  dtree->setTruncatePrunedTree(true);
  // dtree->setPriors( priors );
  dtree->setPriors(cv::Mat()); // ignore priors for now...
  // Now train the model
  // NB: we are only using the "train" part of the data set
  //
  dtree->train(data_set);
  // Having successfully trained the data, we should be able
  // to calculate the error on both the training data, as well
  // as the test data that we held out.
  //
  cv::Mat results;
  float train_performance = dtree->calcError(data_set,
                                             false, // use train data
                                             results // cv::noArray()
                                             );
  std::vector<cv::String> names;
  data_set->getNames(names);
  Mat flags = data_set->getVarSymbolFlags();
  // Compute some statistics on our own:
  //
  {
    cv::Mat expected_responses = data_set->getResponses();
    int good = 0, bad = 0, total = 0;
    for (int i = 0; i < data_set->getNTrainSamples(); ++i) {
      float received = results.at<float>(i, 0);
      float expected = expected_responses.at<float>(i, 0);
      cv::String r_str = names[(int)received];
      cv::String e_str = names[(int)expected];
      cout << "Expected: " << e_str << ", got: " << r_str << endl;
      if (received == expected)
        good++;
      else
        bad++;
      total++;
    }
    cout << "Correct answers: " <<(float(good)/total) <<" % " << endl;
                cout << "Incorrect answers: " << (float(bad) / total) << "%"
         << endl;
  }
  float test_performance = dtree->calcError(data_set,
                                            true, // use test data
                                            results // cv::noArray()
                                            );
  cout << "Performance on training data: " << train_performance << "%" << endl;
  cout << "Performance on test data: " <<test_performance <<" % " <<endl;
  return 0;
}

 

四、opencvsharp(c#)代码参考

 

opencvsharp很多相关方法都没有进行实现,比如没有实现TrainData::loadcsv等,数据需要自行处理。

 

代码内的DTrees都可以替换为Rtrees。

 

代码里面写死了一些数量,8124条数据,取了前8000条进行训练,后124条进行测试了。

 

数据集第一列是标签,其余列是数据。

 

数据从文件内读取后后要对应ascii(可以百度码表直观的对比)转为相应数字(System.Text.Encoding.ASCII),才能传到方法内。

 

1、读取数据训练并保存模型

 

//读取数据
int[,] att = GetTArray(@"C:\Users\xiaomao\Desktop\mushroom数据集\agaricus-lepiota.train.data");
//读取标签,多冗余读取了一次文件,主要是为了区分清楚步骤
int[] label = GetTLabel(@"C:\Users\xiaomao\Desktop\mushroom数据集\agaricus-lepiota.train.data");
InputArray array = InputArray.Create(att);
InputArray outarray = InputArray.Create(label);
//创建决策树,这里也可以用RTrees,据说比Dtrees效果更好
OpenCvSharp.ML.DTrees dtrees = OpenCvSharp.ML.DTrees.Create();
dtrees.MaxDepth = 8;
dtrees.MinSampleCount = 10;
dtrees.RegressionAccuracy = 0.01f;
dtrees.UseSurrogates = false;
dtrees.MaxCategories = 15;
dtrees.CVFolds = 0;
dtrees.Use1SERule = true;
dtrees.TruncatePrunedTree = true;
//进行训练
dtrees.Train(array, OpenCvSharp.ML.SampleTypes.RowSample, outarray);
//保存模型
dtrees.Save(@"C:\Users\xiaomao\Desktop\1.xml");

 

2、读取数据的相关代码

 

//从mushroom数据集读取数据,数据集一共8124条数据,取8000条进行训练
public int[,] GetTArray(string filepath)
{
    int[,] att = new int[8000, 22];
    using (StreamReader sin = new StreamReader(new FileStream(filepath, FileMode.Open, FileAccess.Read, FileShare.Read)))
    {
        int pos = 0;
        for (string str = sin.ReadLine(); str != null; str = sin.ReadLine())
        {
            //分割每行数据
            string[] temp = str.Split(',');
            for(int i=1; i< temp.Length; i++)
            {
                //这里是把字母转ascii码,否则InputArray函数及train都不允许
                //实际上就是标准化数据
                //第一列不要
                att[pos, i-1] = System.Text.Encoding.ASCII.GetBytes(temp[i])[0];
            }
            pos++;
        }
    }
    return att;
}
//读取标签,就是从数据集中读取第一列
public int[] GetTLabel(string filepath)
{
    int[] att = new int[8000];
    using (StreamReader sin = new StreamReader(new FileStream(filepath, FileMode.Open, FileAccess.Read, FileShare.Read)))
    {
        int pos = 0;
        for (string str = sin.ReadLine(); str != null; str = sin.ReadLine())
        {
            string[] temp = str.Split(',');
            att[pos] = System.Text.Encoding.ASCII.GetBytes(temp[0])[0];
            pos++;
        }
    }
    return att;
}
//读取测试数据到list
public List<int[]> GetTestArray(string filepath)
{
    List<int[]> att = new List<int[]>();
    using (StreamReader sin = new StreamReader(new FileStream(filepath, FileMode.Open, FileAccess.Read, FileShare.Read)))
    {
        for (string str = sin.ReadLine(); str != null; str = sin.ReadLine())
        {
            string[] temp = str.Split(',');
            int[] vs = new int[temp.Length];
            for (int i = 1; i < temp.Length; i++)
            {
                vs[i-1] = System.Text.Encoding.ASCII.GetBytes(temp[i])[0];
            }
            att.Add(vs);
        }
    }
    return att;
}

 

3、加载模型并预测

 

OpenCvSharp.ML.DTrees tree = OpenCvSharp.ML.DTrees.Load(@"C:\Users\xiaomao\Desktop\1.xml");
List<int[]> att = GetTestArray(@"C:\Users\xiaomao\Desktop\mushroom数据集\agaricus-lepiota.test.data");
//循环进行预测,应该可以批量预测,没进行测试
for(int i=0;i <att.Count; i++)
{
    Mat p = new Mat(1, 22, OpenCvSharp.MatType.CV_32F, att[i]);
    float rrr = tree.Predict(p);
    System.Console.WriteLine("" + rrr);
}

 

预测结果如下,112对应小写的p,101对应小写的e,就是数据的第一列。

 

 

五、相关源码参考

 

opencv/tree_engine.cpp at 4.x · opencv/opencv · GitHub Open Source Computer Vision Library. Contribute to opencv/opencv development by creating an account on GitHub. https://github.com/opencv/opencv/blob/4.x/samples/cpp/tree_engine.cpp OpenCV: cv::ml::DTrees Class Reference https://docs.opencv.org/4.5.4/d8/d89/classcv_1_1ml_1_1DTrees.html#details

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注