如图所示,它在第10个周期停止了。
y_pred_test_lstm = lstm_model.predict(X_test_lmse) y_train_pred_lstm = lstm_model.predict(X_train_lmse) print("The R2 score on the Train set is:\t{:0.3f}".format(r2_score(y_train, y_train_pred_lstm))) print("The R2 score on the Test set is:\t{:0.3f}".format(r2_score(y_test, y_pred_test_lstm)))
可以看出,LSTM模型的训练和测试R^2均优于ANN模型。
比较模型
接下来,我们比较两种模型的测试MSE。
nn_test_mse = nn_model.evaluate(X_test, y_test, batch_size=1) lstm_test_mse = lstm_model.evaluate(X_test_lmse, y_test, batch_size=1) print('NN: %f'%nn_test_mse) print('LSTM: %f'%lstm_test_mse)
做出预测
nn_y_pred_test = nn_model.predict(X_test) lstm_y_pred_test = lstm_model.predict(X_test_lmse) plt.figure(figsize=(10, 6)) plt.plot(y_test, label='True') plt.plot(y_pred_test_nn, label='NN') plt.title("NN's Prediction") plt.xlabel('Observation') plt.ylabel('Adj Close Scaled') plt.legend() plt.show();
plt.figure(figsize=(10, 6)) plt.plot(y_test, label='True') plt.plot(y_pred_test_lstm, label='LSTM') plt.title("LSTM's Prediction") plt.xlabel('Observation') plt.ylabel('Adj Close scaled') plt.legend() plt.show();