搞了几天终于搞定这玩意了,也大概搞清楚了一些流程,本来写了用爬虫抓股票信息,但是太慢了,所以直接用了tushare库(财经数据接口包)来获取信息,可以直接得到例如去年2017-01-01至今的数据。不过中途专门安装第三方库都搞了好久。也遇到了好多莫名其妙的问题,所幸还是一一解决了。
废话不多说,直接贴代码:1
2
3
4
5
6
7
8
9
10
11
12import numpy as np
import matplotlib.pyplot as plt
import tushare as ts
import matplotlib.dates as mdates
import time
import datetime
import math
from matplotlib.dates import AutoDateLocator
from matplotlib import style
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
以上是需要用到的库,其中需要注意的是如果你用的是pandas库来下载数据的话,首先,yahoo、google、edgar的数据源是用不了了的,而且0.6版本的pandas有两个问题,一是datareader方法从pandas.io.data转移到了pandas_datareader.data下,导入的时候要注意,二是0.6版本的pandas会有个“ImportError: cannot import name ‘is_list_like’”的问题,解决办法就是重新安装库1
pip3 install git+https://github.com/pydata/pandas-datareader
只有等作者在以后的版本修复一下了
接下来获取股票信息1
2
3
4codeName = 'sh601001' # 股票代码
starDate = '2017-01-01' # 起始日期
endDate = time.strftime("%Y/%m/%d") # 结束日期 我这里默认是当天的日期
data = ts.get_hist_data(codeName, start=starDate, end=endDate) # 获取信息
大概格式如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28 open high close low volume price_change p_change ma5 \
date
2018-05-30 5.35 5.37 5.13 5.12 86686.50 -0.27 -5.00 5.336
2018-05-29 5.37 5.44 5.39 5.36 42680.93 0.03 0.56 5.410
2018-05-28 5.33 5.41 5.37 5.29 46136.34 0.04 0.75 5.468
2018-05-25 5.46 5.50 5.33 5.30 67176.92 -0.13 -2.38 5.530
2018-05-24 5.46 5.52 5.46 5.43 51348.99 -0.03 -0.55 5.604
2018-05-23 5.64 5.64 5.50 5.43 132375.31 -0.18 -3.17 5.614
2018-05-22 5.73 5.73 5.68 5.63 79332.33 -0.01 -0.18 5.612
2018-05-21 5.67 5.74 5.68 5.61 129048.71 -0.03 -0.53 5.582
2018-05-18 5.53 5.72 5.70 5.52 164566.48 0.19 3.45 5.548
2018-05-17 5.46 5.54 5.51 5.46 60344.20 0.03 0.55 5.518
2018-05-16 5.48 5.53 5.49 5.46 55381.03 -0.03 -0.54 5.518
2018-05-15 5.55 5.55 5.53 5.42 73622.78 0.03 0.55 5.510
2018-05-14 5.55 5.61 5.51 5.48 85036.80 -0.05 -0.90 5.480
2018-05-11 5.51 5.62 5.55 5.46 119567.14 0.04 0.73 5.454
2018-05-10 5.49 5.58 5.51 5.49 110337.44 0.05 0.92 5.402
2018-05-09 5.39 5.48 5.45 5.38 153947.75 0.06 1.11 5.364
2018-05-08 5.40 5.42 5.38 5.36 86968.29 -0.01 -0.19 5.332
2018-05-07 5.29 5.42 5.38 5.29 103588.98 0.09 1.70 5.326
2018-05-04 5.28 5.34 5.29 5.27 59665.61 -0.02 -0.38 5.306
2018-05-03 5.26 5.32 5.32 5.23 73446.02 0.03 0.57 5.312
2018-05-02 5.29 5.33 5.29 5.17 107738.95 -0.05 -0.94 5.324
2018-04-27 5.35 5.38 5.35 5.29 82541.35 0.07 1.33 5.324
2018-04-26 5.33 5.36 5.28 5.23 72458.88 -0.05 -0.94 5.302
2018-04-25 5.35 5.36 5.32 5.30 49497.68 -0.07 -1.30 5.326
2018-04-24 5.30 5.41 5.38 5.30 86297.36 0.08 1.51 5.322
······
其中有一些数据是我们不需要的,我们需要的只有开盘价、收盘价、最高价、最低价、成交量1
2
3
4
5
6
7# 构建数据集
df = data[['open', 'high', 'low', 'close', 'volume']]
# print(df.head())
df['HL_PCT'] = (df['high'] - df['low']) / df['close'] * 100.0
# 上涨幅度(阳线)
df['PCT_change'] = (df['close'] - df['open']) / df['open'] * 100.0
df = df[['close', 'HL_PCT', 'PCT_change', 'volume']]
em…..其实我也不是很清楚为什么要用这两个数据来进行预测,因为我不是炒股的,但看见有人这么用就假装很懂的用上吧╮(╯▽╰)╭
接下来进行预测1
2
3
4
5forecast_col = 'close'
# fillna()填充nan数据,返回填充后的结果,inplace=True表示在原dataFrame中修改
df.fillna(0, inplace=True)
# forecast_out表示往后预测的天数 math.ceil 将小数向整数进位
forecast_out = int(math.ceil(0.01 * len(df)))
1 | df['label'] = df[forecast_col].shift(-forecast_out) |
使用sklearn(机器学习库)中的train_test_split()方法来进行训练,用来随机划分训练集和测试集的。这里要说一下的是,以前这个方法是在sklearn.cross_validation下的,现在已经转移到了sklearn.model_selection下了。1
2
3
4
5
6# 将数据集的80%作为训练,20%为测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
clf = LinearRegression()
clf.fit(X_train, y_train)
accuracy = clf.score(X_test, y_test)
print(accuracy)
这里clf.score()的意思是正实例占所有正实例的比例,意思可以理解为准确度,理论上作为训练集的数据越多,准确度越高。这里我用了LinearRegression()做回归器,合理性有待商榷,其它的方法我也没有试验过,也许还有更合理的方法。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15# 用训练好的clf模型去预测X_lately(X_lately表示forecast_out天后的数据集)
forecast_set = clf.predict(X_lately)
df['forecast'] = np.nan
# iloc通过行号索引行数据
last_date = df.iloc[1].name
last_unix = time.mktime(time.strptime(last_date, '%Y-%m-%d'))
one_day = 86400 # 24*60*60
# next_unix 22+1=23号
next_unix = last_unix + one_day
# 为预测数据添加时间戳
for i in forecast_set:
next_unix += 86400
next_date = datetime.datetime.fromtimestamp(next_unix)
next_dateStr = next_date.strftime("%Y-%m-%d")
df.loc[next_dateStr] = [np.nan for _ in range(len(df.columns) - 1)] + [i]
这里将预测出来的forecast_out天后的数据添加到总数据当中,并添加好时间戳。
到这里为止,我们所有的数据都准备完成了(下载数据→分析→训练→验证→预测→整合),接下来就是使用matplotlib库来画图了1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16# 生成横纵坐标信息
data_time = df.index
data_time_translation = [datetime.datetime.strptime(d, '%Y-%m-%d').date() for d in data_time]
data_close = df['close'].values
data_forecast = df['forecast'].values
# 获取端点坐标
data_forecastP = len(data_forecast[~np.isnan(data_forecast)])
Xi = [data_time_translation[0], data_time_translation[-data_forecastP]]
Yi = [data_close[0], data_forecast[-data_forecastP]]
# 配置时间坐标轴
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) # 显示时间坐标的格式
autodates = AutoDateLocator() # 时间间隔自动选取
plt.gca().xaxis.set_major_locator(autodates)
由于我们将画两条线,一条原始数据的线,另一条预测数据的线,所以会造成两条线中间是断开的,这里计算出了原始数据线的末尾点和预测线的开始点,用第三条线连接起来(其实只是为了好看)
最后画图1
2
3
4
5
6
7
8
9
10
11
12
13
14
15# 绘制原始数据线
plt.plot(data_time_translation, data_close, color='r', label='原始数据', lw=1.5)
# 绘制预测走势线
plt.plot(data_time_translation, data_forecast, color='b', label='预测走势', lw=1.5)
# 将上面两条线连接起来
plt.plot(Xi, Yi, color='b', lw=1.5)
plt.gcf().autofmt_xdate() # 自动旋转日期标记
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置字体,不然无法显示中文
plt.grid(True) # 显示网格
plt.axis("tight") # 修改x、y坐标的范围让所有的数据显示出来
plt.xlabel('日期') # 横坐标说明
plt.ylabel('价格') # 纵坐标说明
plt.title('股票代码:' + codeName) # 标题
plt.legend(loc='lower left') # 显示图例
plt.show()
最终结果如下:
放大,放大,再放大!
其中红线是真实数据,蓝线是通过机器学习预测出来的数据,最后附上forecast_out(预测天数)和accuracy(准确率)的值
好了,这个基于sklearn的股票预测就完成了。不过应该不会有人真的按照这个结果去买股吧,毕竟影响股票价格的因素太多了,不是单单能通过机器预测出来的,娱乐娱乐就好了。