Press "Enter" to skip to content

深度学习之线性回归基本原理

线性回归的基本元素

 

为了解释线性回归,我们举一个实际的例子:我们希望根据房屋面积和房龄来估算房屋的价格。为了开发一个能预测房价的模型,我们需要收集真实的数据集。这个数据集称为训练数据集(training data set)。每行数据称为样本(sample),也可以成为数据点(data point)。我们把试图预测的内容称为标签或者目标(label/target)所依据的自变量(面积和年龄)称为特征(feature)或协变量(covariate)。

 

线性模型

 

线性假设:指目标可以表示为自变量的加权和,例如下面的公式:

 

price = w1 x1+w2 x2+…+b

 

其中w称为权重,b称为偏置,偏移量或者截距。

 

给定一个数据集,我们的目标是寻找模型的权重w和偏置b,使得根据模型做出的预测大体符合数据的真实价格。输出的预测值有输入特征通过线性模型的仿射变换决定,仿射变换由所选的权重和偏置确定。

 

而在机器学习时,我们通常使用高维数据集,当我们输入d个特征的时候,将预测结果y_hat记为:

 

y_hat = Wx+ b(W为权重即成为向量的转置,x也是向量对应于单个样本的特征)

 

损失函数

 

损失函数(loss function)能够量化目标的真实值和预测值之间的差距,通常我们会选择非负数作为损害时,且数值越小表示损失越小,当完美预测的时候损失为0.

 

回归问题最常用的损失函数就是平方误差函数,也就是

 

LOSS(i) = 1/2*(真实值(i)-预测值(i))^2

 

而为了度量整个函数在数据集上的整体损失,我们需要对LOSS求和,取均值。也就是

 

LOSS = 1/n ∑((i从1到n)1/2 (真实值(i)-预测值(i))^2)

OK,现在我们有了模型,有了目标函数(损失函数),仅需要一个优化算法,就可以了,在回归问题中最常用的就是—随机梯度下降方法。

 

随机梯度下降

 

这种方法的好处在于,我们几乎可以把他运用到任何深度学习模型,尽管有时候不是很好用,但是几乎是全适用的。它通过不断地在损失函数递减的方向上更新参数来降低误差。

 

我们这里来好好理解一下梯度下降这一概念:

 

从导数到梯度

 

在我们都熟知的初中数学中,我们学习过了导数这一概念,例如F(x)=9 x^2这一表达式,求导之后得到F‘(X)=18 x,它在几何上的意义就是函数在某点的切线方向。

 

那幺在多元函数上呢?

对于
我们同样可以计算它的导数,也就是偏导数,当全部偏导构成一个向量的时候,我们就称其为梯度如下图所示:

那幺从数学角度出发,梯度方向就是函数增长最快的方向(这里参考梯度的定义就可知)。如果要计算函数的最小值,我们可以找到梯度的反方向,也就是梯度下降的方向来一步一步让函数朝着减小方向前进,直到那个方向都是梯度增加。

 

如下图所示:

但我们会产生一个疑问,如果我们的函数图像是一个多峰图,如图

当我们处于两峰之间的时候,那也可能走到四面都是梯度增加的“盆地”,但它可能不是最小值。

 

那盲生你就发现了华点,确实,梯度下降法得到的只能是局部最小值,也只能获得局部最优解,这也是其局限性,但如果损失函数是凸函数,那幺我们一定会得到全局最优解,这里我们不再详述,只需记住在线性回归中梯度下降是可行的。

 

梯度下降

 

那幺现在我们开始用文字来跑一下梯度下降法这一算法:

 

 

首先,我们需要假设一些条件来保证后续的正常推进:

 

我们假设我们的线性回归函数为:

(其中x0为1,这样我们就有了偏置项。)

 

而损失函数我们定义为:

还有我们定义一个概念:步长或者称为学习率-α(learning

 

rate):这决定了在梯度下降过程中,每一次沿着梯度下降方向前进的长度,就如上面图像中的例子,步长就是当前沿着最快下降的位置走的那一步。

 

最后,我们要找一个能够接受的值作为我们的终止值ε,当下降的距离小于这个值的时候,我们可以近似认为损失函数已经几乎不会减小,也就近似地完美,我们已经得到了最佳的参数组。

 

算法过程:

 

好,我们现在来跑一下算法。

 

a. 首先重申我们的目标,要让损失函数也就是上面的J(θ0~θn)取值最小。

 

所以我们取J的梯度

b. 用我们的步长/学习率乘以损失函数的梯度,这样就能得到当前位置减少的距离,也就是

c. 确定是否所有的θ的梯度下降的距离都小于ε,如果小于ε那幺算法终止,否则进行下一步

 

d. 更新所有的θ,更新的表达式如下:

 

其他问题

 

OK,至此,我们的理论部分介绍完成,你可能还有有以下疑问:

 

 

    1. 我们该如何实现对参数的不断迭代更新呢?

 

    1. 实际操作中,无数次遍历整个数据集真的能够达到要求幺?甚至真的可行幺?

 

    1. 这几个问题将会在下一篇代码实现中详细介绍。

 

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注