赞
踩
输入特征可以根据实际情况进行选择,这里选择的输入为[“收盘价”,“最高价”,“最低价”],对未来的收盘价进行预测。
def preprocess_data(data, time_len, rate, seq_len, pre_len): train_size = int(time_len * rate) train_data = data[0:train_size] test_data = data[int(time_len*(rate)):time_len] trainX, trainY, valX,valY,testX, testY = [], [], [], [],[],[] for i in range(len(train_data) - seq_len - pre_len+1): a = train_data[i: i + seq_len + pre_len] trainX.append(a[0: seq_len]) trainY.append(a[seq_len: seq_len + pre_len]) for i in range(len(test_data) - seq_len - pre_len+1): b = test_data[i: i + seq_len + pre_len] testX.append(b[0: seq_len]) testY.append(b[seq_len: seq_len + pre_len]) trainX1 = np.array(trainX) trainY1 = np.array(trainY) testX1 = np.array(testX) testY1 = np.array(testY) return trainX1, trainY1,testX1, testY1
def metric(pred, label): with np.errstate(divide = 'ignore', invalid = 'ignore'): mask = np.not_equal(label, 0) mask = mask.astype(np.float32) mask /= np.mean(mask) mae = np.abs(np.subtract(pred, label)).astype(np.float32) rmse = np.square(mae) mape = np.divide(mae, label) mae = np.nan_to_num(mae * mask) wape = np.divide(np.sum(mae), np.sum(label)) mae = np.mean(mae) rmse = np.nan_to_num(rmse * mask) rmse = np.sqrt(np.mean(rmse)) mape = np.nan_to_num(mape * mask) mape = np.mean(mape) return mae, rmse, mape
class LSTM(nn.Module):
def __init__(self,feature):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_size=feature,hidden_size=8,batch_first=True)
self.out = nn.Linear(8,1)
def forward(self,x):
x,_ = self.lstm(x)
x = self.out(x[:,-1,:])
return x
for epoch in range(100):
loss_all = 0
for x,y in train_dataloader:
pre = model(x)
loss = criterion(pre*std+mean,y)
loss_all +=loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
for x,y in test_dataloader:
pre = model(x)*std+mean
pre_list.append(pre.item())
real_list.append(y.item())
mae, rmse, mape = metric(np.array(pre_list),np.array(real_list))
plt.figure(figsize=(20,8))
plt.plot(range(len(pre_list)),pre_list,color ="red",label ="pre")
plt.plot(range(len(real_list)),real_list,color ="blue",label ="real")
plt.legend()
plt.savefig("res.png")
plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。