gru模型的训练以及调优

关纯
• 阅读 180

本周尝试使用gru模型对相对距离进行预测
gru模型如下:

class GRUModel(nn.Module):
    def __init__(self, input_size=6, hidden_size=32, num_layers=2, output_size=1, dropout_prob=0.2):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout_prob = dropout_prob

        # GRU层
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        # Dropout层(在GRU层与全连接层之间)  
        self.dropout = nn.Dropout(self.dropout_prob)
        # 输出层
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        # 初始化隐藏状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        # 前向传播GRU
        out, _ = self.gru(x, h0)
        
        # 取最后一个时间步的输出(或根据需要聚合输出)
        out = out[:, -1, :]

        # 通过全连接层得到最终输出
        out = self.fc(out)
        
        return out

遇到的问题:
出现如下报错

RuntimeError: For unbatched 2-D input, hx should also be 2-D but got 3-D tensor

问题就出现在gru模型要求的输入格式为
输入序列:一个形状为 [batch_size, sequence_length, input_size] 的张量
其中 batch_size 是批量中的样本数量,sequence_length 是序列中的时间步长数量,input_size 是每个时间步的输入特征大小。
而当前数据来源输入格式为:

torch.Size([8, 6])

解决方法:使用unsqueeze扩展为三维数据
扩展后数据格式:

torch.Size([8, 1, 6])

整体运行效果如下

gru模型的训练以及调优
目前使用的是一个二层的gru模型,并通过一个全连接层进行输出。
效果并不是很理想,经过多次调参,发现存在过拟合等问题
于是在GRU层与全连接层之间添加了Dropout层用于解决过拟合的问题

  self.dropout = nn.Dropout(self.dropout_prob)
    def forward(self, x):
        # 初始化隐藏状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        # 前向传播GRU
        out, _ = self.gru(x, h0)
        
        # 取最后一个时间步的输出(或根据需要聚合输出)
        out = out[:, -1, :]

        # 在GRU输出后应用dropout  
        out = self.dropout(out) 

        # 通过全连接层得到最终输出
        out = self.fc(out)
        
        return out

可以发现有了较大的改善,但经过多次调试准确率也并没有达到很高的水平。
gru模型的训练以及调优
之后尝试采用GRU与Cnn模型进行组合来做优化

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=8, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),
            nn.MaxPool1d(2),  
            nn.Conv1d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),
            nn.MaxPool1d(2),  
            nn.Flatten()  
        )
        self.lstm = nn.GRU(
            input_size=16,
            hidden_size=32,  # RNN隐藏神经元个数
            num_layers=2,  # RNN隐藏层个数
            batch_first=True,
            dropout=0.2
        )
        self.out = nn.Linear(32, 1)

同样,为了适配gru模型的输出,使用nn.Flatten()将cnn输出结果展平为三维数据。
调参后可以得到很好的效果:
gru模型的训练以及调优
模型结构如下所示:
gru模型的训练以及调优
并且GRU-CNN模型在更精细的判定下(如误差在1%以内)则比LSTM-CNN模型表现更优秀
LSTM-CNN:
gru模型的训练以及调优
GRU-CNN:
gru模型的训练以及调优

下图是Gru的内部结构:
gru模型的训练以及调优
我们先通过上一个传输下来的状态Ht-1和当前节点的输入Xt来获取两个门控状态

有一个当前的输入Xt,和上一个节点传递下来的隐状态 Ht-1,这个隐状态包含了之前节点的相关信息。
结合Xt和Ht-1,GRU会得到当前隐藏节点的输出Yt和Ht传递给下一个节点的隐状态,其中r为控制重置的门控(reset gate),z为控制更新的门控(update gate)
gru模型的训练以及调优
上图可转化为公式gru模型的训练以及调优gru模型的训练以及调优
其中σ为Sigmoid函数,[]表示将两个向量进行堆叠组合,其中重置门后续用于对历史记忆状态进行筛选,
更新门的输出结果将同时作用于对历史信息和当前时刻信息的筛选,只是两者为互补关系。

最后,当前时刻的输入同经过重置门后的历史记忆状态经过一个非线性层后便得到了当前时刻的新输入信息,然后再将更新门作用的结果相加便得到了当前时刻GRU的输出ht

gru模型的训练以及调优
gru模型的训练以及调优

点赞
收藏
评论区
推荐文章
blmius blmius
3年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
深度学习技术开发与应用
关键点1.强化学习的发展历程2.马尔可夫决策过程3.动态规划4.无模型预测学习5.无模型控制学习6.价值函数逼近7.策略梯度方法8.深度强化学习DQN算法系列9.深度策略梯度DDPG,PPO等第一天9:0012:0014:0017:00一、强化学习概述1.强化学习介绍2.强化学习与其它机器学习的不同3.强化学习发展历史4.强化学习典
Wesley13 Wesley13
3年前
PPDB:今晚老齐直播
【今晚老齐直播】今晚(本周三晚)20:0021:00小白开始“用”飞桨(https://www.oschina.net/action/visit/ad?id1185)由PPDE(飞桨(https://www.oschina.net/action/visit/ad?id1185)开发者专家计划)成员老齐,为深度学习小白指点迷津。
序列数据和文本的深度学习
序列数据和文本的深度学习用于构建深度学习模型的不同文本数据表示法:理解递归神经网络及其不同实现,例如长短期记忆网络(LSTM)和门控循环单元(GatedRecurrentUnit,GRU),它们为大多数深度学习模型提供文本和序列化数据;为序列化数据使用一维卷积。可以使用RNN构建的一些应用程序如下所示。文档分类器:识别推文或评论的情感,对新闻文章
Wesley13 Wesley13
3年前
mysql设置时区
mysql设置时区mysql\_query("SETtime\_zone'8:00'")ordie('时区设置失败,请联系管理员!');中国在东8区所以加8方法二:selectcount(user\_id)asdevice,CONVERT\_TZ(FROM\_UNIXTIME(reg\_time),'08:00','0
Stella981 Stella981
3年前
Django中Admin中的一些参数配置
设置在列表中显示的字段,id为django模型默认的主键list_display('id','name','sex','profession','email','qq','phone','status','create_time')设置在列表可编辑字段list_editable
Wesley13 Wesley13
3年前
60分钟视频带你掌握NLP BERT理论与实战
向AI转型的程序员都关注了这个号👇👇👇机器学习AI算法工程 公众号:datayx本课程会介绍最近NLP领域取得突破性进展的BERT模型。首先会介绍一些背景知识,包括WordEmbedding、RNN/LSTM/GRU、Seq2Seq模型和Attention机制等。然后介绍BERT的基础Transformer模
Wesley13 Wesley13
3年前
MySQL部分从库上面因为大量的临时表tmp_table造成慢查询
背景描述Time:20190124T00:08:14.70572408:00User@Host:@Id:Schema:sentrymetaLast_errno:0Killed:0Query_time:0.315758Lock_
文本的深度学习
序列数据和文本的深度学习用于构建深度学习模型的不同文本数据表示法:理解递归神经网络及其不同实现,例如长短期记忆网络(LSTM)和门控循环单元(GatedRecurrentUnit,GRU),它们为大多数深度学习模型提供文本和序列化数据;为序列化数据使用一维卷积。可以使用RNN构建的一些应用程序如下所示。文档分类器:识别推文或评论的情感,对新闻文章
Python进阶者 Python进阶者
1年前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这
美凌格栋栋酱 美凌格栋栋酱
5个月前
Oracle 分组与拼接字符串同时使用
SELECTT.,ROWNUMIDFROM(SELECTT.EMPLID,T.NAME,T.BU,T.REALDEPART,T.FORMATDATE,SUM(T.S0)S0,MAX(UPDATETIME)CREATETIME,LISTAGG(TOCHAR(
关纯
关纯
Lv1
因为除草麻烦,所以不再种花。
文章
3
粉丝
0
获赞
0