在spark2.0之前,逻辑回归只能做二分类的。之后加上了多分类。
根据我的理解,逻辑回归既可以用来分类,也可以用来回归,但是官网并没有给出logistic regression回归的例子,只有线性回归(linear regression)的例子
此外,官网给出了逻辑回归二分类和多分类的例子,但是多分类中的setFamlily方法无法在spark1.6和spark2.0的版本上运行,无论是用IDEA还是spark-shell。
这里给出完整的逻辑回归二分类代码,数据创建方式有两种:手动代码输入和从文件读取。
代码如下:
/** * Created by wangtuntun on 17-3-7. * 将数据代入模型进行预测 * 本来想用logistic regression做回归的,结果点开的是官网的classification下的代码 * 而且还是只支持二分类的代码 */ import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.sql.Row import org.apache.spark.sql.SQLContext object predict_with_logistic_regression_classification { def main(args: Array[String]): Unit = { // Prepare training data from a list of (label, features) tuples. val conf=new SparkConf().setAppName("tianchi").setMaster("local") val sc=new SparkContext(conf) val sqc=new SQLContext(sc) //Load TrainData //Seq类型是有顺序的数据结构 //training为dataFrame val training = sqc.createDataFrame(Seq( (1.0, Vectors.dense(0.0, 1.1, 0.1)), (0.0, Vectors.dense(2.0, 1.0, -1.0)), (0.0, Vectors.dense(2.0, 1.3, 1.0)), (1.0, Vectors.dense(0.0, 1.2, -0.5)) )).toDF("label", "features")//转变为数据框 // val raw_data=sc.textFile("/home/wangtuntun/IJCAI/Data/lr_format_data.txt") // val map_data=raw_data.map{line=> // val list_split=line.split(",") // val mylable=list_split.last.toDouble //最后一个元素 // val features=list_split.dropRight(1)//除去右边第一个的所有剩余元素 // var myarr:Array[Double]=Array() // features.foreach{x=> // myarr = myarr :+ x.toDouble // } // val myvector=Vectors.dense(myarr) // (mylable,myvector) // } // val training=sqc.createDataFrame(map_data).toDF("label","features") // 创建logistics regression实例 val lr = new LogisticRegression() // Print out the parameters, documentation, and any default values. println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") //重新设置模型参数方法 lr.setMaxIter(10)//最大迭代步数 .setRegParam(0.01) // 根据设定的模型参数与training data拟合训练得到模型 val model1 = lr.fit(training) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) // val testData = sqlContext.createDataFrame(Seq( // (1.0, Vectors.dense(-1.0, 1.5, 1.3)), // (0.0, Vectors.dense(3.0, 2.0, -0.1)), // (1.0, Vectors.dense(0.0, 2.2, -1.5)), // (1.0, Vectors.dense(2.0, 2.2, -1.0)) // )).toDF("label", "features") // model1.transform(training) .select("features", "label", "probability", "prediction")//选择数据框的某些列 .collect()//一般在filter或者足够小的结果的时候,再用collect封装返回一个数组 .foreach(println(_)) sc.stop() } }部分数据如下:
942,0,193,7,3,1,0,1,0,2016,10,30,1 942,0,193,7,3,1,0,1,0,2016,10,29,1 942,0,193,7,3,1,0,1,0,2016,10,28,1 942,0,193,7,3,1,0,1,0,2016,10,27,0 942,0,193,7,3,1,0,1,0,2016,10,26,1 942,0,193,7,3,1,0,1,0,2016,10,25,0 942,0,193,7,3,1,0,1,0,2016,10,24,0 942,0,193,7,3,1,0,1,0,2016,10,23,1 942,0,193,7,3,1,0,1,0,2016,10,22,1 942,0,193,7,3,1,0,1,0,2016,10,21,0 942,0,193,7,3,1,0,1,0,2016,10,20,1 942,0,193,7,3,1,0,1,0,2016,10,19,0 942,0,193,7,3,1,0,1,0,2016,10,18,1 942,0,193,7,3,1,0,1,0,2016,10,17,1 942,0,193,7,3,1,0,1,0,2016,10,16,0 942,0,193,7,3,1,0,1,0,2016,10,15,0 942,0,193,7,3,1,0,1,0,2016,10,14,0 942,0,193,7,3,1,0,1,0,2016,10,13,1