本文不涉及线性回归具体算法和原理性的东西,纯新手向、介绍性的文章。
线性回归
线性回归,对于初学者而言(比方说我)比较难理解,其实换个叫法可能就能立马知道线性回归是做什么的了:线性拟合。所谓拟合,就简单多了,如下图所示:
线性拟合,顾名思义拟合出来的预测函数是一条直线,数学表达如下:
\(h(x)=a_0+a_1x_1+a_2x_2+..+a_nx_n+J(\theta)\)
其中 \(h(x)\) 为预测函数, \(a_i(i=1,2,..,n)\) 为估计参数,模型训练的目的就是计算出这些参数的值。
而线性回归分析的整个过程可以简单描述为如下三个步骤:
- 寻找合适的预测函数,即上文中的 \(h(x)\) ,用来预测输入数据的判断结果。这个过程时非常关键的,需要对数据有一定的了解或分析,知道或者猜测预测函数的“大概”形式,比如是线性函数还是非线性函数,若是非线性的则无法用线性回归来得出高质量的结果。
- 构造一个Loss函数(损失函数),该函数表示预测的输出(h)与训练数据标签之间的偏差,可以是二者之间的差(h-y)或者是其他的形式(如平方差开方)。综合考虑所有训练数据的“损失”,将Loss求和或者求平均,记为 \(J(\theta)\) 函数,表示所有训练数据预测值与实际类别的偏差。
- 显然, \(J(\theta)\) 函数的值越小表示预测函数越准确(即h函数越准确),所以这一步需要做的是找到 \(J(\theta)\) 函数的最小值。找函数的最小值有不同的方法,Spark中采用的是梯度下降法(stochastic gradient descent, SGD)。
关于正则化手段
线性回归同样可以采用正则化手段,其主要目的就是防止过拟合。
当采用L1正则化时,则变成了Lasso Regresion;当采用L2正则化时,则变成了Ridge Regression;线性回归未采用正则化手段。通常来说,在训练模型时是建议采用正则化手段的,特别是在训练数据的量特别少的时候,若不采用正则化手段,过拟合现象会非常严重。L2正则化相比L1而言会更容易收敛(迭代次数少),但L1可以解决训练数据量小于维度的问题(也就是n元一次方程只有不到n个表达式,这种情况下是多解或无穷解的)。
MLlib提供L1、L2和无正则化三种方法:
regularizer \(R(w)\) | gradient or sub-gradient | |
---|---|---|
zero (unregularized) | 0 | 0 |
L2 | \(\frac{1}{2}\|w\|_2^2\) | \(w\) |
L1 | \(\|w\|_1\) | \(\mathrm{sign}(w)\) |
Spark线性回归实现
测试数据
1 2 3 4 |
-0.4307829,-1.63735562648104 -2.00621178480549 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306 -0.1625189,-1.98898046126935 -0.722008756122123 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306 -0.1625189,-1.57881887548545 -2.1887840293994 1.36116336875686 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.155348103855541 ... |
附件下载:lpsa
数据格式:逗号之前为label;之后为8个特征值,以空格分隔。
代码实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
public static void main(String[] args) { SparkConf sparkConf = new SparkConf() .setAppName("Regression") .setMaster("local[2]"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD<String> data = sc.textFile("/home/yurnom/lpsa.txt"); JavaRDD<LabeledPoint> parsedData = data.map(line -> { String[] parts = line.split(","); double[] ds = Arrays.stream(parts[1].split(" ")) .mapToDouble(Double::parseDouble) .toArray(); return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(ds)); }).cache(); int numIterations = 100; //迭代次数 LinearRegressionModel model = LinearRegressionWithSGD.train(parsedData.rdd(), numIterations); RidgeRegressionModel model1 = RidgeRegressionWithSGD.train(parsedData.rdd(), numIterations); LassoModel model2 = LassoWithSGD.train(parsedData.rdd(), numIterations); print(parsedData, model); print(parsedData, model1); print(parsedData, model2); //预测一条新数据方法 double[] d = new double[]{1.0, 1.0, 2.0, 1.0, 3.0, -1.0, 1.0, -2.0}; Vector v = Vectors.dense(d); System.out.println(model.predict(v)); System.out.println(model1.predict(v)); System.out.println(model2.predict(v)); } public static void print(JavaRDD<LabeledPoint> parsedData, GeneralizedLinearModel model) { JavaPairRDD<Double, Double> valuesAndPreds = parsedData.mapToPair(point -> { double prediction = model.predict(point.features()); //用模型预测训练数据 return new Tuple2<>(point.label(), prediction); }); Double MSE = valuesAndPreds.mapToDouble((Tuple2<Double, Double> t) -> Math.pow(t._1() - t._2(), 2)).mean(); //计算预测值与实际值差值的平方值的均值 System.out.println(model.getClass().getName() + " training Mean Squared Error = " + MSE); } |
运行结果
1 2 3 4 5 6 |
LinearRegressionModel training Mean Squared Error = 6.206807793307759 RidgeRegressionModel training Mean Squared Error = 6.416002077543526 LassoModel training Mean Squared Error = 6.972349839013683 Prediction of linear: 0.805390219777772 Prediction of ridge: 1.0907608111865237 Prediction of lasso: 0.18652645118913225 |
可以看到由于采用了正则化手段,ridge和lasso相对于linear其误差要大一些。在实际测试过程中,将迭代次数变成25时,有如下输出:
1 2 3 |
LinearRegressionModel training Mean Squared Error = 50.57566692735476 RidgeRegressionModel training Mean Squared Error = 1.664723124099061E7 LassoModel training Mean Squared Error = 6.972196762562953 |
可以看到此时linear还没有收敛到最终结果,而ridge却过拟合十分严重,此时lasso已经收敛等于最终结果。至于为什么产生这样的现象,我也不清楚,原理性的东西希望以后能有机会在写一篇文章。
你好,我是一个算法初学者,你的文章对我很有帮助。不过当前这篇我是真没太懂,比如我想用回归算法预测网站每天访问PV。我应该选取哪些数据,原始数据可能好多都是文本的,怎么转化成label和向量。最后怎么验证用哪种算法预测是最准的。因为其实是官方的demo没看懂,所以你的文章才更适合我这种初学者。So,方便的话可否另外在展开一篇实际例子的文章?非常不错的文章!支持下!
额,关于维度的选取,可以阅读PCA主成分分析方面的文章(不知道你的问题是不是这个意思)。文本维度转化为label这个看你的业务怎么定义了,比如我们通常把性别中的男定义为0,女定义为1,然后进行预测,0.6就算是女,0.4则是男。不过很少这么处理文本维度,因为通常转化为数字后失去了其原本的意思。
请问博主,这段代码可以运行吗?
看到其中有部分是Scala语句,和java混合到一起没有报错?
我是个新手,因为要用java实现Logistics Regression,所以想找些例子参考一下,如果你有这方面的资料,请介绍一下,谢谢。
java8 支持lamda表达式 可以在java里加一些函数式 jre升到1.8可以运行的
Fucking hate the pefrmroing rights society they are fucking donkey raping pieces of shit. They should stop licking lawyers assholes and become proper human.FUCK YOU PRS!!!!!
at my house growing up we always had dandelion sa,eolsdmatimes with dressings or sometimes my mother would heat some bacon fat that was left over and pour it over the greens.i did not know about store lettuce till i was in my teens.we always used the greens from the fields or yards.thanks
你好,我对你的文章很感兴趣,可以留一个邮箱或qq吗?
yurnom # 126.com欢迎相互交流学习
《深入浅出Spark机器学习实战(用户行为分析)》
课程网盘下载:http://pan.baidu.com/s/1mixvUli 密码:1pfn
为什么没有后面的呢,只有前10个
想要完整代码,可以帮忙发给我吗