Skip to content

Commit

Permalink
Merge branch 'main' into stream
Browse files Browse the repository at this point in the history
  • Loading branch information
vnpy authored Aug 6, 2022
2 parents 295e26c + 32882b1 commit d8dbae8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 47 deletions.
13 changes: 0 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,6 @@ VeighNa不会主动为MySQL数据库创建实例,所以使用前请确保datab

若实例尚未创建,可以使用【MySQL Workbench】客户端的【new_schema】进行操作。

### Tick时间戳的毫秒支持

由于peewee的建表功能限制,默认情况下在保存tick数据时,时间精确度只能精确到秒。如果影响使用,可按照以下方式手动修改MySQL数据表来解决:

```
# 用MySQL命令行工具连接数据库
# 选择数据实例
use vnpy;
# 修改dbtickdata表datetime列的数据格式
ALTER TABLE `dbtickdata` MODIFY COLUMN `datetime` DATETIME(3);
```

### 字符串大小写敏感支持

Expand Down
74 changes: 40 additions & 34 deletions vnpy_mysql/mysql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from vnpy.trader.setting import SETTINGS


db = PeeweeMySQLDatabase(
db: PeeweeMySQLDatabase = PeeweeMySQLDatabase(
database=SETTINGS["database.database"],
user=SETTINGS["database.user"],
password=SETTINGS["database.password"],
Expand All @@ -35,10 +35,16 @@
)


class DateTimeMillisecondField(DateTimeField):
# 毫秒支持
def get_modifiers(self):
return [3]


class DbBarData(Model):
"""K线数据表映射对象"""

id = AutoField()
id: AutoField = AutoField()

symbol: str = CharField()
exchange: str = CharField()
Expand All @@ -54,18 +60,18 @@ class DbBarData(Model):
close_price: float = FloatField()

class Meta:
database = db
indexes = ((("symbol", "exchange", "interval", "datetime"), True),)
database: PeeweeMySQLDatabase = db
indexes: tuple = ((("symbol", "exchange", "interval", "datetime"), True),)


class DbTickData(Model):
"""TICK数据表映射对象"""

id = AutoField()
id: AutoField = AutoField()

symbol: str = CharField()
exchange: str = CharField()
datetime: datetime = DateTimeField()
datetime: datetime = DateTimeMillisecondField()

name: str = CharField()
volume: float = FloatField()
Expand Down Expand Up @@ -108,14 +114,14 @@ class DbTickData(Model):
localtime: datetime = DateTimeField(null=True)

class Meta:
database = db
indexes = ((("symbol", "exchange", "datetime"), True),)
database: PeeweeMySQLDatabase = db
indexes: tuple = ((("symbol", "exchange", "datetime"), True),)


class DbBarOverview(Model):
"""K线汇总数据表映射对象"""

id = AutoField()
id: AutoField = AutoField()

symbol: str = CharField()
exchange: str = CharField()
Expand All @@ -125,8 +131,8 @@ class DbBarOverview(Model):
end: datetime = DateTimeField()

class Meta:
database = db
indexes = ((("symbol", "exchange", "interval"), True),)
database: PeeweeMySQLDatabase = db
indexes: tuple = ((("symbol", "exchange", "interval"), True),)


class DbTickOverview(Model):
Expand All @@ -150,25 +156,25 @@ class MysqlDatabase(BaseDatabase):

def __init__(self) -> None:
""""""
self.db = db
self.db: PeeweeMySQLDatabase = db
self.db.connect()
self.db.create_tables([DbBarData, DbTickData, DbBarOverview, DbTickOverview])

def save_bar_data(self, bars: List[BarData], stream: bool = False) -> bool:
"""保存K线数据"""
# 读取主键参数
bar = bars[0]
symbol = bar.symbol
exchange = bar.exchange
interval = bar.interval
bar: BarData = bars[0]
symbol: str = bar.symbol
exchange: Exchange = bar.exchange
interval: Interval = bar.interval

# 将BarData数据转换为字典,并调整时区
data = []
data: list = []

for bar in bars:
bar.datetime = convert_tz(bar.datetime)

d = bar.__dict__
d: dict = bar.__dict__
d["exchange"] = d["exchange"].value
d["interval"] = d["interval"].value
d.pop("gateway_name")
Expand All @@ -188,7 +194,7 @@ def save_bar_data(self, bars: List[BarData], stream: bool = False) -> bool:
)

if not overview:
overview = DbBarOverview()
overview: DbBarOverview = DbBarOverview()
overview.symbol = symbol
overview.exchange = exchange.value
overview.interval = interval.value
Expand Down Expand Up @@ -221,12 +227,12 @@ def save_tick_data(self, ticks: List[TickData], stream: bool = False) -> bool:
exchange: Exchange = tick.exchange

# 将TickData数据转换为字典,并调整时区
data = []
data: list = []

for tick in ticks:
tick.datetime = convert_tz(tick.datetime)

d = tick.__dict__
d: dict = tick.__dict__
d["exchange"] = d["exchange"].value
d.pop("gateway_name")
d.pop("vt_symbol")
Expand Down Expand Up @@ -276,8 +282,8 @@ def load_bar_data(
end: datetime
) -> List[BarData]:
""""""
start = start.replace(hour=0, minute=0, second=0)
end = end.replace(hour=23, minute=59, second=59)
start: datetime = start.replace(hour=0, minute=0, second=0)
end: datetime = end.replace(hour=23, minute=59, second=59)

s: ModelSelect = (
DbBarData.select().where(
Expand All @@ -291,7 +297,7 @@ def load_bar_data(

bars: List[BarData] = []
for db_bar in s:
bar = BarData(
bar: BarData = BarData(
symbol=db_bar.symbol,
exchange=Exchange(db_bar.exchange),
datetime=datetime.fromtimestamp(db_bar.datetime.timestamp(), DB_TZ),
Expand All @@ -317,8 +323,8 @@ def load_tick_data(
end: datetime
) -> List[TickData]:
"""读取TICK数据"""
start = start.replace(hour=0, minute=0, second=0)
end = end.replace(hour=23, minute=59, second=59)
start: datetime = start.replace(hour=0, minute=0, second=0)
end: datetime = end.replace(hour=23, minute=59, second=59)

s: ModelSelect = (
DbTickData.select().where(
Expand All @@ -331,7 +337,7 @@ def load_tick_data(

ticks: List[TickData] = []
for db_tick in s:
tick = TickData(
tick: TickData = TickData(
symbol=db_tick.symbol,
exchange=Exchange(db_tick.exchange),
datetime=datetime.fromtimestamp(db_tick.datetime.timestamp(), DB_TZ),
Expand Down Expand Up @@ -386,7 +392,7 @@ def delete_bar_data(
& (DbBarData.exchange == exchange.value)
& (DbBarData.interval == interval.value)
)
count = d.execute()
count: int = d.execute()

# 删除K线汇总数据
d2: ModelDelete = DbBarOverview.delete().where(
Expand All @@ -407,27 +413,27 @@ def delete_tick_data(
(DbTickData.symbol == symbol)
& (DbTickData.exchange == exchange.value)
)
count = d.execute()

count: int = d.execute()

# 删除Tick汇总数据
d2: ModelDelete = DbTickOverview.delete().where(
(DbTickOverview.symbol == symbol)
& (DbTickOverview.exchange == exchange.value)
)
d2.execute()

return count

def get_bar_overview(self) -> List[BarOverview]:
"""查询数据库中的K线汇总信息"""
# 如果已有K线,但缺失汇总信息,则执行初始化
data_count = DbBarData.select().count()
overview_count = DbBarOverview.select().count()
data_count: int = DbBarData.select().count()
overview_count: int = DbBarOverview.select().count()
if data_count and not overview_count:
self.init_bar_overview()

s: ModelSelect = DbBarOverview.select()
overviews = []
overviews: List[BarOverview] = []
for overview in s:
overview.exchange = Exchange(overview.exchange)
overview.interval = Interval(overview.interval)
Expand Down Expand Up @@ -459,7 +465,7 @@ def init_bar_overview(self) -> None:
)

for data in s:
overview = DbBarOverview()
overview: DbBarOverview = DbBarOverview()
overview.symbol = data.symbol
overview.exchange = data.exchange
overview.interval = data.interval
Expand Down

0 comments on commit d8dbae8

Please sign in to comment.