Skip to content

Commit

Permalink
✨ use pydantic compat
Browse files Browse the repository at this point in the history
  • Loading branch information
yanyongyu authored Mar 29, 2024
1 parent a8304d1 commit 2de6833
Showing 1 changed file with 50 additions and 45 deletions.
95 changes: 50 additions & 45 deletions nonebot/adapters/qq/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
overload,
)

from pydantic import BaseModel
from nonebot.message import handle_event
from pydantic import BaseModel, parse_obj_as
from nonebot.drivers import Request, Response
from nonebot.compat import type_validate_python

from nonebot.adapters import Bot as BaseBot

Expand Down Expand Up @@ -579,7 +580,7 @@ async def me(self) -> User:
"GET",
self.adapter.get_api_base().joinpath("users/@me"),
)
return parse_obj_as(User, await self._request(request))
return type_validate_python(User, await self._request(request))

@API
async def guilds(
Expand All @@ -594,7 +595,7 @@ async def guilds(
self.adapter.get_api_base().joinpath("users", "@me", "guilds"),
params=exclude_none({"before": before, "after": after, "limit": limit}),
)
return parse_obj_as(List[Guild], await self._request(request))
return type_validate_python(List[Guild], await self._request(request))

# Guild API
@API
Expand All @@ -603,7 +604,7 @@ async def get_guild(self, *, guild_id: str) -> Guild:
"GET",
self.adapter.get_api_base().joinpath("guilds", guild_id),
)
return parse_obj_as(Guild, await self._request(request))
return type_validate_python(Guild, await self._request(request))

# Channel API
@API
Expand All @@ -612,15 +613,15 @@ async def get_channels(self, *, guild_id: str) -> List[Channel]:
"GET",
self.adapter.get_api_base().joinpath("guilds", guild_id, "channels"),
)
return parse_obj_as(List[Channel], await self._request(request))
return type_validate_python(List[Channel], await self._request(request))

@API
async def get_channel(self, *, channel_id: str) -> Channel:
request = Request(
"GET",
self.adapter.get_api_base().joinpath("channels", channel_id),
)
return parse_obj_as(Channel, await self._request(request))
return type_validate_python(Channel, await self._request(request))

@API
async def post_channels(
Expand Down Expand Up @@ -654,7 +655,7 @@ async def post_channels(
}
),
)
return parse_obj_as(List[Channel], await self._request(request))
return type_validate_python(List[Channel], await self._request(request))

@API
async def patch_channel(
Expand Down Expand Up @@ -686,7 +687,7 @@ async def patch_channel(
}
),
)
return parse_obj_as(Channel, await self._request(request))
return type_validate_python(Channel, await self._request(request))

@API
async def delete_channel(self, *, channel_id: str) -> None:
Expand All @@ -710,7 +711,7 @@ async def get_members(
self.adapter.get_api_base().joinpath("guilds", guild_id, "members"),
params=exclude_none({"after": after, "limit": limit}),
)
return parse_obj_as(List[Member], await self._request(request))
return type_validate_python(List[Member], await self._request(request))

@API
async def get_role_members(
Expand All @@ -728,7 +729,7 @@ async def get_role_members(
),
params=exclude_none({"start_index": start_index, "limit": limit}),
)
return parse_obj_as(GetRoleMembersReturn, await self._request(request))
return type_validate_python(GetRoleMembersReturn, await self._request(request))

@API
async def get_member(self, *, guild_id: str, user_id: str) -> Member:
Expand All @@ -738,7 +739,7 @@ async def get_member(self, *, guild_id: str, user_id: str) -> Member:
"guilds", guild_id, "members", user_id
),
)
return parse_obj_as(Member, await self._request(request))
return type_validate_python(Member, await self._request(request))

@API
async def delete_member(
Expand Down Expand Up @@ -770,7 +771,7 @@ async def get_guild_roles(self, *, guild_id: str) -> GetGuildRolesReturn:
"GET",
self.adapter.get_api_base().joinpath("guilds", guild_id, "roles"),
)
return parse_obj_as(GetGuildRolesReturn, await self._request(request))
return type_validate_python(GetGuildRolesReturn, await self._request(request))

@API
async def post_guild_role(
Expand All @@ -792,7 +793,7 @@ async def post_guild_role(
}
),
)
return parse_obj_as(PostGuildRoleReturn, await self._request(request))
return type_validate_python(PostGuildRoleReturn, await self._request(request))

@API
async def patch_guild_role(
Expand All @@ -815,7 +816,7 @@ async def patch_guild_role(
}
),
)
return parse_obj_as(PatchGuildRoleReturn, await self._request(request))
return type_validate_python(PatchGuildRoleReturn, await self._request(request))

@API
async def delete_guild_role(self, *, guild_id: str, role_id: str) -> None:
Expand Down Expand Up @@ -872,7 +873,7 @@ async def get_channel_permissions(
"channels", channel_id, "members", user_id, "permissions"
),
)
return parse_obj_as(ChannelPermissions, await self._request(request))
return type_validate_python(ChannelPermissions, await self._request(request))

@API
async def put_channel_permissions(
Expand Down Expand Up @@ -902,7 +903,7 @@ async def get_channel_roles_permissions(
"channels", channel_id, "roles", role_id, "permissions"
),
)
return parse_obj_as(ChannelPermissions, await self._request(request))
return type_validate_python(ChannelPermissions, await self._request(request))

@API
async def put_channel_roles_permissions(
Expand Down Expand Up @@ -936,7 +937,7 @@ async def get_message_of_id(
result = await self._request(request)
if isinstance(result, dict) and "message" in result:
result = result["message"]
return parse_obj_as(GuildMessage, result)
return type_validate_python(GuildMessage, result)

@staticmethod
def _parse_send_message(data: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -1000,7 +1001,7 @@ async def post_messages(
self.adapter.get_api_base().joinpath("channels", channel_id, "messages"),
**params,
)
return parse_obj_as(GuildMessage, await self._request(request))
return type_validate_python(GuildMessage, await self._request(request))

@API
async def delete_message(
Expand Down Expand Up @@ -1028,7 +1029,7 @@ async def get_message_setting(self, *, guild_id: str) -> MessageSetting:
"guilds", guild_id, "message", "setting"
),
)
return parse_obj_as(MessageSetting, await self._request(request))
return type_validate_python(MessageSetting, await self._request(request))

# DMS API
@API
Expand All @@ -1040,7 +1041,7 @@ async def post_dms(self, *, recipient_id: str, source_guild_id: str) -> DMS:
{"recipient_id": recipient_id, "source_guild_id": source_guild_id}
),
)
return parse_obj_as(DMS, await self._request(request))
return type_validate_python(DMS, await self._request(request))

@API
async def post_dms_messages(
Expand Down Expand Up @@ -1077,7 +1078,7 @@ async def post_dms_messages(
self.adapter.get_api_base().joinpath("dms", guild_id, "messages"),
**params,
)
return parse_obj_as(GuildMessage, await self._request(request))
return type_validate_python(GuildMessage, await self._request(request))

@API
async def delete_dms_message(
Expand Down Expand Up @@ -1174,7 +1175,7 @@ async def patch_guild_mute_multi_member(
}
),
)
return parse_obj_as(List[int], await self._request(request))
return type_validate_python(List[int], await self._request(request))

# Announce API
@API
Expand Down Expand Up @@ -1226,7 +1227,7 @@ async def put_pins_message(
"channels", channel_id, "pins", message_id
),
)
return parse_obj_as(PinsMessage, await self._request(request))
return type_validate_python(PinsMessage, await self._request(request))

@API
async def delete_pins_message(self, *, channel_id: str, message_id: str) -> None:
Expand All @@ -1244,7 +1245,7 @@ async def get_pins_message(self, *, channel_id: str) -> PinsMessage:
"GET",
self.adapter.get_api_base().joinpath("channels", channel_id, "pins"),
)
return parse_obj_as(PinsMessage, await self._request(request))
return type_validate_python(PinsMessage, await self._request(request))

# Schedule API
@API
Expand All @@ -1259,7 +1260,7 @@ async def get_schedules(
self.adapter.get_api_base() / f"channels/{channel_id}/schedules",
json=exclude_none({"since": since}),
)
return parse_obj_as(List[Schedule], await self._request(request))
return type_validate_python(List[Schedule], await self._request(request))

@API
async def get_schedule(self, *, channel_id: str, schedule_id: str) -> Schedule:
Expand All @@ -1269,7 +1270,7 @@ async def get_schedule(self, *, channel_id: str, schedule_id: str) -> Schedule:
"channels", channel_id, "schedules", schedule_id
),
)
return parse_obj_as(Schedule, await self._request(request))
return type_validate_python(Schedule, await self._request(request))

@API
async def post_schedule(
Expand Down Expand Up @@ -1311,7 +1312,7 @@ async def post_schedule(
)
},
)
return parse_obj_as(Schedule, await self._request(request))
return type_validate_python(Schedule, await self._request(request))

@API
async def patch_schedule(
Expand Down Expand Up @@ -1358,7 +1359,7 @@ async def patch_schedule(
)
},
)
return parse_obj_as(Schedule, await self._request(request))
return type_validate_python(Schedule, await self._request(request))

@API
async def delete_schedule(self, *, channel_id: str, schedule_id: str) -> None:
Expand Down Expand Up @@ -1474,7 +1475,7 @@ async def get_threads_list(self, *, channel_id: str) -> GetThreadsListReturn:
"GET",
self.adapter.get_api_base().joinpath("channels", channel_id, "threads"),
)
return parse_obj_as(GetThreadsListReturn, await self._request(request))
return type_validate_python(GetThreadsListReturn, await self._request(request))

@API
async def get_thread(self, *, channel_id: str, thread_id: str) -> GetThreadReturn:
Expand All @@ -1484,7 +1485,7 @@ async def get_thread(self, *, channel_id: str, thread_id: str) -> GetThreadRetur
"channels", channel_id, "threads", thread_id
),
)
return parse_obj_as(GetThreadReturn, await self._request(request))
return type_validate_python(GetThreadReturn, await self._request(request))

@overload
async def put_thread(
Expand Down Expand Up @@ -1530,7 +1531,7 @@ async def put_thread(
}
),
)
return parse_obj_as(PutThreadReturn, await self._request(request))
return type_validate_python(PutThreadReturn, await self._request(request))

@API
async def delete_thread(self, *, channel_id: str, thread_id: str) -> None:
Expand All @@ -1551,7 +1552,9 @@ async def get_guild_api_permission(
"GET",
self.adapter.get_api_base().joinpath("guilds", guild_id, "api_permission"),
)
return parse_obj_as(GetGuildAPIPermissionReturn, await self._request(request))
return type_validate_python(
GetGuildAPIPermissionReturn, await self._request(request)
)

@API
async def post_api_permission_demand(
Expand All @@ -1575,7 +1578,7 @@ async def post_api_permission_demand(
}
),
)
return parse_obj_as(APIPermissionDemand, await self._request(request))
return type_validate_python(APIPermissionDemand, await self._request(request))

# WebSocket API
@API
Expand All @@ -1584,15 +1587,15 @@ async def url_get(self) -> UrlGetReturn:
"GET",
self.adapter.get_api_base().joinpath("gateway"),
)
return parse_obj_as(UrlGetReturn, await self._request(request))
return type_validate_python(UrlGetReturn, await self._request(request))

@API
async def shard_url_get(self) -> ShardUrlGetReturn:
request = Request(
"GET",
self.adapter.get_api_base().joinpath("gateway", "bot"),
)
return parse_obj_as(ShardUrlGetReturn, await self._request(request))
return type_validate_python(ShardUrlGetReturn, await self._request(request))

# Interaction API
@API
Expand Down Expand Up @@ -1668,7 +1671,7 @@ async def post_c2c_messages(
}
),
)
return parse_obj_as(PostC2CMessagesReturn, await self._request(request))
return type_validate_python(PostC2CMessagesReturn, await self._request(request))

@API
async def post_c2c_files(
Expand All @@ -1694,15 +1697,15 @@ async def post_c2c_files(
}
),
)
return parse_obj_as(PostC2CFilesReturn, await self._request(request))
return type_validate_python(PostC2CFilesReturn, await self._request(request))

@API
async def delete_c2c_message(self, *, openid: str, message_id: str) -> None:
request = Request(
"DELETE",
self.adapter.get_api_base().joinpath(
"v2", "users", openid, "messages", message_id
)
),
)
return await self._request(request)

Expand Down Expand Up @@ -1770,7 +1773,9 @@ async def post_group_messages(
}
),
)
return parse_obj_as(PostGroupMessagesReturn, await self._request(request))
return type_validate_python(
PostGroupMessagesReturn, await self._request(request)
)

@API
async def post_group_files(
Expand All @@ -1796,17 +1801,15 @@ async def post_group_files(
}
),
)
return parse_obj_as(PostGroupFilesReturn, await self._request(request))
return type_validate_python(PostGroupFilesReturn, await self._request(request))

@API
async def delete_group_message(
self, *, group_openid: str, message_id: str
) -> None:
async def delete_group_message(self, *, group_openid: str, message_id: str) -> None:
request = Request(
"DELETE",
self.adapter.get_api_base().joinpath(
"v2", "groups", group_openid, "messages", message_id
)
),
)
return await self._request(request)

Expand All @@ -1823,4 +1826,6 @@ async def post_group_members(
self.adapter.get_api_base().joinpath("v2", "groups", group_id, "members"),
json=exclude_none({"limit": limit, "start_index": start_index}),
)
return parse_obj_as(PostGroupMembersReturn, await self._request(request))
return type_validate_python(
PostGroupMembersReturn, await self._request(request)
)

0 comments on commit 2de6833

Please sign in to comment.