基于Spark实现随机森林代码如下:
public class RandomForestClassficationTest extends TestCase implements Serializable
{
/**
*
*/
private static final long serialVersionUID = 7802523720751354318L;
class PredictResult implements Serializable{
/**
*
*/
private static final long serialVersionUID = -168308887976477219L;
double label;
double prediction;
public PredictResult(double label,double prediction){
this.label = label;
this.prediction = prediction;
}
@Override
public String toString(){
return this.label + " : " + this.prediction ;
}
}
public void test_randomForest() throws JAXBException{
SparkConf sparkConf = new SparkConf();
sparkConf.setAppName("RandomForest");
sparkConf.setMaster("local");
SparkContext sc = new SparkContext(sparkConf);
String dataPath = RandomForestClassficationTest.class.getResource("/").getPath() + "/sample_libsvm_data.txt";
RDD
dataSet = MLUtils.loadLibSVMFile(sc, dataPath);
RDD
[] rddList = dataSet.randomSplit(new double[]{0.7,0.3},1);
RDD
trainingData = rddList[0];
RDD
testData = rddList[1];
ClassTag
labelPointClassTag = trainingData.elementClassTag(); JavaRDD
trainingJavaData = new JavaRDD
(trainingData,labelPointClassTag); int numClasses = 2; Map
categoricalFeatureInfos = new HashMap
(); int numTrees = 3; String featureSubsetStrategy = "auto"; String impurity = "gini"; int maxDepth = 4; int maxBins = 32; /** * 1 numClasses分类个数为2 * 2 numTrees 表示的是随机森林中树的个数 * 3 featureSubsetStrategy * 4 */ final RandomForestModel model = RandomForest.trainClassifier(trainingJavaData, numClasses, categoricalFeatureInfos, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, 1); JavaRDD
testJavaData = new JavaRDD
(testData,testData.elementClassTag()); JavaRDD
predictRddResult = testJavaData.map(new Function
(){ /** * */ private static final long serialVersionUID = 1L; public PredictResult call(LabeledPoint point) throws Exception { // TODO Auto-generated method stub double pointLabel = point.label(); double prediction = model.predict(point.features()); PredictResult result = new PredictResult(pointLabel,prediction); return result; } }); List
predictResultList = predictRddResult.collect(); for(PredictResult result:predictResultList){ System.out.println(result.toString()); } System.out.println(model.toDebugString()); } } 得到的随机森林的展示结果如下: TreeEnsembleModel classifier with 3 trees Tree 0: If (feature 435 <= 0.0) If (feature 516 <= 0.0) Predict: 0.0 Else (feature 516 > 0.0) Predict: 1.0 Else (feature 435 > 0.0) Predict: 1.0 Tree 1: If (feature 512 <= 0.0) Predict: 1.0 Else (feature 512 > 0.0) Predict: 0.0 Tree 2: If (feature 377 <= 1.0) Predict: 0.0 Else (feature 377 > 1.0) If (feature 455 <= 0.0) Predict: 1.0 Else (feature 455 > 0.0) Predict: 0.0
转载请注明原文地址: https://ju.6miu.com/read-2197.html