diff --git a/vnpy_tushare/tushare_datafeed.py b/vnpy_tushare/tushare_datafeed.py index 1077517..1e413ba 100644 --- a/vnpy_tushare/tushare_datafeed.py +++ b/vnpy_tushare/tushare_datafeed.py @@ -12,32 +12,44 @@ from vnpy.trader.object import BarData, HistoryRequest from vnpy.trader.utility import round_to +# 数据频率映射 INTERVAL_VT2TS = { Interval.MINUTE: "1min", Interval.HOUR: "60min", Interval.DAILY: "D", } -ASSET_VT2TS = { - Exchange.CFFEX: "FT", - Exchange.SHFE: "FT", - Exchange.CZCE: "FT", - Exchange.DCE: "FT", - Exchange.INE: "FT", - Exchange.SSE: "E", - Exchange.SZSE: "E", - Exchange.BITMEX: "C", - Exchange.BITSTAMP: "C", - Exchange.OKEX: "C", - Exchange.HUOBI: "C", - Exchange.BITFINEX: "C", - Exchange.BINANCE: "C", - Exchange.BYBIT: "C", - Exchange.COINBASE: "C", - Exchange.DERIBIT: "C", - Exchange.GATEIO: "C", -} - +# 股票支持列表 +STOCK_LIST = [ + Exchange.SSE, + Exchange.SZSE, + Exchange.BSE, +] + +# 期货支持列表 +FUTURE_LIST = [ + Exchange.CFFEX, + Exchange.SHFE, + Exchange.CZCE, + Exchange.DCE, + Exchange.INE, +] + +# 数字货币交易所支持列表 +CRYPTOCURRENCY_LIST = [ + Exchange.BITMEX, + Exchange.BITSTAMP, + Exchange.OKEX, + Exchange.HUOBI, + Exchange.BITFINEX, + Exchange.BINANCE, + Exchange.BYBIT, + Exchange.COINBASE, + Exchange.DERIBIT, + Exchange.GATEIO, +] + +# 交易所映射 EXCHANGE_VT2TS = { Exchange.CFFEX: "CFX", Exchange.SHFE: "SHF", @@ -48,41 +60,27 @@ Exchange.SZSE: "SZ", } +# 时间调整映射 INTERVAL_ADJUSTMENT_MAP = { Interval.MINUTE: timedelta(minutes=1), Interval.HOUR: timedelta(hours=1), Interval.DAILY: timedelta() } +# 中国上海时区 CHINA_TZ = timezone("Asia/Shanghai") def to_ts_symbol(symbol, exchange) -> Optional[str]: """将交易所代码转换为tushare代码""" # 股票 - if exchange in [Exchange.SSE, Exchange.SZSE]: + if exchange in STOCK_LIST: ts_symbol = f"{symbol}.{EXCHANGE_VT2TS[exchange]}" # 期货 - elif exchange in [ - Exchange.SHFE, - Exchange.CFFEX, - Exchange.DCE, - Exchange.CZCE, - Exchange.INE - ]: + elif exchange in FUTURE_LIST: ts_symbol = f"{symbol}.{EXCHANGE_VT2TS[exchange]}".upper() # 数字货币 - elif exchange in [ - Exchange.BITSTAMP, - Exchange.OKEX, - Exchange.HUOBI, - Exchange.BITFINEX, - Exchange.BINANCE, - Exchange.BYBIT, - Exchange.COINBASE, - Exchange.DERIBIT, - Exchange.BITSTAMP - ]: + elif exchange in CRYPTOCURRENCY_LIST: ts_symbol = symbol else: return None @@ -90,6 +88,28 @@ def to_ts_symbol(symbol, exchange) -> Optional[str]: return ts_symbol +def to_ts_asset(symbol, exchange) -> Optional[str]: + """生成tushare资产类别""" + # 股票 + if exchange in STOCK_LIST: + if exchange is Exchange.SSE and symbol[0] == "6": + asset = "E" + elif exchange is Exchange.SZSE and symbol[0] == "0" or symbol[0] == "3": + asset = "E" + else: + asset = "I" + # 期货 + elif exchange in FUTURE_LIST: + asset = "FT" + # 数字货币 + elif exchange in CRYPTOCURRENCY_LIST: + asset = "C" + else: + return None + + return asset + + class TushareDatafeed(BaseDatafeed): """TuShare数据服务接口""" @@ -121,9 +141,9 @@ def query_bar_history(self, req: HistoryRequest) -> Optional[List[BarData]]: interval = req.interval start = req.start.strftime("%Y%m%d") end = req.end.strftime("%Y%m%d") - asset = ASSET_VT2TS[exchange] ts_symbol = to_ts_symbol(symbol, exchange) + asset = to_ts_asset(symbol, exchange) ts_interval = INTERVAL_VT2TS.get(interval) if not ts_interval: