Skip to content

Commit

Permalink
Merge pull request #232 from lonelam/feat/1m-index
Browse files Browse the repository at this point in the history
feat: add 1m kdata for qmt recording
  • Loading branch information
foolcage authored Dec 13, 2024
2 parents 21b09fd + 880348a commit c38662c
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 10 deletions.
22 changes: 19 additions & 3 deletions src/zvt/broker/qmt/qmt_quote.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import pandas as pd
from xtquant import xtdata

from zvt.contract import Exchange
from zvt.contract import IntervalLevel, AdjustType
from zvt.contract.api import decode_entity_id, df_to_db, get_db_session
from zvt.domain import StockQuote, Stock, Stock1dKdata
from zvt.domain.quotes.stock.stock_quote import Stock1mQuote, StockQuoteLog
from zvt.recorders.em import em_api
from zvt.utils.pd_utils import pd_is_not_null
from zvt.utils.time_utils import (
to_time_str,
Expand Down Expand Up @@ -84,7 +86,12 @@ def _qmt_instrument_detail_to_stock(stock_detail):


def get_qmt_stocks():
return xtdata.get_stock_list_in_sector("沪深A股")
df = em_api.get_tradable_list(exchange=Exchange.bj)
bj_stock_list = df["entity_id"].map(_to_qmt_code).tolist()

stock_list = xtdata.get_stock_list_in_sector("沪深A股")
stock_list += bj_stock_list
return stock_list


def get_entity_list():
Expand Down Expand Up @@ -127,7 +134,10 @@ def get_entity_list():

tick = xtdata.get_full_tick(code_list=[stock])
if tick and tick[stock]:
if code.startswith("300") or code.startswith("688"):
if code.startswith(("83", "87", "88", "889", "82", "920")):
limit_up_price = tick[stock]["lastClose"] * 1.3
limit_down_price = tick[stock]["lastClose"] * 0.7
elif code.startswith("300") or code.startswith("688"):
limit_up_price = tick[stock]["lastClose"] * 1.2
limit_down_price = tick[stock]["lastClose"] * 0.8
else:
Expand All @@ -150,9 +160,15 @@ def get_kdata(
):
code = _to_qmt_code(entity_id=entity_id)
period = level.value
start_time = to_time_str(start_timestamp, fmt="YYYYMMDDHHmmss")
end_time = to_time_str(end_timestamp, fmt="YYYYMMDDHHmmss")
# download比较耗时,建议单独定时任务来做
if download_history:
xtdata.download_history_data(stock_code=code, period=period)
print(f"download from {start_time} to {end_time}")
xtdata.download_history_data(
stock_code=code, period=period,
start_time=start_time, end_time=end_time
)
records = xtdata.get_market_data(
stock_list=[code],
period=period,
Expand Down
4 changes: 2 additions & 2 deletions src/zvt/contract/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def __init__(
end_timestamp=None,
return_unfinished=False,
) -> None:
self.start_timestamp = to_pd_timestamp(start_timestamp)
self.end_timestamp = to_pd_timestamp(end_timestamp)
super().__init__(
force_update,
sleeping_time,
Expand All @@ -213,8 +215,6 @@ def __init__(
self.real_time = real_time
self.close_hour, self.close_minute = self.entity_schema.get_close_hour_and_minute()
self.fix_duplicate_way = fix_duplicate_way
self.start_timestamp = to_pd_timestamp(start_timestamp)
self.end_timestamp = to_pd_timestamp(end_timestamp)

def get_latest_saved_record(self, entity):
order = eval("self.data_schema.{}.desc()".format(self.get_evaluated_time_field()))
Expand Down
4 changes: 4 additions & 0 deletions src/zvt/domain/quotes/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@
from .index_1wk_kdata import __all__ as _index_1wk_kdata_all

__all__ += _index_1wk_kdata_all

from .index_1m_kdata import *
from .index_1m_kdata import __all__ as _index_1m_kdata_all
__all__ += _index_1m_kdata_all
20 changes: 20 additions & 0 deletions src/zvt/domain/quotes/index/index_1m_kdata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# this file is generated by gen_kdata_schema function, dont't change it
from sqlalchemy.orm import declarative_base

from zvt.contract import TradableEntity
from zvt.contract.register import register_schema
from zvt.domain.quotes import IndexKdataCommon

KdataBase = declarative_base()


class Index1mKdata(KdataBase, IndexKdataCommon, TradableEntity):
__tablename__ = "index_1m_kdata"


register_schema(providers=["em", "sina", "qmt"], db_name="index_1m_kdata", schema_base=KdataBase, entity_type="index")


# the __all__ is generated
__all__ = ["Index1mKdata"]
25 changes: 23 additions & 2 deletions src/zvt/recorders/qmt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,26 @@
# -*- coding: utf-8 -*-

# -*- coding: utf-8 -*-#

# the __all__ is generated
__all__ = []

# __init__.py structure:
# common code of the package
# export interface in __all__ which contains __all__ of its sub modules

# import all from submodule quotes
from .quotes import *
from .quotes import __all__ as _quotes_all

__all__ += _quotes_all

# import all from submodule money_flow
from .index import *
from .index import __all__ as _index_all

__all__ += _index_all

# import all from submodule meta
from .meta import *
from .meta import __all__ as _meta_all

__all__ += _meta_all
15 changes: 15 additions & 0 deletions src/zvt/recorders/qmt/index/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-


# the __all__ is generated
__all__ = []

# __init__.py structure:
# common code of the package
# export interface in __all__ which contains __all__ of its sub modules

# import all from submodule qmt_kdata_recorder
from .qmt_index_recorder import *
from .qmt_index_recorder import __all__ as _qmt_index_recorder_all

__all__ += _qmt_index_recorder_all
173 changes: 173 additions & 0 deletions src/zvt/recorders/qmt/index/qmt_index_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# -*- coding: utf-8 -*-
import pandas as pd
from zvt.api.kdata import get_kdata_schema
from zvt.broker.qmt import qmt_quote
from zvt.consts import IMPORTANT_INDEX
from zvt.contract import IntervalLevel
from zvt.contract.api import df_to_db
from zvt.contract.recorder import FixedCycleDataRecorder
from zvt.contract.utils import evaluate_size_from_timestamp
from zvt.domain import Index, Index1mKdata
from zvt.utils.pd_utils import pd_is_not_null
from zvt.utils.time_utils import TIME_FORMAT_DAY, TIME_FORMAT_MINUTE, current_date, to_time_str


class QmtIndexRecorder(FixedCycleDataRecorder):
provider = "qmt"
data_schema = Index1mKdata
entity_provider = "em"
entity_schema = Index
download_history_data = False

def __init__(
self,
force_update=True,
sleeping_time=10,
exchanges=None,
entity_id=None,
entity_ids=None,
code=None,
codes=None,
day_data=False,
entity_filters=None,
ignore_failed=True,
real_time=False,
fix_duplicate_way="ignore",
start_timestamp=None,
end_timestamp=None,
level=IntervalLevel.LEVEL_1DAY,
kdata_use_begin_time=False,
one_day_trading_minutes=24 * 60,
return_unfinished=False,
download_history_data=False
) -> None:
level = IntervalLevel(level)
self.entity_type = "index"
self.download_history_data = download_history_data

self.data_schema = get_kdata_schema(entity_type=self.entity_type, level=level, adjust_type=None)

super().__init__(
force_update,
sleeping_time,
exchanges,
entity_id,
entity_ids,
code,
codes,
day_data,
entity_filters,
ignore_failed,
real_time,
fix_duplicate_way,
start_timestamp,
end_timestamp,
level,
kdata_use_begin_time,
one_day_trading_minutes,
return_unfinished,
)
self.one_day_trading_minutes = 240

def record(self, entity, start, end, size, timestamps):
if start and (self.level == IntervalLevel.LEVEL_1DAY):
start = start.date()
if not start:
start = "2005-01-01"
if not end:
end = current_date()

# 统一高频数据习惯,减小数据更新次数,分钟K线需要直接多读1根K线,以兼容start_timestamp=9:30, end_timestamp=15:00的情况
if self.level == IntervalLevel.LEVEL_1MIN:
end += pd.Timedelta(seconds=1)

df = qmt_quote.get_kdata(
entity_id=entity.id,
start_timestamp=start,
end_timestamp=end,
adjust_type=None,
level=self.level,
download_history=self.download_history_data,
)
time_str_fmt = TIME_FORMAT_DAY if self.level == IntervalLevel.LEVEL_1DAY else TIME_FORMAT_MINUTE
if pd_is_not_null(df):
df["entity_id"] = entity.id
df["timestamp"] = pd.to_datetime(df.index)
df["id"] = df.apply(lambda row: f"{row['entity_id']}_{to_time_str(row['timestamp'], fmt=time_str_fmt)}",
axis=1)
df["provider"] = "qmt"
df["level"] = self.level.value
df["code"] = entity.code
df["name"] = entity.name
df.rename(columns={"amount": "turnover"}, inplace=True)
df["change_pct"] = (df["close"] - df["preClose"]) / df["preClose"]
df_to_db(df=df, data_schema=self.data_schema, provider=self.provider, force_update=self.force_update)

else:
self.logger.info(f"no kdata for {entity.id}")

def evaluate_start_end_size_timestamps(self, entity):
if self.download_history_data and self.start_timestamp and self.end_timestamp:
# 历史数据可能碎片化,允许按照实际start和end之间有没有写满数据
expected_size = evaluate_size_from_timestamp(start_timestamp=self.start_timestamp,
end_timestamp=self.end_timestamp, level=self.level,
one_day_trading_minutes=self.one_day_trading_minutes)

recorded_size = self.session.query(self.data_schema).filter(
self.data_schema.entity_id == entity.id,
self.data_schema.timestamp >= self.start_timestamp,
self.data_schema.timestamp <= self.end_timestamp
).count()

if expected_size != recorded_size:
# print(f"expected_size: {expected_size}, recorded_size: {recorded_size}")
return self.start_timestamp, self.end_timestamp, self.default_size, None

start_timestamp, end_timestamp, size, timestamps = super().evaluate_start_end_size_timestamps(entity)
# start_timestamp is the last updated timestamp
if self.end_timestamp is not None:
if start_timestamp >= self.end_timestamp:
return start_timestamp, end_timestamp, 0, None
else:
size = evaluate_size_from_timestamp(
start_timestamp=start_timestamp,
level=self.level,
one_day_trading_minutes=self.one_day_trading_minutes,
end_timestamp=self.end_timestamp,
)
return start_timestamp, self.end_timestamp, size, timestamps

return start_timestamp, end_timestamp, size, timestamps

# # 中证,上海
# def record_cs_index(self, index_type):
# df = cs_index_api.get_cs_index(index_type=index_type)
# df_to_db(data_schema=self.data_schema, df=df, provider=self.provider, force_update=True)
# self.logger.info(f"finish record {index_type} index")
#
# # 国证,深圳
# def record_cn_index(self, index_type):
# if index_type == "cni":
# category_map_url = cn_index_api.cni_category_map_url
# elif index_type == "sz":
# category_map_url = cn_index_api.sz_category_map_url
# else:
# self.logger.error(f"not support index_type: {index_type}")
# assert False
#
# for category, _ in category_map_url.items():
# df = cn_index_api.get_cn_index(index_type=index_type, category=category)
# df_to_db(data_schema=self.data_schema, df=df, provider=self.provider, force_update=True)
# self.logger.info(f"finish record {index_type} index:{category.value}")


if __name__ == "__main__":
# init_log('china_stock_category.log')
start_timestamp = pd.Timestamp("2024-12-01")
end_timestamp = pd.Timestamp("2024-12-03")
QmtIndexRecorder(codes=IMPORTANT_INDEX, level=IntervalLevel.LEVEL_1MIN, sleeping_time=0,
start_timestamp=start_timestamp, end_timestamp=end_timestamp,
download_history_data=True).run()

# the __all__ is generated
__all__ = ["QmtIndexRecorder"]
38 changes: 36 additions & 2 deletions src/zvt/recorders/qmt/quotes/qmt_kdata_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from zvt.api.kdata import get_kdata_schema, get_kdata
from zvt.broker.qmt import qmt_quote
from zvt.contract import IntervalLevel, AdjustType
from zvt.contract.api import df_to_db
from zvt.contract.api import df_to_db, get_db_session, get_entities
from zvt.contract.recorder import FixedCycleDataRecorder
from zvt.domain import (
Stock,
StockKdataCommon,
)
from zvt.utils.pd_utils import pd_is_not_null
from zvt.utils.time_utils import current_date, to_time_str
from zvt.utils.time_utils import current_date, to_time_str, now_time_str


class BaseQmtKdataRecorder(FixedCycleDataRecorder):
Expand Down Expand Up @@ -69,6 +69,40 @@ def __init__(
return_unfinished,
)

def init_entities(self):
"""
init the entities which we would record data for
"""
if self.entity_provider == self.provider and self.entity_schema == self.data_schema:
self.entity_session = self.session
else:
self.entity_session = get_db_session(provider=self.entity_provider, data_schema=self.entity_schema)

if self.day_data:
df = self.data_schema.query_data(
start_timestamp=now_time_str(), columns=["entity_id", "timestamp"], provider=self.provider
)
if pd_is_not_null(df):
entity_ids = df["entity_id"].tolist()
self.logger.info(f"ignore entity_ids:{entity_ids}")
if self.entity_filters:
self.entity_filters.append(self.entity_schema.entity_id.notin_(entity_ids))
else:
self.entity_filters = [self.entity_schema.entity_id.notin_(entity_ids)]

#: init the entity list
self.entities = get_entities(
session=self.entity_session,
entity_schema=self.entity_schema,
exchanges=self.exchanges,
entity_ids=self.entity_ids,
codes=self.codes,
return_type="domain",
provider=self.entity_provider,
filters=self.entity_filters,
)

def record(self, entity, start, end, size, timestamps):
if start and (self.level == IntervalLevel.LEVEL_1DAY):
start = start.date()
Expand Down
3 changes: 2 additions & 1 deletion src/zvt/tasks/qmt_data_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from xtquant import xtdata

from zvt import init_log
from zvt.broker.qmt.qmt_quote import get_qmt_stocks
from zvt.contract import AdjustType
from zvt.recorders.qmt.meta import QMTStockRecorder
from zvt.recorders.qmt.quotes import QMTStockKdataRecorder
Expand All @@ -16,7 +17,7 @@
def download_data(download_tick=False):
period = "1d"
xtdata.download_sector_data()
stock_codes = xtdata.get_stock_list_in_sector("沪深A股")
stock_codes = get_qmt_stocks()
stock_codes = sorted(stock_codes)
count = len(stock_codes)
download_status = {"ok": False}
Expand Down

0 comments on commit c38662c

Please sign in to comment.