Skip to content

Commit

Permalink
Remove DM channel cache (#87)
Browse files Browse the repository at this point in the history
* Remove DM channel cache

* Improve `RequireContextAttribute`

* Blank
  • Loading branch information
KubaZ2 authored Jan 22, 2025
1 parent 8aab034 commit 6293d3d
Show file tree
Hide file tree
Showing 19 changed files with 108 additions and 234 deletions.
3 changes: 0 additions & 3 deletions Hosting/NetCord.Hosting/Gateway/GatewayClientOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ internal partial class Validator : IValidateOptions<GatewayClientOptions>

public Shard? Shard { get; set; }

public bool? CacheDMChannels { get; set; }

public RestClientConfiguration? RestClientConfiguration { get; set; }

internal GatewayClientConfiguration CreateConfiguration()
Expand All @@ -71,7 +69,6 @@ internal GatewayClientConfiguration CreateConfiguration()
LargeThreshold,
Presence,
Shard,
CacheDMChannels,
RestClientConfiguration);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ internal partial class Validator : IValidateOptions<ShardedGatewayClientOptions>

public int? ShardCount { get; set; }

public bool? CacheDMChannels { get; set; }

public RestClientConfiguration? RestClientConfiguration { get; set; }

// Simple properties
Expand Down Expand Up @@ -98,7 +96,6 @@ internal ShardedGatewayClientConfiguration CreateConfiguration()
CreateFactory(LargeThreshold, LargeThresholdFactory),
CreateFactory(Presence, PresenceFactory),
ShardCount,
CacheDMChannels,
RestClientConfiguration);

static Func<Shard, T?>? CreateFactory<T>(T? value, Func<Shard, T?>? func, [CallerArgumentExpression(nameof(value))] string valueName = "", [CallerArgumentExpression(nameof(func))] string funcName = "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ public class ApplicationCommandContext(ApplicationCommandInteraction interaction
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;
public User User => Interaction.User;

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpApplicationCommandContext(ApplicationCommandInteraction interaction, RestClient client)
Expand Down Expand Up @@ -50,6 +52,8 @@ public class SlashCommandContext(SlashCommandInteraction interaction, GatewayCli
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;
public User User => Interaction.User;

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpSlashCommandContext(SlashCommandInteraction interaction, RestClient client)
Expand Down Expand Up @@ -82,6 +86,8 @@ public class UserCommandContext(UserCommandInteraction interaction, GatewayClien
public TextChannel Channel => Interaction.Channel;
public User User => Interaction.User;
public User Target => Interaction.Data.TargetUser;

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpUserCommandContext(UserCommandInteraction interaction, RestClient client)
Expand Down Expand Up @@ -115,6 +121,8 @@ public class MessageCommandContext(MessageCommandInteraction interaction, Gatewa
public TextChannel Channel => Interaction.Channel;
public User User => Interaction.User;
public RestMessage Target => Interaction.Data.TargetMessage;

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpMessageCommandContext(MessageCommandInteraction interaction, RestClient client)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ public class AutocompleteInteractionContext(AutocompleteInteraction interaction,
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;
public User User => Interaction.User;

ulong? IGuildContext.GuildId => Interaction.GuildId;
}
2 changes: 2 additions & 0 deletions NetCord.Services/Commands/CommandContexts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ public class CommandContext(Message message, GatewayClient client)
public Guild? Guild => Message.Guild;
public TextChannel? Channel => Message.Channel;
public User User => Message.Author;

ulong? IGuildContext.GuildId => Message.GuildId;
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ public class ComponentInteractionContext(ComponentInteraction interaction, Gatew
public User User => Interaction.User;
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpComponentInteractionContext(ComponentInteraction interaction, RestClient client)
Expand Down Expand Up @@ -52,6 +54,8 @@ public class MessageComponentInteractionContext(MessageComponentInteraction inte
public User User => Interaction.User;
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpMessageComponentInteractionContext(MessageComponentInteraction interaction, RestClient client)
Expand Down Expand Up @@ -87,6 +91,8 @@ public class ButtonInteractionContext(ButtonInteraction interaction, GatewayClie
public User User => Interaction.User;
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpButtonInteractionContext(ButtonInteraction interaction, RestClient client)
Expand Down Expand Up @@ -123,6 +129,8 @@ public class StringMenuInteractionContext(StringMenuInteraction interaction, Gat
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;
public IReadOnlyList<string> SelectedValues => Interaction.Data.SelectedValues;

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpStringMenuInteractionContext(StringMenuInteraction interaction, RestClient client)
Expand Down Expand Up @@ -160,6 +168,8 @@ public class EntityMenuInteractionContext(EntityMenuInteraction interaction, Gat
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;
public IReadOnlyList<ulong> SelectedValues => Interaction.Data.SelectedValues;

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpEntityMenuInteractionContext(EntityMenuInteraction interaction, RestClient client)
Expand Down Expand Up @@ -197,6 +207,8 @@ public class UserMenuInteractionContext(UserMenuInteraction interaction, Gateway
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;
public IReadOnlyList<User> SelectedUsers { get; } = Utils.GetUserMenuValues(interaction);

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpUserMenuInteractionContext(UserMenuInteraction interaction, RestClient client)
Expand Down Expand Up @@ -234,6 +246,8 @@ public class RoleMenuInteractionContext(RoleMenuInteraction interaction, Gateway
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;
public IReadOnlyList<Role> SelectedRoles { get; } = Utils.GetRoleMenuValues(interaction);

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpRoleMenuInteractionContext(RoleMenuInteraction interaction, RestClient client)
Expand Down Expand Up @@ -271,6 +285,8 @@ public class MentionableMenuInteractionContext(MentionableMenuInteraction intera
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;
public IReadOnlyList<Mentionable> SelectedMentionables { get; } = Utils.GetMentionableMenuValues(interaction);

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpMentionableMenuInteractionContext(MentionableMenuInteraction interaction, RestClient client)
Expand Down Expand Up @@ -308,6 +324,8 @@ public class ChannelMenuInteractionContext(ChannelMenuInteraction interaction, G
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;
public IReadOnlyList<Channel> SelectedChannels { get; } = Utils.GetChannelMenuValues(interaction);

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpChannelMenuInteractionContext(ChannelMenuInteraction interaction, RestClient client)
Expand Down Expand Up @@ -343,6 +361,8 @@ public class ModalInteractionContext(ModalInteraction interaction, GatewayClient
public Guild? Guild => Interaction.Guild;
public TextChannel Channel => Interaction.Channel;
public IReadOnlyList<IComponent> Components => Interaction.Data.Components;

ulong? IGuildContext.GuildId => Interaction.GuildId;
}

public class HttpModalInteractionContext(ModalInteraction interaction, RestClient client)
Expand Down
2 changes: 2 additions & 0 deletions NetCord.Services/Contexts/IGuildContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ namespace NetCord.Services;
public interface IGuildContext
{
public Guild? Guild { get; }

internal protected ulong? GuildId { get; }
}
35 changes: 21 additions & 14 deletions NetCord.Services/PreconditionAttributes/RequireContextAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,55 @@

namespace NetCord.Services;

public class RequireContextAttribute<TContext> : PreconditionAttribute<TContext> where TContext : IChannelContext
#pragma warning disable IDE0290 // Use primary constructor

public class RequireContextAttribute<TContext> : PreconditionAttribute<TContext> where TContext : IGuildContext
{
public RequiredContext RequiredContext { get; }
public RequiredContext RequiredContext => GetRequiredContext(_guild);

public string Format => _format.Format;

private readonly bool _guild;
private readonly CompositeFormat _format;

/// <param name="requiredContext"></param>
/// <param name="format">{0} - required context</param>
public RequireContextAttribute(RequiredContext requiredContext, [StringSyntax(StringSyntaxAttribute.CompositeFormat)] string format = "Required context: {0}.")
{
if (requiredContext > RequiredContext.DM)
throw new InvalidEnumArgumentException(nameof(requiredContext), (int)requiredContext, typeof(RequiredContext));
_guild = requiredContext switch
{
RequiredContext.Guild => true,
RequiredContext.DM => false,
_ => throw new InvalidEnumArgumentException(nameof(requiredContext), (int)requiredContext, typeof(RequiredContext)),
};

RequiredContext = requiredContext;
_format = CompositeFormat.Parse(format);
}

public override ValueTask<PreconditionResult> EnsureCanExecuteAsync(TContext context, IServiceProvider? serviceProvider)
{
var channel = context.Channel;

var requiredContext = RequiredContext;
var guild = _guild;
var hasValue = context.GuildId.HasValue;

if (requiredContext switch
if (guild switch
{
RequiredContext.Guild => channel is not IGuildChannel,
RequiredContext.GroupDM => channel is not GroupDMChannel,
RequiredContext.DM => channel is not DMChannel,
_ => throw new InvalidOperationException(),
true => !hasValue,
false => hasValue,
})
{
var requiredContext = GetRequiredContext(guild);
return new(new InvalidContextResult(string.Format(null, _format, requiredContext), requiredContext));
}

return new(PreconditionResult.Success);
}

private static RequiredContext GetRequiredContext(bool guild) => guild ? RequiredContext.Guild : RequiredContext.DM;
}

public enum RequiredContext : byte
{
Guild,
GroupDM,
DM,
}

Expand Down
51 changes: 2 additions & 49 deletions NetCord/Gateway/GatewayClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ public partial class GatewayClient : WebSocketClient, IEntity
private readonly int? _largeThreshold;
private readonly PresenceProperties? _presence;
private readonly GatewayIntents _intents;
private readonly bool _cacheDMChannels;
private readonly object? _DMsLock;
private readonly Dictionary<ulong, SemaphoreSlim>? _DMSemaphores;
private readonly IGatewayCompression _compression;
private readonly bool _disposeRest;

Expand Down Expand Up @@ -837,12 +834,6 @@ public partial class GatewayClient : WebSocketClient, IEntity
_presence = configuration.Presence;
_intents = configuration.Intents.GetValueOrDefault(GatewayIntents.AllNonPrivileged);

if (_cacheDMChannels = configuration.CacheDMChannels.GetValueOrDefault(true))
{
_DMsLock = new();
_DMSemaphores = [];
}

var compression = _compression = configuration.Compression ?? IGatewayCompression.CreateDefault();
Uri = new($"wss://{configuration.Hostname ?? Discord.GatewayHostname}/?v={(int)configuration.Version.GetValueOrDefault(ApiVersion.V10)}&encoding=json&compress={compression.Name}", UriKind.Absolute);
Cache = configuration.Cache ?? new GatewayClientCache();
Expand Down Expand Up @@ -1290,36 +1281,12 @@ await InvokeEventAsync(Ready, args, data =>
break;
case "MESSAGE_CREATE":
{
await InvokeEventAsync(
MessageCreate,
() => data.ToObject(Serialization.Default.JsonMessage),
json => Message.CreateFromJson(json, Cache, Rest),
json => _cacheDMChannels && !json.GuildId.HasValue && !json.Flags.GetValueOrDefault().HasFlag(MessageFlags.Ephemeral),
json =>
{
var channelId = json.ChannelId;
if (!_DMSemaphores!.TryGetValue(channelId, out var semaphore))
_DMSemaphores.Add(channelId, semaphore = new(1, 1));
return semaphore;
},
json => CacheChannelAsync(json.ChannelId)).ConfigureAwait(false);
await InvokeEventAsync(MessageCreate, () => Message.CreateFromJson(data.ToObject(Serialization.Default.JsonMessage), Cache, Rest)).ConfigureAwait(false);
}
break;
case "MESSAGE_UPDATE":
{
await InvokeEventAsync(
MessageUpdate,
() => data.ToObject(Serialization.Default.JsonMessage),
json => Message.CreateFromJson(json, Cache, Rest),
json => _cacheDMChannels && !json.GuildId.HasValue && !json.Flags.GetValueOrDefault().HasFlag(MessageFlags.Ephemeral),
json =>
{
var channelId = json.ChannelId;
if (!_DMSemaphores!.TryGetValue(channelId, out var semaphore))
_DMSemaphores.Add(channelId, semaphore = new(1, 1));
return semaphore;
},
json => CacheChannelAsync(json.ChannelId)).ConfigureAwait(false);
await InvokeEventAsync(MessageUpdate, () => Message.CreateFromJson(data.ToObject(Serialization.Default.JsonMessage), Cache, Rest)).ConfigureAwait(false);
}
break;
case "MESSAGE_DELETE":
Expand Down Expand Up @@ -1457,20 +1424,6 @@ await InvokeEventAsync(

[MethodImpl(MethodImplOptions.AggressiveInlining)]
ulong GetGuildId() => data.GetProperty("guild_id").ToObject(Serialization.Default.UInt64);

async ValueTask CacheChannelAsync(ulong channelId)
{
var cache = Cache;
if (!cache.DMChannels.ContainsKey(channelId))
{
var channel = await Rest.GetChannelAsync(channelId).ConfigureAwait(false);
if (channel is DMChannel dMChannel)
{
lock (_DMsLock!)
Cache = Cache.CacheDMChannel(dMChannel);
}
}
}
}

protected override void Dispose(bool disposing)
Expand Down
13 changes: 0 additions & 13 deletions NetCord/Gateway/GatewayClientCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ public sealed record GatewayClientCache : IGatewayClientCache
{
public GatewayClientCache()
{
_DMChannels = CollectionsUtils.CreateImmutableDictionary<ulong, DMChannel>();
_guilds = CollectionsUtils.CreateImmutableDictionary<ulong, Guild>();
}

Expand All @@ -19,38 +18,26 @@ public GatewayClientCache(JsonGatewayClientCache jsonModel, ulong clientId, Rest
var userModel = jsonModel.User;
if (userModel is not null)
_user = new(userModel, client);
_DMChannels = jsonModel.DMChannels.ToImmutableDictionary(c => c.Id, c => DMChannel.CreateFromJson(c, client));
_guilds = jsonModel.Guilds.ToImmutableDictionary(g => g.Id, g => new Guild(g, clientId, client));
}

public CurrentUser? User => _user;
public IReadOnlyDictionary<ulong, DMChannel> DMChannels => _DMChannels;
public IReadOnlyDictionary<ulong, Guild> Guilds => _guilds;

#pragma warning disable IDE0032 // Use auto property
private CurrentUser? _user;
#pragma warning restore IDE0032 // Use auto property
private ImmutableDictionary<ulong, DMChannel> _DMChannels;
private ImmutableDictionary<ulong, Guild> _guilds;

public JsonGatewayClientCache ToJsonModel()
{
return new()
{
User = _user is null ? null : ((IJsonModel<JsonUser>)_user).JsonModel,
DMChannels = _DMChannels.Select(p => ((IJsonModel<JsonChannel>)p.Value).JsonModel).ToArray(),
Guilds = _guilds.Select(p => ((IJsonModel<JsonGuild>)p.Value).JsonModel).ToArray(),
};
}

public IGatewayClientCache CacheDMChannel(DMChannel dMChannel)
{
return this with
{
_DMChannels = _DMChannels.SetItem(dMChannel.Id, dMChannel),
};
}

public IGatewayClientCache CacheGuild(Guild guild)
{
return this with
Expand Down
1 change: 0 additions & 1 deletion NetCord/Gateway/GatewayClientConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ public class GatewayClientConfiguration : IWebSocketClientConfiguration
public int? LargeThreshold { get; init; }
public PresenceProperties? Presence { get; init; }
public Shard? Shard { get; init; }
public bool? CacheDMChannels { get; init; }
public RestClientConfiguration? RestClientConfiguration { get; init; }

IRateLimiterProvider? IWebSocketClientConfiguration.RateLimiterProvider => RateLimiterProvider is { } rateLimiter ? rateLimiter : new GatewayRateLimiterProvider(120, 60_000);
Expand Down
Loading

0 comments on commit 6293d3d

Please sign in to comment.