Press "Enter" to skip to content

Java Spark ML实现的文本分类

 

文本分类是指将一篇文章归到事先定义好的某一类或者某几类,在数据平台的一个典型的应用场景是,通过爬取用户浏览过的页面内容,识别出用户的浏览偏好,从而丰富该用户的画像。

 

本文介绍使用Spark MLlib提供的朴素贝叶斯(Naive Bayes)算法,完成对中文文本的分类过程。主要包括中文分词、文本表示(TF-IDF)、模型训练、分类预测等。

 

特征工程

 

文本处理

 

对于中文文本分类,需要先对内容进行分词,我使用的是ansj中文分析工具,其中自己可以配置扩展词库来使分词结果更合理,同时可以加一些停用词可以提高准确率,需要把数据样本分割成两批数据,一份用于训练模型,一份用于测试模型效果。

 

目录结构

 

DataFactory.java

 

package com.maweiming.spark.mllib.classifier;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.maweiming.spark.mllib.utils.AnsjUtils;
import com.maweiming.spark.mllib.utils.FileUtils;
import org.apache.commons.lang3.StringUtils;
import java.io.File;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
 * 1、first step
 * data format
 * Created by Coder-Ma on 2017/6/12.
 */public class DataFactory {
 
    public static final String CLASS_PATH = "/Users/coderma/coders/github/SparkTextClassifier/src/main/resources";
    public static final String STOP_WORD_PATH = CLASS_PATH + "/data/stopWord.txt";
    public static final String NEWS_DATA_PATH = CLASS_PATH + "/data/NewsData";
    public static final String DATA_TRAIN_PATH = CLASS_PATH + "/data/data-train.txt";
    public static final String DATA_TEST_PATH = CLASS_PATH + "/data/data-test.txt";
    public static final String MODELS = CLASS_PATH + "/models";
    public static final String MODEL_PATH = CLASS_PATH + "/models/category-4";
    public static final String LABEL_PATH = CLASS_PATH + "/models/labels.txt";
    public static final String TF_PATH = CLASS_PATH + "/models/tf";
    public static final String IDF_PATH = CLASS_PATH + "/models/idf";
    public static void main(String[] args) {
 
        /**
         * 收集数据、特征工程
         * 1、遍历数据样本目录
         * 2、对数据进行清洗,剔除掉停用词
         */        //数据样本切割比例 80%用于训练样本,20%数据用于测试模型准确率
        Double spiltRate = 0.8;
        //停用词
        List<String> stopWords = FileUtils.readLine(line -> line, STOP_WORD_PATH);
        //分类标签(标签id,分类名)
        Map<Integer, String> labels = new HashMap<>();
        Integer dirIndex = 0;
        String[] dirNames = new File(NEWS_DATA_PATH).list();
        for (String dirName : dirNames) {
 
            dirIndex++;
            labels.put(dirIndex, dirName);
            String fileDirPath = String.format("%s/%s", NEWS_DATA_PATH, dirName);
            String[] fileNames = new File(fileDirPath).list();
            //当前分类目录的样本总数 * 切割比率
            int spilt = Double.valueOf(fileNames.length * spiltRate).intValue();
            for (int i = 0; i < fileNames.length; i++) {
 
                String fileName = fileNames[i];
                String filePath = String.format("%s/%s", fileDirPath, fileName);
                System.out.println(filePath);
                String text = FileUtils.readFile(filePath);
                for (String stopWord : stopWords) {
 
                    text = text.replaceAll(stopWord, "");
                }
                if (StringUtils.isBlank(text)) {
 
                    continue;
                }
                //把文本内容进行分词
                List<String> wordList = AnsjUtils.participle(text);
                JSONObject data = new JSONObject();
                data.put("text", wordList);
                data.put("category", Double.valueOf(dirIndex));
                if (i > spilt) {
 
                    //测试数据
                    FileUtils.appendText(DATA_TEST_PATH, data.toJSONString() + "
");
                } else {
 
                    //训练数据
                    FileUtils.appendText(DATA_TRAIN_PATH, data.toJSONString() + "
");
                }
            }
        }
        FileUtils.writer(LABEL_PATH, JSON.toJSONString(labels));//data labels
        System.out.println("Data processing successfully !");
        System.out.println("=======================================================");
        System.out.println("trainData:" + DATA_TRAIN_PATH);
        System.out.println("testData:" + DATA_TEST_PATH);
        System.out.println("labes:" + LABEL_PATH);
        System.out.println("=======================================================");
    }
}

 

训练模型

 

词语特征值处理(TF-IDF)

 

分好词后,每一个词都作为一个特征,需要将中文词语转换成Double型来表示,通常使用该词语的TF-IDF值作为特征值,Spark提供了全面的特征抽取及转换的API,非常方便,详见http://spark.apache.org/docs/latest/ml-features.html

 

为原始属于设置标签,按照resource->NewsData目录下面文件夹索引区分。

 

 

    1. car

 

    1. game

 

    1. it

 

    1. military

 

 

这里将中文词语转换成INT型的Hashing算法,类似于Bloomfilter,下面的setNumFeatures(500000)表示将Hash分桶的数量设置为500000个,这个值默认为2的20次方,即1048576,可以根据你的词语数量来调整,一般来说,这个值越大,不同的词被计算为一个Hash值的概率就越小,数据也更准确,但需要消耗更大的内存,和Bloomfilter是一个道理。

 

然后就可以训练模型,下面代码

 

package com.maweiming.spark.mllib.classifier;
import com.maweiming.spark.mllib.utils.FileUtils;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.IDFModel;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import java.io.File;
import java.io.IOException;
/**
 * 2、The second step
 * Created by Coder-Ma on 2017/6/26.
 */public class NaiveBayesTrain {
 
    public static void main(String[] args) throws IOException {
 
        //1、创建一个SparkSession
        SparkSession spark = SparkSession.builder().appName("NaiveBayes").master("local")
                .getOrCreate();
        //2、加载训练数据样本
        Dataset<Row> train = spark.read().json(DataFactory.DATA_TRAIN_PATH);
        //3、通过tf-idf计算数据样本中的词频
        //word frequency count
        HashingTF hashingTF = new HashingTF().setNumFeatures(500000).setInputCol("text").setOutputCol("rawFeatures");
        Dataset<Row> featurizedData  = hashingTF.transform(train);
        //count tf-idf
        IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
        IDFModel idfModel = idf.fit(featurizedData);
        Dataset<Row> rescaledData = idfModel.transform(featurizedData);
        //4、把数据样本转换成向量
        JavaRDD<LabeledPoint> trainDataRdd = rescaledData.select("category", "features").javaRDD().map(v1 -> {
 
            Double category = v1.getAs("category");
            SparseVector features = v1.getAs("features");
            Vector featuresVector = Vectors.dense(features.toArray());
            return new LabeledPoint(Double.valueOf(category),featuresVector);
        });
        System.out.println("Start training...");
        //调用朴素贝叶斯算法,传入向量数据训练模型
        NaiveBayesModel model  = NaiveBayes.train(trainDataRdd.rdd());
        //save model
        model.save(spark.sparkContext(), DataFactory.MODEL_PATH);
        //save tf
        hashingTF.save(DataFactory.TF_PATH);
        //save idf
        idfModel.save(DataFactory.IDF_PATH);
        System.out.println("train successfully !");
        System.out.println("=======================================================");
        System.out.println("modelPath:"+DataFactory.MODEL_PATH);
        System.out.println("tfPath:"+DataFactory.TF_PATH);
        System.out.println("idfPath:"+DataFactory.IDF_PATH);
        System.outprintln("=======================================================");
    }
}

 

训练模型完成

 

train successfully !
=======================================================
modelPath:/Users/coderma/coders/github/SparkTextClassifier/src/main/resources/models/category-4
tfPath:/Users/coderma/coders/github/SparkTextClassifier/src/main/resources/models/tf
idfPath:/Users/coderma/coders/github/SparkTextClassifier/src/main/resources/models/idf
=======================================================

 

测试模型

 

package com.maweiming.spark.mllib.classifier;
import com.alibaba.fastjson.JSON;
import com.maweiming.spark.mllib.dto.Result;
import com.maweiming.spark.mllib.utils.AnsjUtils;
import com.maweiming.spark.mllib.utils.FileUtils;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDFModel;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*;
import java.io.File;
import java.text.DecimalFormat;
import java.util.*;
/**
 * 3、the third step
 * Created by Coder-Ma on 2017/6/26.
 */public class NaiveBayesTest {
 
    private static HashingTF hashingTF;
    private static IDFModel idfModel;
    private static NaiveBayesModel model;
    private static Map<Integer,String> labels = new HashMap<>();
    public static void main(String[] args) {
 
        SparkSession spark = SparkSession.builder().appName("NaiveBayes").master("local")
                .getOrCreate();
        //load tf file
        hashingTF = HashingTF.load(DataFactory.TF_PATH);
        //load idf file
        idfModel = IDFModel.load(DataFactory.IDF_PATH);
        //load model
        model = NaiveBayesModel.load(spark.sparkContext(), DataFactory.MODEL_PATH);
        //batch test
        batchTestModel(spark, DataFactory.DATA_TEST_PATH);
        //test a single
        testModel(spark,"最近这段时间,由于印度三哥可能有些膨胀,在边境地区总想“搞事情”,这也让不少人的目光集中到此。事实上,我国在与印度的交界处有一军事要地,只要解放军一抬高水位,那幺印军或就“不战而退”。它就是地处我国西藏与印度控制克什米尔交界的班公湖。
" +
                "
" +
                "
" +
                "众所周知,从古至今那些地处与军事险要易守难攻的形胜之地,都具有非常重要的军事意义。经常能左右一场战争的胜负。据悉,班公湖位于西藏自治区阿里地区日土县城西北。全长有600多公里,其中地处中国的有400多公里,地处与印度约有200公里。整体成东西走向,海拔在4000多米以上。湖水整体为淡水湖,但由于湖水在西段的淡水补给量的大方面建少,东西方向上交替不通畅,使西部的区域变成了咸水湖。于是便出现了一个有趣的现象,在东部的中国境内班公湖为淡水湖,在西部的印度境内,班公湖为咸水湖。
" +
                "
" +
                "
" +
                "而我军在于印度交界的班公湖区域有一个阀门,这个区域有着非常大的军事作用,而如果印军将部队部署在班公湖地区,我军只需打开阀门,抬高班公湖的东部水位。将他们的军事设施和军用要道给全部淹没。而印军的军事物资和后勤保障都将全部瘫痪,到时印度的军事部署都将全部不攻自破。
" +
                "
" +
                "
" +
                "而印度应该知道现代战争最为重要的便是后勤制度的保障,军事行动能否取得胜利,很大程度取决于后勤能否及时的跟上。而我军在班公湖地区地势上就有了绝对的军事优势,军用物资也可源源不断的运输上来,而印度却优势全无。而我国自古以来就是爱好和平的国家,人不犯我我不犯人。只希望印军能认清与我国军事力量的差距,不要盲目自信。
" +
                "
");
    }
    public static void batchTestModel(SparkSession sparkSession, String testPath) {
 
        Dataset<Row> test = sparkSession.read().json(testPath);
        //word frequency count
        Dataset<Row> featurizedData = hashingTF.transform(test);
        //count tf-idf
        Dataset<Row> rescaledData = idfModel.transform(featurizedData);
        List<Row> rowList = rescaledData.select("category", "features").javaRDD().collect();
        List<Result> dataResults = new ArrayList<>();
        for (Row row : rowList) {
 
            Double category = row.getAs("category");
            SparseVector sparseVector = row.getAs("features");
            Vector features = Vectors.dense(sparseVector.toArray());
            double predict = model.predict(features);
            dataResults.add(new Result(category, predict));
        }
        Integer successNum = 0;
        Integer errorNum = 0;
        for (Result result : dataResults) {
 
            if (result.isCorrect()) {
 
                successNum++;
            } else {
 
                errorNum++;
            }
        }
        DecimalFormat df = new DecimalFormat("######0.0000");
        Double result = (Double.valueOf(successNum) / Double.valueOf(dataResults.size())) * 100;
        System.out.println("batch test");
        System.out.println("=======================================================");
        System.out.println("Summary");
        System.out.println("-------------------------------------------------------");
        System.out.println(String.format("Correctly Classified Instances          :      %s\t   %s%%",successNum,df.format(result)));
        System.out.println(String.format("Incorrectly Classified Instances        :       %s\t    %s%%",errorNum,df.format(100-result)));
        System.out.println(String.format("Total Classified Instances              :      %s",dataResults.size()));
        System.out.println("===================================");
    }
    public static void testModel(SparkSession sparkSession, String content){
 
        List<Row> data = Arrays.asList(
                RowFactory.create(AnsjUtils.participle(content))
        );
        StructType schema = new StructType(new StructField[]{
 
                new StructField("text", new ArrayType(DataTypes.StringType, false), false, Metadata.empty())
        });
        Dataset<Row> testData = sparkSession.createDataFrame(data, schema);
        //word frequency count
        Dataset<Row> transform = hashingTF.transform(testData);
        //count tf-idf
        Dataset<Row> rescaledData = idfModel.transform(transform);
        Row row =rescaledData.select("features").first();
        SparseVector sparseVector = row.getAs("features");
        Vector features = Vectors.dense(sparseVector.toArray());
        Double predict = model.predict(features);
        System.out.println("test a single");
        System.out.println("=======================================================");
        System.out.println("Result");
        System.out.println("-------------------------------------------------------");
        System.out.println(labels.get(predict.intValue()));
        System.out.println("===================================");
    }
}

 

测试结果

 

batch test
=======================================================
Summary
-------------------------------------------------------
Correctly Classified Instances          :      785   98.6181%
Incorrectly Classified Instances        :       11    1.3819%
Total Classified Instances              :      796
===================================

 

准确率98%,还可以。以上就是文本分类器的实现,我们还可以直接把数据样本换成 正常邮件|垃圾邮件 这两类的数据,就可以实现一个垃圾邮箱分类器了

 

https://github.com/Maweiming/SparkTextClassifier

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注