三.机器学习算法篇-线性回归(2)

强转根系
• 阅读 1153

1.梯度下降法

上文写的求解损失函数的最小二乘法
三.机器学习算法篇-线性回归(2)

除了最小二乘法还可以使用梯度下降求解。
我们先随机给θ一个值,然后朝着负梯度的方向移动,也就是迭代,每次得到的θ值使用J(θ)比之前更小。

三.机器学习算法篇-线性回归(2)
这个α是指学习率,或者说是步长,这个影响的迭代的快慢。

我们函数y = (x - 0.1)²/2为例,使用梯度下降的方法,求其y达到最小时,x的值
代码示例

# coding:utf-8
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import font_manager


font = font_manager.FontProperties(fname="/usr/share/fonts/wps-office/msyhbd.ttf", size=25)

class OneGard(object):
    def __init__(self,fx,hx):
        """
        :param fx: 原函数
        :param hx: 导函数
        """
        self.fx = fx
        self.hx = hx
        self.x = None
        self.GD_X = []
        self.GD_Y = []
        self.iter_num = 0
        self.f_change = None
        self.f_current = None

    def gard_fun(self,x, alpha=0.5):
        """
        梯度下降
        :param x: 初始随机x值
        :param alpha: 学习率
        :return:
        """
        self.x = x
        self.f_change = self.fx(self.x)
        self.f_current = self.f_change
        self.GD_X.append(x)
        self.GD_Y.append(self.f_current)
        while self.f_change > 1e-10 and self.iter_num < 100:
            self.iter_num += 1
            self.x = self.x - alpha * self.hx(self.x)
            tmp = self.fx(self.x)
            self.f_change = np.abs(self.f_current - tmp)
            self.f_current = tmp
            self.GD_X.append(self.x)
            self.GD_Y.append(self.f_current)


def f(x):
    """
    y = (x - 0.1)²/2
    :param x:
    :return:
    """
    return (x - 0.1) ** 2 /2

def h(x):
    """
    y = (x - 0.1)²/2的导数
    :param x:
    :return:
    """
    return (x - 0.1)



gard = OneGard(f, h)
gard.gard_fun(x=4,alpha=0.5)

print("最终x:{:.2f},y:{:.2f}" .format(gard.x, gard.f_current))
print("迭代次数{}" .format( gard.iter_num))
print("迭代过程x的取值:\n{}".format(gard.GD_X))

# 画图
X = np.arange(-4, 4.5, 0.05)
Y = np.array(list(map(lambda t: f(t), X)))

plt.figure(figsize=(20,10), facecolor='w')
plt.plot(X, Y, 'r-', linewidth=2)
plt.plot(gard.GD_X, gard.GD_Y, 'bo--', linewidth=2)

plt.show()

结果为

最终x:0.10,y:0.00
迭代次数19
迭代过程x的取值:
[4, 2.05, 1.075, 0.5874999999999999, 0.34374999999999994, 0.221875, 0.1609375, 0.13046875000000002, 0.11523437500000001, 0.10761718750000002, 0.10380859375000001, 0.10190429687500001, 0.1009521484375, 0.10047607421875, 0.10023803710937501, 0.10011901855468751, 0.10005950927734375, 0.10002975463867188, 0.10001487731933595, 0.10000743865966798]

图像如下:
三.机器学习算法篇-线性回归(2)
机器学习中梯度下降常用的有三种:
批量梯度下降(BGD)、随机梯度下降(SGD)、小批量梯度下降(MBGD)。
参考 梯度下降法的三种形式BGD、SGD以及MBGD
(这个写的清楚明了,以前学的时候,脑子听成了浆糊,看了这篇文,我是豁然开朗)

2.多项式回归

线性回归针对的是θ而言是一种,对于样本本身而言,样本可以是非线性的。
例如
三.机器学习算法篇-线性回归(2)
这时我们可以领x1 = x, x2 = x²,
得到如下
三.机器学习算法篇-线性回归(2)
这样就转变成我们的熟悉的线性回归。

多项式扩展,就是将低纬度空间的点映射到高纬度空间中。

3.其他线性回归

3.1Ridge回归

线性回归的L2正则化通常称为Ridge回归,也叫作岭回归,与标准线性回归的差异,在于它在损失函数上增加了一个L2正则化的项。
三.机器学习算法篇-线性回归(2)

λ是常数系数,是正则化系数,属于一个超参数,需要调参。
λ太小就会失去处理过拟合的能力,太大就会因力度过大而出现欠拟合的现象。

3.2 LASSO回归

使用L1正则的线性回归模型就称为LASSO回归,和Ridge回归区别在于,它加的是L1正则化的项。
三.机器学习算法篇-线性回归(2)

3.3 弹性网络

弹性网络,Elasitc Net,同时使用L1正则和L2正则
三.机器学习算法篇-线性回归(2)

3.4 Ridge回归和LASSO回归比较

1)两者都可以都可以来解决标准线性回归的过拟合问题。
2)LASSO可以用来做特征选择,但Ridge回归则不行,因为LASSO能够使得不重要的变量的系数变为0,而Ridge回归则不行。
参考 线性回归、lasso回归、岭回归以及弹性网络的系统解释

4.代码示例

岭回归预测波士顿房价

from sklearn.datasets import load_boston
from sklearn.linear_model import Ridge
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from matplotlib import pyplot as plt
from matplotlib import font_manager

font = font_manager.FontProperties(fname="/usr/share/fonts/wps-office/msyhbd.ttf", size=25)

def radge_fun():
   """
   岭回归预测波士顿房价
   :return:
   """
   lb = load_boston()

   x_train, x_test, y_train, y_test = train_test_split(lb.data, lb.target, test_size=0.2)

   x_std = StandardScaler()
   y_std = StandardScaler()

   x_train = x_std.fit_transform(x_train)
   x_test = x_std.transform(x_test)
   y_train = y_std.fit_transform(y_train.reshape(-1,1))
   y_test = y_std.transform(y_test.reshape(-1,1))

   model = Ridge(alpha=1.0)

   model.fit(x_train, y_train)

   y_predict = y_std.inverse_transform(model.predict(x_test))
   return y_predict, y_std.inverse_transform(y_test)


def draw_fun(y_predict, y_test):
   """
   绘制房价预测与真实值的散点和折线图
   :param y_predict:
   :param y_test:
   :return:
   """
   x = range(1,len(y_predict)+1)
   plt.figure(figsize=(25, 10), dpi=80)
   plt.scatter(x, y_test, label="真实值",color='blue')
   plt.scatter(x, y_predict,label='预测值', color='red')
   plt.plot(x,y_test)
   plt.plot(x,y_predict)

   x_tick = list(x)
   y_tick = list(range(0,60,5))

   plt.legend(prop=font, loc='best')
   plt.xticks(list(x), x_tick)
   plt.yticks(y_tick)
   plt.grid(alpha=0.8)
   plt.show()


if __name__ == '__main__':
   y_predict, y_test = radge_fun()
   draw_fun(y_predict, y_test)
   

结果
三.机器学习算法篇-线性回归(2)

点赞
收藏
评论区
推荐文章
深度学习技术开发与应用
关键点1.强化学习的发展历程2.马尔可夫决策过程3.动态规划4.无模型预测学习5.无模型控制学习6.价值函数逼近7.策略梯度方法8.深度强化学习DQN算法系列9.深度策略梯度DDPG,PPO等第一天9:0012:0014:0017:00一、强化学习概述1.强化学习介绍2.强化学习与其它机器学习的不同3.强化学习发展历史4.强化学习典
代码哈士奇 代码哈士奇
4年前
vue实现桌面向网页拖动文件(可显示图片/音频/视频)
效果在这里插入图片描述(https://imghelloworld.osscnbeijing.aliyuncs.com/062771391
Wesley13 Wesley13
3年前
java简单的用户登录界面+mysql
1.概述一个简单的swing登录界面,使用了简单的JDBC.如图:!在这里插入图片描述(https://imgblog.csdnimg.cn/20191210013512615.png)!在这里插入图片描述(https://imgblog.csdnimg.cn/20191210013543435.png)2.UI
Stella981 Stella981
3年前
Python实现——二次多项式回归(最小二乘法)
2019/3/25真的,当那个图像出现的时候,我真的感觉太美了。或许是一路上以来自我的摸索加深的我对于这个模型的感受吧。二次函数拟合——最小二乘法公式法与线性回归相似,对二次函数进行拟合某种意义上也只是加了一个函数,虽然求解的方程变得更加繁琐,需要准备的变量也增加到了七个。思路有借鉴于:最小二乘法拟合二次曲线C语言(https://w
Stella981 Stella981
3年前
Redis未授权访问漏洞复现学习
0x00前言前段时间看到想复现学习一下,然后就忘了越临近考试越不想复习!在这里插入图片描述(https://oscimg.oschina.net/oscnet/ec73a943a3d9e18184946ee4c4ca290e14f.jpg)常见的未授权访问漏洞Redis未授权访问漏洞MongoDB未授权访问漏
Wesley13 Wesley13
3年前
AI金融知识自学偏量化方向
前提:统计学习(统计分析)和机器学习之间的区别金融公司采用机器学习技术及招募相关人才要求第一个问题:  机器学习和统计学都是数据科学的一部分。机器学习中的学习一词表示算法依赖于一些数据(被用作训练集),来调整模型或算法的参数。这包含了许多的技术,比如回归、朴素贝叶斯或监督聚类。但不是所有的技术都适合机器学习。例如有一种统计和数
Stella981 Stella981
3年前
QT使用label显示图片或者gif并自动适应label尺寸
显示图片1.在ui界面拖动label控件至界面。2.将想要显示的图片加入qt资源库。3.添加图片至label(利用setPixmap函数)。4.自使用label尺寸(利用setScaledContents函数)。未自适应label大小的效果:!在这里插入图片描述(https://oscimg.oschina.
Wesley13 Wesley13
3年前
mysql查询每个学生的各科成绩,以及总分和平均分
今天看一个mysql教程,看到一个例子,感觉里面的解决方案不是很合理。问题如下:有学生表:!在这里插入图片描述(https://oscimg.oschina.net/oscnet/07b001b0c6cb7e0038a9299e768fc00a0d3.png)成绩表:!在这里插入图片描述(https://oscimg.o
迁移学习核心技术的开发与应用
一、机器学习简介与经典机器学习算法介绍1.什么是机器学习?2.机器学习框架与基本组成3.机器学习的训练步骤4.机器学习问题的分类5.经典机器学习算法介绍章节目标:机器学习是人工智能的重要技术之一,详细了解机器学习的原理、机制和方法,为学习深度学习与迁移学习打下坚实的基础。二、深度学习简介与经典网络结构介绍1.神经网络简介2.神经网络组件简介3.神经网
机器学习入门指南
资料获取地址见文末或评论!一、预备知识微积分(偏导数、梯度等等)概率论与数理统计(例如极大似然估计、中央极限定理、大数法则等等)最优化方法(比如梯度下降、牛顿拉普什方法、变分法(欧拉拉格朗日方程)、凸优化等等)二、路线1(基于普通最小二乘法的)简单线性回归线性回归中的新进展(岭回归和LASSO回归)(此处可以插入Bagging和AdaBoost的内容