编辑推荐: |
本文讲解训练数据集,SVM参数设定,函数作用:判断sample的类别等等希望对您有所帮助。
本文来自于知乎,由火龙果软件Delores编辑、推荐。 |
|
OpenCV中集成了多种机器学习算法供我们方便使用,如果我们要训练数据进行分类,不用自己写分类器,只需要调用相应的库和类即可轻松实现。本文重点不在于介绍机器学习原理及数学推导,着重介绍OpenCV中的机器学习相关函数,并且用十分简单的训练数据作为例子实现分类。
对于OpenCV的机器学习分类器,大多换汤不换药,构造方法和实现方法很类似,基本遵循原始数据—训练分类器—进行分类的步骤,某些算法可能有特殊的初始化参数,需要额外设置
在实现任何分类器之前,都需要训练数据。插句题外话,训练数据的好坏是一个分类器成功与否的决定性条件,数据选取永远凌驾于分类器算法选取之上,如果训练数据选取得当,无论使用任何算法都会得到不错的效果,反之如果训练数据选取不当,分类算法是无法弥补的。在此我们使用简单的二维数据作为训练数据,其标号分别为1和-1,我们用图像来直观的表示:
//设定800*800的二维坐标平面区域
int width = 800, height = 800;
Mat I = Mat::zeros(height, width, CV_8UC3);
//训练数据集,前10个标记为1,后10个标记为-1
float trainingData[20][2] =
{ { 100, 100 }, { 200, 100 }, { 400, 100 },
{ 200, 200 }, { 500, 200 },
{ 100, 300 }, { 300, 300 }, { 400, 300 }, {
100, 400 }, { 200, 500 },
{ 600, 600 }, { 700, 300 }, { 700, 300 }, {
400, 500 }, { 600, 500 },
{ 200, 700 }, { 300, 600 }, { 500, 600 }, {
600, 300 }, { 400, 700 } };
//训练数据集存入矩阵
Mat trainingDataMat(20, 2, CV_32FC1, trainingData);
//训练数据标记,前10个标记为1,后10个标记为-1
float labels[20] =
{ 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0,
-1.0, -1.0, -1.0, -1.0, -1.0,
-1.0, -1.0, -1.0, -1.0, -1.0 };
//训练数据标记存入矩阵
Mat labelsMat(20, 1, CV_32FC1, labels);
//将训练数据用不同颜色画出:1为绿色,-1为蓝色
for (int i = 0; i < 20; i++)
{
if (labels[i] == 1.0)
circle(I, Point(trainingData[i][0], trainingData[i][1]),
2, Scalar(255, 0, 0), 2);
else
circle(I, Point(trainingData[i][0], trainingData[i][1]),
2, Scalar(0, 255, 0), 2);
}
imshow("dataset", I); |
注意训练数据集矩阵类型一定是CV_32FC1型,长宽分别为数据个数和维度(20个训练数据,2维);训练数据标记矩阵是一维向量,也建议使用CV_32FC1型,还可用CV_32SC1型,长度为数据个数,要和训练数据一一对应(如例子中前10个数据标记为1,后10个数据标记为-1)
接下来是SVM参数设定,建议设定方法是初始化一个空类,需要什么参数单独设定,具体如下:
CvSVMParams params;
params.svm_type = CvSVM::C_SVC;
params.kernel_type = CvSVM::LINEAR;
params.term_crit = cvTermCriteria (CV_TERMCRIT_ITER,
100, FLT_EPSILON); |
其中,CvSVMParams可设置的参数有:(具体分类涉及SVM数学原理,不进行展开)
int svm_type:用来设定SVM的类型,分为C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104这5种,通常使用C_SVC=100作为一般的SVM分类器
int kernel_type:用来设定SVM所用核函数类型,分为LINEAR=0, POLY=1, RBF=2, SIGMOID=3这三种,文中训练数据分类较为简单,用线性核LINEAR=0即可
double degree:用来设定多项式内核函数(POLY=1)的幂次
double gamma:用来设定内核函数(POLY/ RBF/ SIGMOID)的参数gamma(多项式系数)
double coef0:用来设定内核函数(POLY/ RBF/ SIGMOID)的参数coef0(常数项)
double C、double nu、double p、CvMat* class_weights:用来设定非C_SVC=100类型的相应参数
CvTermCriteria term_crit:用来设定SVM迭代终止条件,其构造类型为(int type, int max_iter, double expsolon),三个参数分别意为结束方式(迭代次数为基准的CV_TERMCRIT_ITER或误差值为基准的CV_TERMCRIT_EPS),最大迭代次数,最小误差值
综上所述,文中SVM参数设置为:一般SVM分类器,线性核,循环终止,100次循环,最小误差值为定义FLT_EPSILON(1.192092896e-07F)。
设置完参数后,就该是SVM训练了,由于类初始化需要CvMat*数据类型,依旧建议初始化一个空类,需要什么参数用函数添加,具体如下:
CvSVM SVM;
SVM.train(trainingDataMat, labelsMat, Mat(), Mat(),
params); |
svm.train即为训练函数,其参数为
bool train( const
cv::Mat& trainData, const cv::Mat& responses,
const cv::Mat& varIdx=cv::Mat(), const cv::Mat&
sampleIdx=cv::Mat(),
CvSVMParams params=CvSVMParams() ) |
const cv::Mat& trainData:训练数据集,前文设定20*2的Mat trainingDataMat,再次提醒格式一定是CV_32FC1
const cv::Mat& responses:响应数据,即前文的训练数据标记,20*1的向量Mat labelsMat,格式最好也是CV_32FC1
const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat():两个参数表示感兴趣的特征和样本,如没有感兴趣对象则设为空矩阵Mat()即可
CvSVMParams params=CvSVMParams():SVM参数设定,即前文设定的CvSVMParams params
运行完train后,样本训练过程结束,可用SVM.predict()函数进行分类,用SVM.get_support_vector_count()函数和SVM.get_support_vector()函数查看支持向量,下面分别介绍三个函数:
float predict(
const cv::Mat& sample, bool returnDFVal=false
) const |
函数作用:判断sample的类别
参数const Mat& sample:待分类向量,文中训练数据是二维数据,因此待分类向量应是1*2的Mat矩阵,数据类型应为float型(CV_32F)
参数bool returnDFVal=false:判断是否为二分类器,通常情况下不用设定,默认false即可
返回值:const Mat& sample的分类结果,文中返回值应为前文设定的训练数据标记种类1或-1
简单例子:
float temp[2]
= { i, j };
Mat sampleMat(1, 2, CV_32F, temp);
float response = SVM.predict(sampleMat); |
int get_support_vector_count()
const
const float* get_support_vector(int i) const |
两个函数作用是获得支持向量,通常需要结合使用。int get_support_vector_count()得到支持向量个数,将结果遍历带入float* get_support_vector(int i)的参数i便可获得每个支持向量
简单例子:
int c = SVM.get_support_vector_count();
for (int i = 0; i < c; ++i)
{const float* v = SVM.get_support_vector(i);} |
SVM函数大体如此,完整代码及注释:
#include <iostream>
#include <opencv.hpp>
using namespace std;
using namespace cv;
void main()
{
//设定800*800的二维坐标平面区域
int width = 800, height = 800;
Mat I = Mat::zeros(height, width, CV_8UC3);
//训练数据集,前10个标记为1,后10个标记为-1
float trainingData[20][2] =
{ { 100, 100 }, { 200, 100 }, { 400, 100 },
{ 200, 200 }, { 500, 200 },
{ 100, 300 }, { 300, 300 }, { 400, 300 }, {
100, 400 }, { 200, 500 },
{ 600, 600 }, { 700, 300 }, { 700, 300 }, {
400, 500 }, { 600, 500 },
{ 200, 700 }, { 300, 600 }, { 500, 600 }, {
600, 300 }, { 400, 700 } };
//训练数据集存入矩阵
Mat trainingDataMat(20, 2, CV_32FC1, trainingData);
//训练数据标记,前10个标记为1,后10个标记为-1
float labels[20] =
{ 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0,
-1.0, -1.0, -1.0, -1.0, -1.0,
-1.0, -1.0, -1.0, -1.0, -1.0 };
//训练数据标记存入矩阵
Mat labelsMat(20, 1, CV_32FC1, labels);
//将训练数据用不同颜色画出:1为绿色,-1为蓝色
for (int i = 0; i < 20; i++)
{
if (labels[i] == 1.0)
circle(I, Point(trainingData[i][0], trainingData[i][1]),
2, Scalar(255, 0, 0), 2);
else
circle(I, Point(trainingData[i][0], trainingData[i][1]),
2, Scalar(0, 255, 0), 2);
}
imshow("dataset", I);
//SVM参数设置
CvSVMParams params;
params.svm_type = CvSVM::C_SVC;
params.kernel_type = CvSVM::LINEAR;
params.term_crit = cvTermCriteria (CV_TERMCRIT_ITER,
100, FLT_EPSILON);
//SVM训练
CvSVM SVM;
SVM.train(trainingDataMat, labelsMat, Mat(),
Mat(), params);
//SVM分类结果显示:1区域为绿色,-1区域为蓝色
for (int i = 0; i < I.rows; ++i)
for (int j = 0; j < I.cols; ++j)
{
float temp[2] = { i, j };
Mat sampleMat(1, 2, CV_32FC1, temp);
float response = SVM.predict(sampleMat);
if (response == 1)
I.at<Vec3b>(j, i) = Vec3b(255, 0, 0);
else if (response == -1)
I.at<Vec3b>(j, i) = Vec3b(0, 255, 0);
}
for (int i = 0; i < 20; i++)
{
if (labels[i] == 1.0)
circle(I, Point(trainingData[i][0], trainingData[i][1]),
2, Scalar(255, 255, 255), 2);
else
circle(I, Point(trainingData[i][0], trainingData[i][1]),
2, Scalar(0, 0, 0), 2);
}
//支持向量标注,用红圈圈出
int c = SVM.get_support_vector_count();
for (int i = 0; i < c; ++i)
{
const float* v = SVM.get_support_vector(i);
circle(I, Point((int)v[0], (int)v[1]), 6, Scalar(0,
0, 255), 2, 8);
}
imshow("result", I);
waitKey();
} |
分类结果:
|