diff --git a/README.md b/README.md index b16ac7c..a34faff 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,9 @@ 注意:需要使用相应的数据服务权限,可以通过[该页面](https://www.tushare.pro)注册使用。 +## 数据使用事项 + +tushare数据源期货数据中,第一条夜盘k线数据是集合竞价数据,用户可以根据自己需求进行过滤或者合并。 ## 安装 diff --git a/vnpy_tushare/tushare_datafeed.py b/vnpy_tushare/tushare_datafeed.py index f1749a6..21e029c 100644 --- a/vnpy_tushare/tushare_datafeed.py +++ b/vnpy_tushare/tushare_datafeed.py @@ -1,6 +1,6 @@ from datetime import timedelta, datetime from pytz import timezone -from typing import List, Optional +from typing import Dict, List, Optional from copy import deepcopy import pandas as pd @@ -9,35 +9,47 @@ from vnpy.trader.setting import SETTINGS from vnpy.trader.datafeed import BaseDatafeed from vnpy.trader.constant import Exchange, Interval -from vnpy.trader.object import BarData, TickData, HistoryRequest +from vnpy.trader.object import BarData, HistoryRequest from vnpy.trader.utility import round_to +# 数据频率映射 INTERVAL_VT2TS = { Interval.MINUTE: "1min", Interval.HOUR: "60min", Interval.DAILY: "D", } -ASSET_VT2TS = { - Exchange.CFFEX: "FT", - Exchange.SHFE: "FT", - Exchange.CZCE: "FT", - Exchange.DCE: "FT", - Exchange.INE: "FT", - Exchange.SSE: "E", - Exchange.SZSE: "E", - Exchange.BITMEX: "C", - Exchange.BITSTAMP: "C", - Exchange.OKEX: "C", - Exchange.HUOBI: "C", - Exchange.BITFINEX: "C", - Exchange.BINANCE: "C", - Exchange.BYBIT: "C", - Exchange.COINBASE: "C", - Exchange.DERIBIT: "C", - Exchange.GATEIO: "C", -} - +# 股票支持列表 +STOCK_LIST = [ + Exchange.SSE, + Exchange.SZSE, + Exchange.BSE, +] + +# 期货支持列表 +FUTURE_LIST = [ + Exchange.CFFEX, + Exchange.SHFE, + Exchange.CZCE, + Exchange.DCE, + Exchange.INE, +] + +# 数字货币交易所支持列表 +CRYPTOCURRENCY_LIST = [ + Exchange.BITMEX, + Exchange.BITSTAMP, + Exchange.OKEX, + Exchange.HUOBI, + Exchange.BITFINEX, + Exchange.BINANCE, + Exchange.BYBIT, + Exchange.COINBASE, + Exchange.DERIBIT, + Exchange.GATEIO, +] + +# 交易所映射 EXCHANGE_VT2TS = { Exchange.CFFEX: "CFX", Exchange.SHFE: "SHF", @@ -48,41 +60,27 @@ Exchange.SZSE: "SZ", } +# 时间调整映射 INTERVAL_ADJUSTMENT_MAP = { Interval.MINUTE: timedelta(minutes=1), Interval.HOUR: timedelta(hours=1), Interval.DAILY: timedelta() } +# 中国上海时区 CHINA_TZ = timezone("Asia/Shanghai") def to_ts_symbol(symbol, exchange) -> Optional[str]: """将交易所代码转换为tushare代码""" # 股票 - if exchange in [Exchange.SSE, Exchange.SZSE]: + if exchange in STOCK_LIST: ts_symbol = f"{symbol}.{EXCHANGE_VT2TS[exchange]}" # 期货 - elif exchange in [ - Exchange.SHFE, - Exchange.CFFEX, - Exchange.DCE, - Exchange.CZCE, - Exchange.INE - ]: + elif exchange in FUTURE_LIST: ts_symbol = f"{symbol}.{EXCHANGE_VT2TS[exchange]}".upper() # 数字货币 - elif exchange in [ - Exchange.BITSTAMP, - Exchange.OKEX, - Exchange.HUOBI, - Exchange.BITFINEX, - Exchange.BINANCE, - Exchange.BYBIT, - Exchange.COINBASE, - Exchange.DERIBIT, - Exchange.BITSTAMP - ]: + elif exchange in CRYPTOCURRENCY_LIST: ts_symbol = symbol else: return None @@ -90,6 +88,28 @@ def to_ts_symbol(symbol, exchange) -> Optional[str]: return ts_symbol +def to_ts_asset(symbol, exchange) -> Optional[str]: + """生成tushare资产类别""" + # 股票 + if exchange in STOCK_LIST: + if exchange is Exchange.SSE and symbol[0] == "6": + asset = "E" + elif exchange is Exchange.SZSE and symbol[0] == "0" or symbol[0] == "3": + asset = "E" + else: + asset = "I" + # 期货 + elif exchange in FUTURE_LIST: + asset = "FT" + # 数字货币 + elif exchange in CRYPTOCURRENCY_LIST: + asset = "C" + else: + return None + + return asset + + class TushareDatafeed(BaseDatafeed): """TuShare数据服务接口""" @@ -121,9 +141,14 @@ def query_bar_history(self, req: HistoryRequest) -> Optional[List[BarData]]: interval = req.interval start = req.start.strftime("%Y%m%d") end = req.end.strftime("%Y%m%d") - asset = ASSET_VT2TS[exchange] ts_symbol = to_ts_symbol(symbol, exchange) + if not ts_symbol: + return None + + asset = to_ts_asset(symbol, exchange) + if not asset: + return None ts_interval = INTERVAL_VT2TS.get(interval) if not ts_interval: @@ -190,8 +215,13 @@ def query_bar_history(self, req: HistoryRequest) -> Optional[List[BarData]]: ) df = pd.concat([df[:-1], d1]) + bar_keys: List[datetime] = [] + bar_dict: Dict[datetime, BarData] = {} data: List[BarData] = [] + # 处理原始数据中的NaN值 + df.fillna(0, inplace=True) + if df is not None: for ix, row in df.iterrows(): if row["open"] is None: @@ -206,7 +236,7 @@ def query_bar_history(self, req: HistoryRequest) -> Optional[List[BarData]]: dt = CHINA_TZ.localize(dt) - bar = BarData( + bar: BarData = BarData( symbol=symbol, exchange=exchange, interval=interval, @@ -221,6 +251,11 @@ def query_bar_history(self, req: HistoryRequest) -> Optional[List[BarData]]: gateway_name="TS" ) - data.append(bar) + bar_dict[dt] = bar + + bar_keys = bar_dict.keys() + bar_keys = sorted(bar_keys, reverse=False) + for i in bar_keys: + data.append(bar_dict[i]) return data