From 71160ac16ba6d8c9c201d91d85feadc16a43053d Mon Sep 17 00:00:00 2001 From: Cathy Fitzpatrick Date: Wed, 6 Oct 2010 22:01:49 -0600 Subject: [PATCH] Improved bans table. Transition SQL: create unique index ban_idx on bans (channel, user_id); alter table bans change column flags mod_id int; update bans set mod_id=0; alter table bans change channel channel_id int; --- src/database/DatabaseRegistry.cpp | 75 ++++++++++++++++--------------- src/database/DatabaseRegistry.h | 13 +++--- src/network/Channel.cpp | 10 ++--- src/network/network.cpp | 58 ++++++++++++++++-------- 4 files changed, 92 insertions(+), 64 deletions(-) diff --git a/src/database/DatabaseRegistry.cpp b/src/database/DatabaseRegistry.cpp index 2f0fd76..af0a2a0 100644 --- a/src/database/DatabaseRegistry.cpp +++ b/src/database/DatabaseRegistry.cpp @@ -506,13 +506,13 @@ bool DatabaseRegistry::registerUser(const string name, return true; } -void DatabaseRegistry::getGlobalBan(const std::string &user, - const std::string &ip, int &date, int &flags) { +int DatabaseRegistry::getGlobalBan(const std::string &user, + const std::string &ip) { ScopedConnection conn(m_impl->pool); Query query = conn->query( - "SELECT expiry, flags " + "SELECT expiry " "FROM bans " - "WHERE channel=-1 " + "WHERE channel_id=-1 " "AND (" "user_id=(SELECT id FROM users WHERE name=%0q) " "OR (" @@ -525,21 +525,22 @@ void DatabaseRegistry::getGlobalBan(const std::string &user, query.parse(); StoreQueryResult res = query.store(user, ip); if (res.empty()) { - date = flags = 0; - return; + return 0; } - date = (int)DateTime(res[0][0]); - flags = (int)res[0][1]; + return (int)DateTime(res[0][0]); } void DatabaseRegistry::getBan(const int channel, const string &user, int &date, int &flags) { ScopedConnection conn(m_impl->pool); Query query = conn->query( - "SELECT expiry, flags " + "SELECT expiry, IF(channel_users.flags, channel_users.flags, 0) " "FROM bans " - "WHERE channel=%0q " - "AND user_id=(SELECT id FROM users WHERE name= %1q )"); + "LEFT JOIN channel_users " + "ON mod_id=channel_users.user_id " + "AND channel_users.channel_id=bans.channel_id " + "WHERE bans.channel_id=%0q " + "AND bans.user_id=(SELECT id FROM users WHERE name=%1q)"); query.parse(); StoreQueryResult res = query.store(channel, user); if (res.empty()) { @@ -550,29 +551,33 @@ void DatabaseRegistry::getBan(const int channel, const string &user, flags = (int)res[0][1]; } -bool DatabaseRegistry::setBan(const int channel, const string &user, - const int flags, const long date, const bool ipBan) { +void DatabaseRegistry::removeBan(const int channel, const string &user) { ScopedConnection conn(m_impl->pool); - { - Query query = conn->query("delete from bans where channel="); - query << channel - << " and user_id=(select id from users where name= %0q )"; - query.parse(); - query.execute(user); - } - + Query query = conn->query( + "DELETE FROM bans WHERE channel_id=%0q AND user_id=(" + "SELECT id FROM users WHERE name=%1q" + ")" + ); + query.parse(); + query.execute(channel, user); +} + +bool DatabaseRegistry::setBan(const int channel, const string &user, + const int modId, const long date, const bool ipBan) { if (date < time(NULL)) { - return true; - } else { - Query query = conn->query( - "INSERT INTO bans " - "(channel, user_id, flags, expiry, ip_ban) " - "VALUES (%0q, (select id from users where name= %1q ), %2q, " - "%3q, %4q)"); - query.parse(); - query.execute(channel, user, flags, DateTime(date), ipBan); + // No ban to set. return false; } + ScopedConnection conn(m_impl->pool); + Query query = conn->query( + "INSERT INTO bans " + "(channel_id, user_id, mod_id, expiry, ip_ban) " + "VALUES (%0q, (SELECT id FROM users WHERE name=%1q), %2q, " + "%3q, %4q) " + "ON DUPLICATE KEY UPDATE mod_id=%2q, expiry=%3q, ip_ban=%4q"); + query.parse(); + query.execute(channel, user, modId, DateTime(date), ipBan); + return true; } void DatabaseRegistry::updateIp(const string &user, const string &ip) { @@ -614,11 +619,11 @@ const vector DatabaseRegistry::getAliases(const string &user) { const DatabaseRegistry::BAN_LIST DatabaseRegistry::getBans(const string &user) { ScopedConnection conn(m_impl->pool); - Query query = conn->query("SELECT bans.channel, users.name, bans.expiry, " - "bans.ip_ban " - "FROM bans "); - query << "JOIN users ON users.id=bans.user_id WHERE users.ip="; - query << "(SELECT ip FROM users WHERE name= %0q )"; + Query query = conn->query( + "SELECT bans.channel_id, users.name, bans.expiry, bans.ip_ban " + "FROM bans " + "JOIN users ON users.id=bans.user_id " + "WHERE users.ip=(SELECT ip FROM users WHERE name=%0q)"); query.parse(); StoreQueryResult res = query.store(user); DatabaseRegistry::BAN_LIST bans; diff --git a/src/database/DatabaseRegistry.h b/src/database/DatabaseRegistry.h index 4e1b5d9..13ab9ec 100644 --- a/src/database/DatabaseRegistry.h +++ b/src/database/DatabaseRegistry.h @@ -149,15 +149,18 @@ class DatabaseRegistry { void getBan(const int channel, const std::string &user, int &date, int &flags); - void getGlobalBan(const std::string &user, const std::string &ip, - int &date, int &flags); + int getGlobalBan(const std::string &user, const std::string &ip); /** - * Sets the ban for a user on a channel - * Returns if the user was unbanned + * Sets the ban for a user on a channel. */ - bool setBan(const int channel, const std::string &user, const int flags, + bool setBan(const int channel, const std::string &user, const int bannerId, const long date, const bool ipBan = false); + + /** + * Remove a ban. + */ + void removeBan(const int channel, const std::string &user); /** * Gets the maximum level of a user (including their alts) diff --git a/src/network/Channel.cpp b/src/network/Channel.cpp index 6980a03..ada5832 100644 --- a/src/network/Channel.cpp +++ b/src/network/Channel.cpp @@ -363,16 +363,16 @@ bool Channel::join(ClientPtr client) { bool Channel::handleBan(ClientPtr client) { int ban, flags; - m_impl->server->getRegistry()->getBan(m_impl->id, client->getName(), ban, + database::DatabaseRegistry *registry = m_impl->server->getRegistry(); + registry->getBan(m_impl->id, client->getName(), ban, flags); if (ban < time(NULL)) { // ban expired; remove it - m_impl->server->commitBan(m_impl->id, client->getName(), 0, 0); + registry->removeBan(m_impl->id, client->getName()); return false; - } else { - client->informBanned(ban); - return true; } + client->informBanned(ban); + return true; } void Channel::part(ClientPtr client) { diff --git a/src/network/network.cpp b/src/network/network.cpp index 4f34dbd..7934a7c 100644 --- a/src/network/network.cpp +++ b/src/network/network.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -129,6 +130,20 @@ OutMessageBuffer &OutMessageBuffer::operator<<(const string &str) { return *this; } +namespace { + +/** + * Truncate a collection such that its length can be stored in an object of the + * type T. + */ +template void truncate(U &collection) { + if (collection.size() > integer_traits::const_max) { + collection.resize(integer_traits::const_max); + } +} + +} + /** * A message that the client sends to the server. */ @@ -494,17 +509,20 @@ class UserDetailMessage : public OutMessage { *this << (unsigned char)0; finalise(); } - UserDetailMessage(const string& name, const string &ip, vector &aliases, + UserDetailMessage(const string& name, const string &ip, + vector &aliases, database::DatabaseRegistry::BAN_LIST &bans) : OutMessage(USER_DETAILS) { using database::DatabaseRegistry; *this << name; *this << ip; + truncate(aliases); *this << (unsigned char)aliases.size(); vector::iterator i = aliases.begin(); for (; i != aliases.end(); ++i) { *this << *i; } + truncate(bans); *this << (unsigned char)bans.size(); DatabaseRegistry::BAN_LIST::iterator j = bans.begin(); for (; j != bans.end(); ++j) { @@ -524,7 +542,7 @@ class UserPersonalMessage : public OutMessage { using database::DatabaseRegistry; *this << name; *this << *msg; - + truncate(estimates); *this << (unsigned char)estimates.size(); DatabaseRegistry::ESTIMATE_LIST::iterator i = estimates.begin(); for (; i != estimates.end(); ++i) { @@ -715,9 +733,9 @@ ChannelPtr Server::getMainChannel() const { return m_impl->getMainChannel(); } -bool Server::commitBan(const int id, const string &user, const int auth, +bool Server::commitBan(const int id, const string &user, const int bannerId, const int date) { - return m_impl->commitBan(id, user, auth, date); + return m_impl->commitBan(id, user, bannerId, date); } Server::~Server() { @@ -907,13 +925,11 @@ class ClientImpl : public Client, public enable_shared_from_this { void handleRequestChallenge(InMessage &msg) { string user; msg >> user; - int ban; - int flags; - m_server->getRegistry()->getGlobalBan(user, m_ip, ban, flags); + const int ban = m_server->getRegistry()->getGlobalBan(user, m_ip); if (ban > 0) { if (ban < time(NULL)) { - //ban expired remove the ban - m_server->commitBan(-1, user, 0, 0); + // ban expired remove the ban + m_server->getRegistry()->removeBan(-1, user); } else { informBanned(ban); return; @@ -1634,8 +1650,10 @@ void ClientImpl::handleBanMessage(InMessage &msg) { return; } + database::DatabaseRegistry *registry = m_server->getRegistry(); + // Escape if the target doesn't exist. - if (!m_server->getRegistry()->userExists(target)) { + if (!registry->userExists(target)) { sendMessage(ErrorMessage(ErrorMessage::NONEXISTENT_USER)); return; } @@ -1662,7 +1680,7 @@ void ClientImpl::handleBanMessage(InMessage &msg) { int setter; if (id == -1) { - m_server->getRegistry()->getBan(-1, target, ban, setter); + registry->getBan(-1, target, ban, setter); } else { channel->getBan(target, ban, setter); } @@ -1674,15 +1692,15 @@ void ClientImpl::handleBanMessage(InMessage &msg) { } string msg; - const bool unban = m_server->commitBan(id, target, auth.to_ulong(), - date, ipBan); - if (unban && (ban > 0)) { + const bool banned = m_server->commitBan(id, target, m_id, date, ipBan); + if (!banned && (ban > 0)) { + registry->removeBan(id, target); msg = (id == -1) ? "[unban global] " : "[unban] "; msg += m_name + " -> " + target; // Inform the invoker that the user was unbanned. sendMessage(KickBanMessage(channel->getId(), m_name, target, -1)); - } else if (!unban) { + } else if (banned) { msg = (id == -1) ? "[ban global] " : "[ban] "; msg += getBanString(m_name, target, date); @@ -1693,7 +1711,9 @@ void ClientImpl::handleBanMessage(InMessage &msg) { sendMessage(KickBanMessage(id, m_name, target, date)); } } - channel->writeLog(msg); + if (!msg.empty()) { + channel->writeLog(msg); + } } } @@ -1740,9 +1760,9 @@ void ServerImpl::loadPersonalMessage(const string &user, string &message) { m_registry.loadPersonalMessage(user, message); } -bool ServerImpl::commitBan(const int channel, const string &user, - const int auth, const int date, const bool ipBan) { - return m_registry.setBan(channel, user, auth, date, ipBan); +bool ServerImpl::commitBan(const int channel, const string &user, + const int bannerId, const int date, const bool ipBan) { + return m_registry.setBan(channel, user, bannerId, date, ipBan); } void ServerImpl::postLadderMatch(const int metagame,