Skip to content

Commit

Permalink
Merge pull request #9 from Edanflame/stream
Browse files Browse the repository at this point in the history
添加Tickoverview并增加流式写入参数支持
  • Loading branch information
vnpy authored Aug 6, 2022
2 parents 32882b1 + d8dbae8 commit 74bae97
Showing 1 changed file with 74 additions and 4 deletions.
78 changes: 74 additions & 4 deletions vnpy_mysql/mysql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vnpy.trader.database import (
BaseDatabase,
BarOverview,
TickOverview,
DB_TZ,
convert_tz
)
Expand Down Expand Up @@ -134,16 +135,32 @@ class Meta:
indexes: tuple = ((("symbol", "exchange", "interval"), True),)


class DbTickOverview(Model):
"""Tick汇总数据表映射对象"""

id: AutoField = AutoField()

symbol: str = CharField()
exchange: str = CharField()
count: int = IntegerField()
start: datetime = DateTimeField()
end: datetime = DateTimeField()

class Meta:
database: PeeweeMySQLDatabase = db
indexes: tuple = ((("symbol", "exchange"), True),)


class MysqlDatabase(BaseDatabase):
"""Mysql数据库接口"""

def __init__(self) -> None:
""""""
self.db: PeeweeMySQLDatabase = db
self.db.connect()
self.db.create_tables([DbBarData, DbTickData, DbBarOverview])
self.db.create_tables([DbBarData, DbTickData, DbBarOverview, DbTickOverview])

def save_bar_data(self, bars: List[BarData]) -> bool:
def save_bar_data(self, bars: List[BarData], stream: bool = False) -> bool:
"""保存K线数据"""
# 读取主键参数
bar: BarData = bars[0]
Expand Down Expand Up @@ -184,6 +201,9 @@ def save_bar_data(self, bars: List[BarData]) -> bool:
overview.start = bars[0].datetime
overview.end = bars[-1].datetime
overview.count = len(bars)
elif stream:
overview.end = bars[-1].datetime
overview.count += len(bars)
else:
overview.start = min(bars[0].datetime, overview.start)
overview.end = max(bars[-1].datetime, overview.end)
Expand All @@ -199,8 +219,13 @@ def save_bar_data(self, bars: List[BarData]) -> bool:

return True

def save_tick_data(self, ticks: List[TickData]) -> bool:
def save_tick_data(self, ticks: List[TickData], stream: bool = False) -> bool:
"""保存TICK数据"""
# 读取主键参数
tick: TickData = ticks[0]
symbol: str = tick.symbol
exchange: Exchange = tick.exchange

# 将TickData数据转换为字典,并调整时区
data: list = []

Expand All @@ -218,6 +243,34 @@ def save_tick_data(self, ticks: List[TickData]) -> bool:
for c in chunked(data, 50):
DbTickData.insert_many(c).on_conflict_replace().execute()

# 更新Tick汇总数据
overview: DbTickOverview = DbTickOverview.get_or_none(
DbTickOverview.symbol == symbol,
DbTickOverview.exchange == exchange.value,
)

if not overview:
overview: DbTickOverview = DbTickOverview()
overview.symbol = symbol
overview.exchange = exchange.value
overview.start = ticks[0].datetime
overview.end = ticks[-1].datetime
overview.count = len(ticks)
elif stream:
overview.end = ticks[-1].datetime
overview.count += len(ticks)
else:
overview.start = min(ticks[0].datetime, overview.start)
overview.end = max(ticks[-1].datetime, overview.end)

s: ModelSelect = DbTickData.select().where(
(DbTickData.symbol == symbol)
& (DbTickData.exchange == exchange.value)
)
overview.count = s.count()

overview.save()

return True

def load_bar_data(
Expand Down Expand Up @@ -355,12 +408,20 @@ def delete_tick_data(
symbol: str,
exchange: Exchange
) -> int:
""""""
"""删除TICK数据"""
d: ModelDelete = DbTickData.delete().where(
(DbTickData.symbol == symbol)
& (DbTickData.exchange == exchange.value)
)

count: int = d.execute()

# 删除Tick汇总数据
d2: ModelDelete = DbTickOverview.delete().where(
(DbTickOverview.symbol == symbol)
& (DbTickOverview.exchange == exchange.value)
)
d2.execute()
return count

def get_bar_overview(self) -> List[BarOverview]:
Expand All @@ -379,6 +440,15 @@ def get_bar_overview(self) -> List[BarOverview]:
overviews.append(overview)
return overviews

def get_tick_overview(self) -> List[TickOverview]:
"""查询数据库中的Tick汇总信息"""
s: ModelSelect = DbTickOverview.select()
overviews: list = []
for overview in s:
overview.exchange = Exchange(overview.exchange)
overviews.append(overview)
return overviews

def init_bar_overview(self) -> None:
"""初始化数据库中的K线汇总信息"""
s: ModelSelect = (
Expand Down

0 comments on commit 74bae97

Please sign in to comment.