当前位置:
首页 >
keras实现简单lstm_四十二.长短期记忆网络(LSTM)过程和keras实现股票预测
发布时间:2024/9/19
37
豆豆
生活随笔
收集整理的这篇文章主要介绍了
keras实现简单lstm_四十二.长短期记忆网络(LSTM)过程和keras实现股票预测
小编觉得挺不错的,现在分享给大家,帮大家做个参考.
一.概述
传统循环网络RNN可以通过记忆体实现短期记忆进行连续数据的预测,但是,当连续数据的序列边长时,会使展开时间步过长,在反向传播更新参数的过程中,梯度要按时间步连续相乘,会导致梯度消失或者梯度爆炸。
LSTM是RNN的变体,通过门结构,有效的解决了梯度爆炸或者梯度消失问题。
LSTM在RNN的基础上引入了三个门结构和记录长期记忆的细胞态以及归纳出新知识的候选态。
二.LSTM结构
1.短期记忆
短期记忆即为RNN中的记忆体,在LSTM中,它的通过输出门
和经过tanh函数的长期记忆的哈达玛积得到:2.细胞态(长期记忆)
长期记忆记录了当前时刻的历史信息:
其中,
为上一时刻的长期记忆, 为遗忘门, 为输入门, 为候选状态,表示在本时间段归纳出的新知识:3.输入门、遗忘门、输出门
它们三个都是当前时刻的输入特征
和上个时刻的短期记忆 的函数。遗忘门通过sigmod函数,将上一层隐藏状态
和本层输入 映射到[0,1],表示上一层的内部状态 需要遗忘多少信息,公式为下:输入门
控制当前候选状态 有多少信息需要保存。输出门
控制当前时刻的内部状态 有多少信息传递给隐藏信息 。三.LSTM过程
1.先利用上一时刻的隐藏状态
和当前输入计算出三个门和候选状态:2.结合遗忘门
和输入门 更新长期记忆:3.结合输出门和内部状态更新隐藏状态:
4.反向传播,利用梯度下降等优化方法更新参数矩阵和偏置。
四.keras+LSTM实现股票预测
导入依赖包
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import pandas as pd from tensorflow.keras.layers import Dense,Dropout,LSTM from sklearn.preprocessing import MinMaxScaler from sklearn.metrics import mean_absolute_error,mean_squared_error读取数据
maotai = pd.read_csv('./SH600519.csv') training_set = maotai.iloc[0:2126,2:3].values test_set = maotai.iloc[2126:,2:3].values print(training_set.shape,test_set.shape) 输出: (2126, 1) (300, 1)归一化
sc = MinMaxScaler(feature_range=(0,1)) training_set = sc.fit_transform(training_set) test_set = sc.fit_transform(test_set)划分训练数据和测试数据
x_train,y_train,x_test,y_test=[],[],[],[] for i in range(60,len(training_set)):x_train.append(training_set[i-60:i,0])y_train.append(training_set[i,0]) np.random.seed(7) np.random.shuffle(x_train) np.random.seed(7) np.random.shuffle(y_train) tf.random.set_seed(7) x_train,y_train = np.array(x_train),np.array(y_train) x_train = np.reshape(x_train, (x_train.shape[0], 60, 1)) for i in range(60, len(test_set)):x_test.append(test_set[i - 60:i, 0])y_test.append(test_set[i, 0]) x_test, y_test = np.array(x_test), np.array(y_test) x_test = np.reshape(x_test, (x_test.shape[0], 60, 1))搭建网络
model = tf.keras.Sequential([LSTM(80,return_sequences=True),Dropout(0.2),LSTM(100),Dropout(0.2),Dense(1) ])配置网络
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss='mean_squared_error')开始训练
history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=(x_test, y_test), validation_freq=1)训练过程
Epoch 1/50 33/33 [==============================] - 4s 114ms/step - loss: 0.0135 - val_loss: 0.0110 Epoch 2/50 33/33 [==============================] - 3s 95ms/step - loss: 0.0013 - val_loss: 0.0049 Epoch 3/50 33/33 [==============================] - 3s 99ms/step - loss: 0.0011 - val_loss: 0.0051 Epoch 4/50 33/33 [==============================] - 3s 98ms/step - loss: 0.0013 - val_loss: 0.0057 Epoch 5/50 33/33 [==============================] - 3s 95ms/step - loss: 0.0011 - val_loss: 0.0047 Epoch 6/50 33/33 [==============================] - 3s 99ms/step - loss: 0.0011 - val_loss: 0.0046 Epoch 7/50 33/33 [==============================] - 3s 92ms/step - loss: 0.0011 - val_loss: 0.0046 Epoch 8/50 33/33 [==============================] - 3s 86ms/step - loss: 0.0010 - val_loss: 0.0049 Epoch 9/50 33/33 [==============================] - 3s 84ms/step - loss: 0.0010 - val_loss: 0.0051 Epoch 10/50 33/33 [==============================] - 3s 86ms/step - loss: 0.0010 - val_loss: 0.0051 Epoch 11/50 33/33 [==============================] - 3s 86ms/step - loss: 9.7592e-04 - val_loss: 0.0044 Epoch 12/50 33/33 [==============================] - 3s 87ms/step - loss: 9.6163e-04 - val_loss: 0.0043 Epoch 13/50 33/33 [==============================] - 3s 88ms/step - loss: 0.0011 - val_loss: 0.0041 Epoch 14/50 33/33 [==============================] - 3s 89ms/step - loss: 9.1143e-04 - val_loss: 0.0042 Epoch 15/50 33/33 [==============================] - 3s 89ms/step - loss: 0.0011 - val_loss: 0.0046 Epoch 16/50 33/33 [==============================] - 3s 89ms/step - loss: 8.8493e-04 - val_loss: 0.0040 Epoch 17/50 33/33 [==============================] - 3s 90ms/step - loss: 9.2448e-04 - val_loss: 0.0042 Epoch 18/50 33/33 [==============================] - 3s 91ms/step - loss: 8.7795e-04 - val_loss: 0.0038 Epoch 19/50 33/33 [==============================] - 3s 91ms/step - loss: 7.1217e-04 - val_loss: 0.0045 Epoch 20/50 33/33 [==============================] - 3s 91ms/step - loss: 0.0012 - val_loss: 0.0038 Epoch 21/50 33/33 [==============================] - 3s 93ms/step - loss: 8.5274e-04 - val_loss: 0.0037 Epoch 22/50 33/33 [==============================] - 3s 92ms/step - loss: 9.9773e-04 - val_loss: 0.0052 Epoch 23/50 33/33 [==============================] - 3s 93ms/step - loss: 9.0810e-04 - val_loss: 0.0046 Epoch 24/50 33/33 [==============================] - 3s 93ms/step - loss: 8.4353e-04 - val_loss: 0.0041 Epoch 25/50 33/33 [==============================] - 3s 95ms/step - loss: 8.7846e-04 - val_loss: 0.0037 Epoch 26/50 33/33 [==============================] - 3s 94ms/step - loss: 7.2408e-04 - val_loss: 0.0035 Epoch 27/50 33/33 [==============================] - 3s 95ms/step - loss: 7.8355e-04 - val_loss: 0.0059 Epoch 28/50 33/33 [==============================] - 3s 96ms/step - loss: 8.1942e-04 - val_loss: 0.0035 Epoch 29/50 33/33 [==============================] - 3s 96ms/step - loss: 7.7674e-04 - val_loss: 0.0033 Epoch 30/50 33/33 [==============================] - 3s 95ms/step - loss: 7.3867e-04 - val_loss: 0.0037 Epoch 31/50 33/33 [==============================] - 3s 97ms/step - loss: 7.2609e-04 - val_loss: 0.0033 Epoch 32/50 33/33 [==============================] - 3s 96ms/step - loss: 6.9374e-04 - val_loss: 0.0033 Epoch 33/50 33/33 [==============================] - 3s 96ms/step - loss: 6.3776e-04 - val_loss: 0.0050 Epoch 34/50 33/33 [==============================] - 3s 97ms/step - loss: 7.6443e-04 - val_loss: 0.0036 Epoch 35/50 33/33 [==============================] - 3s 98ms/step - loss: 7.9301e-04 - val_loss: 0.0032 Epoch 36/50 33/33 [==============================] - 3s 97ms/step - loss: 7.7646e-04 - val_loss: 0.0036 Epoch 37/50 33/33 [==============================] - 3s 99ms/step - loss: 8.3467e-04 - val_loss: 0.0033 Epoch 38/50 33/33 [==============================] - 3s 99ms/step - loss: 7.6392e-04 - val_loss: 0.0032 Epoch 39/50 33/33 [==============================] - 3s 99ms/step - loss: 6.3954e-04 - val_loss: 0.0047 Epoch 40/50 33/33 [==============================] - 3s 99ms/step - loss: 7.3498e-04 - val_loss: 0.0034 Epoch 41/50 33/33 [==============================] - 3s 99ms/step - loss: 5.8371e-04 - val_loss: 0.0031 Epoch 42/50 33/33 [==============================] - 3s 99ms/step - loss: 5.7156e-04 - val_loss: 0.0034 Epoch 43/50 33/33 [==============================] - 3s 100ms/step - loss: 6.2417e-04 - val_loss: 0.0030 Epoch 44/50 33/33 [==============================] - 3s 101ms/step - loss: 6.8761e-04 - val_loss: 0.0035 Epoch 45/50 33/33 [==============================] - 4s 108ms/step - loss: 6.7483e-04 - val_loss: 0.0031 Epoch 46/50 33/33 [==============================] - 4s 113ms/step - loss: 6.2236e-04 - val_loss: 0.0031 Epoch 47/50 33/33 [==============================] - 4s 115ms/step - loss: 6.4746e-04 - val_loss: 0.0034 Epoch 48/50 33/33 [==============================] - 4s 112ms/step - loss: 7.4622e-04 - val_loss: 0.0029 Epoch 49/50 33/33 [==============================] - 3s 101ms/step - loss: 6.8864e-04 - val_loss: 0.0028 Epoch 50/50 33/33 [==============================] - 3s 101ms/step - loss: 5.6762e-04 - val_loss: 0.0028loss曲线
loss = history.history['loss'] val_loss = history.history['val_loss'] plt.plot(loss,label='Training Loss') plt.plot(val_loss,label='Validation Loss') plt.legend() plt.title('Loss') plt.show()预测结果与真实值比较
predict_price = model.predict(x_test) predict_price = sc.inverse_transform(predict_price) real_price = sc.inverse_transform(test_set[60:]) plt.plot(real_price, color='red', label='MaoTai Stock Price') plt.plot(predict_price, color='blue', label='Predicted MaoTai Stock Price') plt.title('MaoTai Stock Price Prediction') plt.xlabel('Time') plt.ylabel('MaoTai Stock Price') plt.legend() plt.show()查看评价指标(均方误差和均方根差)
mse=mean_squared_error(predict_price,real_price) mae = mean_absolute_error(predict_price,real_price) print('mean_squared_error',mse) print('mean_absolute_error',mae) 输出: mean_squared_error 922.6493975725148 mean_absolute_error 23.789508666992194总结
以上是生活随笔为你收集整理的keras实现简单lstm_四十二.长短期记忆网络(LSTM)过程和keras实现股票预测的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: java中coverage怎么取消_别人
- 下一篇: 提高电脑反应速度_宁美千元价电脑,一体机