diff --git a/src/Miki.Framework.Commands/CommandPipeline.cs b/src/Miki.Framework.Commands/CommandPipeline.cs index 51b5c5f..9083173 100644 --- a/src/Miki.Framework.Commands/CommandPipeline.cs +++ b/src/Miki.Framework.Commands/CommandPipeline.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Threading.Tasks; using Miki.Discord.Common; + using Miki.Framework.Models; using Miki.Framework.Commands.Pipelines; using Miki.Logging; @@ -28,7 +29,7 @@ internal CommandPipeline( public async ValueTask ExecuteAsync(IDiscordMessage data) { var sw = Stopwatch.StartNew(); - using ContextObject contextObj = new ContextObject(services); + using ContextObject contextObj = new ContextObject(services, new DiscordMessage(data)); int index = 0; Func nextFunc = null; diff --git a/src/Miki.Framework.Commands/CommandPipelineBuilder.cs b/src/Miki.Framework.Commands/CommandPipelineBuilder.cs index 61b4e02..ce59093 100644 --- a/src/Miki.Framework.Commands/CommandPipelineBuilder.cs +++ b/src/Miki.Framework.Commands/CommandPipelineBuilder.cs @@ -9,12 +9,14 @@ /// public class CommandPipelineBuilder { + private readonly List stages = new List(); + /// /// Services that can be used throughout the command pipeline. /// public IServiceProvider Services { get; } - private readonly List stages = new List(); + public IReadOnlyList Stages => stages; /// /// Creates a new CommandPipelineBuilder. @@ -42,5 +44,14 @@ public CommandPipelineBuilder UseStage(IPipelineStage stage) stages.Add(stage); return this; } + + /// + /// Initializes a pipeline stage as a runnable stage in the pipeline. + /// + public CommandPipelineBuilder UseStage() + where T : class, IPipelineStage + { + return UseStage(Services.GetOrCreateService()); + } } } diff --git a/src/Miki.Framework.Commands/CommandTreeBuilder.cs b/src/Miki.Framework.Commands/CommandTreeBuilder.cs index 19580d1..16f75a1 100644 --- a/src/Miki.Framework.Commands/CommandTreeBuilder.cs +++ b/src/Miki.Framework.Commands/CommandTreeBuilder.cs @@ -1,136 +1,52 @@ -namespace Miki.Framework.Commands +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Miki.Framework.Commands { - using Miki.Framework.Commands.Nodes; - using System; - using System.Collections.Generic; - using System.Linq; - using System.Reflection; - using System.Threading.Tasks; - public class CommandTreeBuilder { - [Obsolete("use 'CommandTreeBuilder::AddCommandBuildStep()' intead")] - public event Action OnContainerLoaded; + private readonly List types = new List(); - private readonly IServiceProvider services; - - private List buildSteps; - - public CommandTreeBuilder(IServiceProvider services) + public CommandTreeBuilder(IServiceCollection services) { - this.services = services; + Services = services; } - - public CommandTreeBuilder AddCommandBuildStep(ICommandBuildStep buildStep) - { - if(buildSteps == null) - { - buildSteps = new List(); - } - buildSteps.Add(buildStep); - return this; - } - - public CommandTree Create(Assembly assembly) + + public IServiceCollection Services { get; } + + public CommandTreeBuilder AddType(Type type) { - var allTypes = assembly.GetTypes() - .Where(x => x.GetCustomAttribute() != null); - var root = new CommandTree(); - foreach(var t in allTypes) - { - var module = LoadModule(t, root.Root); - if(module != null) - { - root.Root.Children.Add(module); - } - } - return root; - } - - private NodeContainer LoadModule(Type t, NodeContainer parent) - { - var moduleAttrib = t.GetCustomAttribute(); - if(moduleAttrib == null) - { - throw new InvalidOperationException("Modules must have a valid ModuleAttribute."); - } - - NodeContainer module = new NodeModule(moduleAttrib.Name, parent, services, t); - OnContainerLoaded?.Invoke(module, services); - - var allCommands = t.GetNestedTypes(BindingFlags.NonPublic | BindingFlags.Public) - .Where(x => x.GetCustomAttribute() != null); - foreach(var c in allCommands) - { - module.Children.Add(LoadCommand(c, module)); - } - - var allSingleCommands = t.GetMethods() - .Where(x => x.GetCustomAttribute() != null); - foreach(var c in allSingleCommands) - { - module.Children.Add(LoadCommand(c, module)); - } - - return module; + types.Add(type); + return this; } - private Node LoadCommand(Type t, NodeContainer parent) - { - var commandAttrib = t.GetCustomAttribute(); - if(commandAttrib == null) - { - throw new InvalidOperationException( - $"Multi command of type '{t.ToString()}' must have a valid CommandAttribute."); - } - - if(commandAttrib.Aliases?.Count() == 0) - { - throw new InvalidOperationException( - $"Multi commands cannot have an invalid name."); - } - - var multiCommand = new NodeNestedExecutable(commandAttrib.AsMetadata(), parent, services, t); - OnContainerLoaded?.Invoke(multiCommand, services); - - var allCommands = t.GetNestedTypes() - .Where(x => x.GetCustomAttribute() != null); - foreach(var c in allCommands) - { - multiCommand.Children.Add(LoadCommand(c, multiCommand)); - } - - var allSingleCommands = t.GetMethods() - .Where(x => x.GetCustomAttribute() != null); - foreach(var c in allSingleCommands) - { - var attrib = c.GetCustomAttribute(); - if(attrib.Aliases == null - || attrib.Aliases.Count() == 0) - { - var node = LoadCommand(c, multiCommand); - if(node is IExecutable execNode) - { - multiCommand.SetDefaultExecution(async (e) - => await execNode.ExecuteAsync(e)); - } - } - else - { - multiCommand.Children.Add(LoadCommand(c, multiCommand)); - } - } - return multiCommand; - } - private Node LoadCommand(MethodInfo m, NodeContainer parent) - { - var commandAttrib = m.GetCustomAttribute(); - var command = new NodeExecutable(commandAttrib.AsMetadata(), parent, m); + + public CommandTreeBuilder AddAssembly(Assembly assembly) + { + types.AddRange(assembly.GetTypes().Where(x => x.GetCustomAttribute() != null)); + return this; + } + + public CommandTree Build(IServiceProvider provider) + { + var root = new CommandTree(); + var compiler = ActivatorUtilities.CreateInstance(provider); + + foreach (var type in types) + { + var module = compiler.LoadModule(type, root.Root); + + if (module != null) + { + root.Root.Children.Add(module); + } + } + + return root; + } - if(m.ReturnType != typeof(Task)) - { - throw new Exception("Methods with attribute 'Command' require to be Tasks."); - } - return command; - } } } diff --git a/src/Miki.Framework.Commands/CommandTreeCompiler.cs b/src/Miki.Framework.Commands/CommandTreeCompiler.cs new file mode 100644 index 0000000..c929cc9 --- /dev/null +++ b/src/Miki.Framework.Commands/CommandTreeCompiler.cs @@ -0,0 +1,179 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Threading.Tasks; +using Miki.Framework.Commands.Nodes; +using Miki.Framework.Hosting; +using Miki.Framework.Models; + +namespace Miki.Framework.Commands +{ + internal class CommandTreeCompiler + { + private static readonly PropertyInfo TaskCompleteProperty = typeof(Task).GetProperty(nameof(Task.CompletedTask)); + + private readonly IReadOnlyList parameterProviders; + + public CommandTreeCompiler(IEnumerable parameterProviders) + { + this.parameterProviders = parameterProviders.ToArray(); + } + + public NodeContainer LoadModule(Type t, NodeContainer parent) + { + var moduleAttrib = t.GetCustomAttribute(); + if(moduleAttrib == null) + { + throw new InvalidOperationException("Modules must have a valid ModuleAttribute."); + } + + NodeContainer module = new NodeModule(moduleAttrib.Name, parent, t); + + var allCommands = t.GetNestedTypes(BindingFlags.NonPublic | BindingFlags.Public) + .Where(x => x.GetCustomAttribute() != null); + foreach(var c in allCommands) + { + module.Children.Add(LoadCommand(c, module)); + } + + var allSingleCommands = t.GetMethods() + .Where(x => x.GetCustomAttribute() != null); + foreach(var c in allSingleCommands) + { + module.Children.Add(LoadCommand(c, module)); + } + + return module; + } + + private Node LoadCommand(Type t, NodeContainer parent) + { + var commandAttrib = t.GetCustomAttribute(); + if(commandAttrib == null) + { + throw new InvalidOperationException( + $"Multi command of type '{t}' must have a valid CommandAttribute."); + } + + if(commandAttrib.Aliases?.Count == 0) + { + throw new InvalidOperationException( + $"Multi commands cannot have an invalid name."); + } + + var multiCommand = new NodeNestedExecutable(commandAttrib.AsMetadata(), parent, t); + + var allCommands = t.GetNestedTypes() + .Where(x => x.GetCustomAttribute() != null); + foreach(var c in allCommands) + { + multiCommand.Children.Add(LoadCommand(c, multiCommand)); + } + + var allSingleCommands = t.GetMethods() + .Where(x => x.GetCustomAttribute() != null); + foreach(var c in allSingleCommands) + { + var attrib = c.GetCustomAttribute(); + if(attrib.Aliases == null + || attrib.Aliases.Count() == 0) + { + var node = LoadCommand(c, multiCommand); + if(node is IExecutable execNode) + { + multiCommand.SetDefaultExecution(async (e) + => await execNode.ExecuteAsync(e)); + } + } + else + { + multiCommand.Children.Add(LoadCommand(c, multiCommand)); + } + } + return multiCommand; + } + + private Node LoadCommand(MethodInfo m, NodeContainer parent) + { + var commandAttrib = m.GetCustomAttribute(); + var command = new NodeExecutable(commandAttrib.AsMetadata(), parent, CreateDelegate(parent.Type, m)); + + if(m.ReturnType != typeof(Task)) + { + throw new Exception("Methods with attribute 'Command' require to be Tasks."); + } + return command; + } + + private CommandDelegate CreateDelegate(Type type, MethodInfo methodInfo) + { + var context = Expression.Parameter(typeof(IContext)); + var builder = new ParameterBuilder(context); + var constructors = type.GetConstructors(); + var module = constructors.Length switch + { + 0 => Expression.New(type), + 1 => Expression.New(constructors[0], GetParameterValues(builder, constructors[0])), + _ => throw new NotSupportedException($"The module {type} has multiple constructors") + }; + var parameterValues = GetParameterValues(builder, methodInfo); + var returnType = methodInfo.ReturnType; + Expression result = Expression.Call(module, methodInfo, parameterValues); + + if (returnType == typeof(void)) + { + var taskComplete = Expression.Property(null, TaskCompleteProperty); + var returnTarget = Expression.Label(typeof(Task)); + + result = Expression.Block(new[] + { + result, + Expression.Return(returnTarget, taskComplete), + Expression.Label(returnTarget, taskComplete) + }); + } + else if (returnType != typeof(Task)) + { + throw new InvalidOperationException($"Method {type.Name}.{methodInfo.Name} should return Task or void."); + } + + return Expression.Lambda(result, context).Compile(); + } + + private Expression[] GetParameterValues(ParameterBuilder builder, MethodBase methodInfo) + { + var parameters = methodInfo.GetParameters(); + var parameterValues = new Expression[parameters.Length]; + + for (var i = 0; i < parameters.Length; i++) + { + var paramType = parameters[i].ParameterType; + var provider = parameterProviders.FirstOrDefault(p => paramType.IsAssignableFrom(p.ParameterType)); + Expression paramExpression; + + if (provider != null) + { + paramExpression = provider.Provide(builder); + } + else if (paramType == typeof(IContext)) + { + paramExpression = builder.Context; + } + else if (paramType == typeof(IMessage)) + { + paramExpression = Expression.Property(builder.Context, nameof(IContext.Message)); + } + else + { + paramExpression = builder.GetService(paramType); + } + + parameterValues[i] = paramExpression; + } + + return parameterValues; + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework.Commands/Extensions/MiddlewareExtensions.cs b/src/Miki.Framework.Commands/Extensions/MiddlewareExtensions.cs new file mode 100644 index 0000000..1af65c7 --- /dev/null +++ b/src/Miki.Framework.Commands/Extensions/MiddlewareExtensions.cs @@ -0,0 +1,59 @@ +using System; +using System.Collections.Generic; +using Microsoft.EntityFrameworkCore.Internal; +using Miki.Discord.Common; +using Miki.Framework.Commands.Pipelines; +using Miki.Framework.Hosting; + +namespace Miki.Framework.Commands +{ + public static class MiddlewareExtensions + { + private const string CoreStageRegistered = "CoreStageRegistered"; + + private static IBotApplicationBuilder UseStageInternal(IBotApplicationBuilder app, IPipelineStage stage) + { + return app.Use(next => + { + return context => stage.CheckAsync( + (IDiscordMessage) context.Message.InnerMessage, + (IMutableContext) context, + () => next(context)); + }); + } + + public static IBotApplicationBuilder UseStage(this IBotApplicationBuilder app, IPipelineStage stage) + { + if (!app.Properties.TryGetValue(CoreStageRegistered, out var value) || !Equals(value, true)) + { + if (!(stage is CorePipelineStage)) + { + UseStageInternal(app, new CorePipelineStage()); + } + + app.Properties[CoreStageRegistered] = true; + } + + return UseStageInternal(app, stage); + } + + public static IBotApplicationBuilder UseStage(this IBotApplicationBuilder app) + where T : IPipelineStage + { + return UseStage(app, app.ApplicationServices.GetOrCreateService()); + } + + public static IBotApplicationBuilder UseCommandPipeline(this IBotApplicationBuilder app, Action configure) + { + var builder = new CommandPipelineBuilder(app.ApplicationServices); + configure(builder); + + foreach (var stage in builder.Stages) + { + UseStage(app, stage); + } + + return app; + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework.Commands/Extensions/ServiceExtensions.cs b/src/Miki.Framework.Commands/Extensions/ServiceExtensions.cs new file mode 100644 index 0000000..27345a2 --- /dev/null +++ b/src/Miki.Framework.Commands/Extensions/ServiceExtensions.cs @@ -0,0 +1,22 @@ +using System; +using System.Reflection; +using Microsoft.Extensions.DependencyInjection; + +namespace Miki.Framework.Commands +{ + public static class ServiceExtensions + { + public static IServiceCollection AddCommands(this IServiceCollection services, Assembly assembly) + { + return AddCommands(services, builder => builder.AddAssembly(assembly)); + } + + public static IServiceCollection AddCommands(this IServiceCollection services, Action configure) + { + var builder = new CommandTreeBuilder(services); + configure(builder); + services.AddSingleton(provider => builder.Build(provider)); + return services; + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework.Commands/Models/Nodes/Node.cs b/src/Miki.Framework.Commands/Models/Nodes/Node.cs index ebab718..9408d2c 100644 --- a/src/Miki.Framework.Commands/Models/Nodes/Node.cs +++ b/src/Miki.Framework.Commands/Models/Nodes/Node.cs @@ -16,15 +16,15 @@ public abstract class Node public IReadOnlyCollection Attributes => Type.GetCustomAttributes(false) .ToList(); - private MemberInfo Type { get; } + public Type Type { get; } - protected Node(CommandMetadata metadata, MemberInfo type) + protected Node(CommandMetadata metadata, Type type) { Metadata = metadata; Type = type; } - protected Node(CommandMetadata metadata, NodeContainer parent, MemberInfo type) + protected Node(CommandMetadata metadata, NodeContainer parent, Type type) : this(metadata, type) { Parent = parent ?? throw new InvalidOperationException("Parent cannot be null when explicitly set up."); diff --git a/src/Miki.Framework.Commands/Models/Nodes/NodeContainer.cs b/src/Miki.Framework.Commands/Models/Nodes/NodeContainer.cs index d4dd7bb..b203d36 100644 --- a/src/Miki.Framework.Commands/Models/Nodes/NodeContainer.cs +++ b/src/Miki.Framework.Commands/Models/Nodes/NodeContainer.cs @@ -10,22 +10,13 @@ public abstract class NodeContainer : Node { public List Children = new List(); - /// - /// Instance object for reflection. - /// - public object Instance { get; } - public NodeContainer(CommandMetadata metadata, Type t) : base(metadata, t) { } - public NodeContainer(CommandMetadata metadata, NodeContainer parent, IServiceProvider provider, Type t) + public NodeContainer(CommandMetadata metadata, NodeContainer parent, Type t) : base(metadata, parent, t) { - if(t != null) - { - Instance = ActivatorUtilities.CreateInstance(provider, t); - } } public virtual Node FindCommand(IArgumentPack pack) diff --git a/src/Miki.Framework.Commands/Models/Nodes/NodeExecutable.cs b/src/Miki.Framework.Commands/Models/Nodes/NodeExecutable.cs index d4f67be..52d9102 100644 --- a/src/Miki.Framework.Commands/Models/Nodes/NodeExecutable.cs +++ b/src/Miki.Framework.Commands/Models/Nodes/NodeExecutable.cs @@ -1,32 +1,26 @@ -namespace Miki.Framework.Commands.Nodes +using System.Linq.Expressions; + using Miki.Framework.Models; + + namespace Miki.Framework.Commands.Nodes { using System; using System.Linq; using System.Reflection; using System.Threading.Tasks; -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member 'CommandDelegate' public delegate Task CommandDelegate(IContext c); -#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member 'CommandDelegate' -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member 'NodeExecutable' public class NodeExecutable : Node, IExecutable -#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member 'NodeExecutable' { private readonly CommandDelegate runAsync; -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member 'NodeExecutable.NodeExecutable(CommandMetadata, NodeContainer, MethodInfo)' - public NodeExecutable(CommandMetadata metadata, NodeContainer parent, MethodInfo method) -#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member 'NodeExecutable.NodeExecutable(CommandMetadata, NodeContainer, MethodInfo)' - : base(metadata, parent, method) + public NodeExecutable(CommandMetadata metadata, NodeContainer parent, CommandDelegate commandDelegate) + : base(metadata, parent, parent.Type) { - runAsync = (CommandDelegate)Delegate.CreateDelegate( - typeof(CommandDelegate), parent.Instance, method, true); + runAsync = commandDelegate; } -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member 'NodeExecutable.ExecuteAsync(IContext)' public async ValueTask ExecuteAsync(IContext e) -#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member 'NodeExecutable.ExecuteAsync(IContext)' { if(runAsync == null) { diff --git a/src/Miki.Framework.Commands/Models/Nodes/NodeModule.cs b/src/Miki.Framework.Commands/Models/Nodes/NodeModule.cs index aa6ca37..fe5d5f8 100644 --- a/src/Miki.Framework.Commands/Models/Nodes/NodeModule.cs +++ b/src/Miki.Framework.Commands/Models/Nodes/NodeModule.cs @@ -5,10 +5,10 @@ public class NodeModule : NodeContainer { public NodeModule(string id, IServiceProvider provider, Type t) - : this(id, null, provider, t) + : this(id, (NodeContainer) null, t) { } - public NodeModule(string id, NodeContainer parent, IServiceProvider provider, Type t) - : base(new CommandMetadata { Identifiers = new[] { id } }, parent, provider, t) + public NodeModule(string id, NodeContainer parent, Type t) + : base(new CommandMetadata { Identifiers = new[] { id } }, parent, t) { } } } diff --git a/src/Miki.Framework.Commands/Models/Nodes/NodeNestedExecutable.cs b/src/Miki.Framework.Commands/Models/Nodes/NodeNestedExecutable.cs index 9171190..4b947fc 100644 --- a/src/Miki.Framework.Commands/Models/Nodes/NodeNestedExecutable.cs +++ b/src/Miki.Framework.Commands/Models/Nodes/NodeNestedExecutable.cs @@ -23,23 +23,21 @@ public NodeNestedExecutable( CommandMetadata metadata, IServiceProvider builder, Type t) - : this(metadata, null, builder, t) + : this(metadata, (NodeContainer) null, t) { } - /// - /// Creates a new Nested, Executable Node. - /// - /// Command properties. - /// - /// - /// - public NodeNestedExecutable( + /// + /// Creates a new Nested, Executable Node. + /// + /// Command properties. + /// + /// + public NodeNestedExecutable( CommandMetadata metadata, NodeContainer parent, - IServiceProvider builder, Type t) - : base(metadata, parent, builder, t) + : base(metadata, parent, t) { } diff --git a/src/Miki.Framework.Commands/Stages/CorePipelineStage.cs b/src/Miki.Framework.Commands/Stages/CorePipelineStage.cs index 6a14601..a39aa5a 100644 --- a/src/Miki.Framework.Commands/Stages/CorePipelineStage.cs +++ b/src/Miki.Framework.Commands/Stages/CorePipelineStage.cs @@ -1,4 +1,6 @@ -namespace Miki.Framework.Commands +using System; + +namespace Miki.Framework.Commands { using Miki.Discord.Common; using Miki.Framework.Commands.Pipelines; diff --git a/src/Miki.Framework.Discord/Extensions/DiscordServiceExtensions.cs b/src/Miki.Framework.Discord/Extensions/DiscordServiceExtensions.cs new file mode 100644 index 0000000..c6d7191 --- /dev/null +++ b/src/Miki.Framework.Discord/Extensions/DiscordServiceExtensions.cs @@ -0,0 +1,40 @@ +using System; +using Microsoft.Extensions.DependencyInjection; +using Miki.Discord.Common; +using Miki.Framework.Discord.Factories; +using Miki.Framework.Discord.Providers; +using Miki.Framework.Discord.Services; +using Miki.Framework.Hosting; + +// ReSharper disable once CheckNamespace +namespace Miki.Framework +{ + public static class DiscordServiceExtensions + { + public static IServiceCollection AddDiscord(this IServiceCollection services, Type factoryType, params object[] factoryArguments) + { + services.AddSingleton(); + + services.AddHostedService(provider => + { + var factory = (IDiscordClientFactory) ActivatorUtilities.CreateInstance(provider, factoryType, factoryArguments); + + return ActivatorUtilities.CreateInstance(provider, factory); + }); + + return services; + } + + public static IServiceCollection AddDiscord(this IServiceCollection services, params object[] factoryArguments) + where TFactory : IDiscordClientFactory + { + return AddDiscord(services, typeof(TFactory), factoryArguments); + } + + public static IServiceCollection AddDiscord(this IServiceCollection services, DiscordToken token) + { + AddDiscord(services, token); + return services; + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework.Discord/Factories/DefaultDiscordClientFactory.cs b/src/Miki.Framework.Discord/Factories/DefaultDiscordClientFactory.cs new file mode 100644 index 0000000..ee94849 --- /dev/null +++ b/src/Miki.Framework.Discord/Factories/DefaultDiscordClientFactory.cs @@ -0,0 +1,44 @@ +using System.Threading.Tasks; +using Miki.Cache; +using Miki.Discord; +using Miki.Discord.Common; +using Miki.Discord.Gateway; +using Miki.Discord.Rest; + +namespace Miki.Framework.Discord.Factories +{ + public class DefaultDiscordClientFactory : IDiscordClientFactory + { + private readonly IExtendedCacheClient cacheClient; + private readonly DiscordToken token; + + public DefaultDiscordClientFactory(IExtendedCacheClient cacheClient, DiscordToken token) + { + this.cacheClient = cacheClient; + this.token = token; + } + + public Task CreateClientAsync() + { + var gateway = new GatewayShard(new GatewayProperties + { + ShardCount = 1, + ShardId = 0, + Token = token.Token, + AllowNonDispatchEvents = true, + Intents = GatewayIntents.AllDefault | GatewayIntents.GuildMembers + }); + + var apiClient = new DiscordApiClient(token, cacheClient); + + var configuration = new DiscordClientConfigurations + { + Gateway = gateway, + ApiClient = apiClient, + CacheClient = cacheClient + }; + + return Task.FromResult(new DiscordClient(configuration)); + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework.Discord/Factories/IDiscordClientFactory.cs b/src/Miki.Framework.Discord/Factories/IDiscordClientFactory.cs new file mode 100644 index 0000000..d5b235c --- /dev/null +++ b/src/Miki.Framework.Discord/Factories/IDiscordClientFactory.cs @@ -0,0 +1,10 @@ +using System.Threading.Tasks; +using Miki.Discord.Common; + +namespace Miki.Framework.Discord.Factories +{ + public interface IDiscordClientFactory + { + Task CreateClientAsync(); + } +} \ No newline at end of file diff --git a/src/Miki.Framework.Discord/Miki.Framework.Discord.csproj b/src/Miki.Framework.Discord/Miki.Framework.Discord.csproj new file mode 100644 index 0000000..8749edc --- /dev/null +++ b/src/Miki.Framework.Discord/Miki.Framework.Discord.csproj @@ -0,0 +1,31 @@ + + + + netstandard2.1 + Debug;Release;Debug Production;Prod + + + + true + Velddev + + true + 2.3.1 + https://github.com/mikibot/miki.framework + https://github.com/mikibot/miki.framework + git + Velddev + 8.0 + + + + + + + + + + + + + diff --git a/src/Miki.Framework.Discord/Providers/DiscordClientParameterProvider.cs b/src/Miki.Framework.Discord/Providers/DiscordClientParameterProvider.cs new file mode 100644 index 0000000..b962ac7 --- /dev/null +++ b/src/Miki.Framework.Discord/Providers/DiscordClientParameterProvider.cs @@ -0,0 +1,17 @@ +using System; +using System.Linq.Expressions; +using Miki.Discord.Common; +using Miki.Framework.Hosting; + +namespace Miki.Framework.Discord.Providers +{ + public class DiscordClientParameterProvider : IParameterProvider + { + public Type ParameterType => typeof(IDiscordClient); + + public Expression Provide(ParameterBuilder context) + { + return context.GetContext(typeof(IDiscordClient), "DiscordClient"); + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework.Discord/Services/DiscordHostedService.cs b/src/Miki.Framework.Discord/Services/DiscordHostedService.cs new file mode 100644 index 0000000..ca8c9b8 --- /dev/null +++ b/src/Miki.Framework.Discord/Services/DiscordHostedService.cs @@ -0,0 +1,72 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Miki.Discord.Common; +using Miki.Framework.Discord.Factories; +using Miki.Framework.Hosting; +using Miki.Framework.Models; +using Miki.Logging; + +namespace Miki.Framework.Discord.Services +{ + public class DiscordHostedService : IHostedService, IDisposable + { + private readonly IServiceProvider serviceProvider; + private readonly MessageDelegate invoke; + private readonly IDiscordClientFactory discordClientFactory; + private IDiscordClient discordClient; + + public DiscordHostedService( + IBotApplicationBuilderFactory factory, + IServiceProvider serviceProvider, + IDiscordClientFactory discordClientFactory) + { + this.serviceProvider = serviceProvider; + this.discordClientFactory = discordClientFactory; + invoke = factory.CreateBuilder().Build(); + } + + private async Task HandleMessageAsync(IDiscordMessage message) + { + using var scope = serviceProvider.CreateScope(); + using var context = new ContextObject(scope.ServiceProvider, new DiscordMessage(message)); + + context.SetContext("DiscordClient", discordClient); + + try + { + await invoke(context); + } + catch (Exception e) + { + Log.Error(e); + } + } + + public async Task StartAsync(CancellationToken cancellationToken) + { + if (discordClient == null) + { + discordClient = await discordClientFactory.CreateClientAsync(); + discordClient.MessageCreate += HandleMessageAsync; + } + + await discordClient.Gateway.StartAsync(); + } + + public Task StopAsync(CancellationToken cancellationToken) + { + return discordClient?.Gateway.StopAsync() ?? Task.CompletedTask; + } + + public void Dispose() + { + if (discordClient != null) + { + discordClient.MessageCreate -= HandleMessageAsync; + } + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework.sln b/src/Miki.Framework.sln index 99ed09a..8daf26a 100644 --- a/src/Miki.Framework.sln +++ b/src/Miki.Framework.sln @@ -27,6 +27,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Miki.Framework.Commands.Fil EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tests", "tests", "{4BE94594-5456-49B4-AA9D-4BBDCD77BCE3}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Miki.Framework.Discord", "Miki.Framework.Discord\Miki.Framework.Discord.csproj", "{B213EE14-D5E9-41CD-B53A-8F3ECDEE515F}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug Production|Any CPU = Debug Production|Any CPU @@ -123,6 +125,14 @@ Global {C8F412EB-645B-42EA-B21E-5A04D813D930}.Prod|Any CPU.Build.0 = Debug|Any CPU {C8F412EB-645B-42EA-B21E-5A04D813D930}.Release|Any CPU.ActiveCfg = Release|Any CPU {C8F412EB-645B-42EA-B21E-5A04D813D930}.Release|Any CPU.Build.0 = Release|Any CPU + {B213EE14-D5E9-41CD-B53A-8F3ECDEE515F}.Debug Production|Any CPU.ActiveCfg = Debug Production|Any CPU + {B213EE14-D5E9-41CD-B53A-8F3ECDEE515F}.Debug Production|Any CPU.Build.0 = Debug Production|Any CPU + {B213EE14-D5E9-41CD-B53A-8F3ECDEE515F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B213EE14-D5E9-41CD-B53A-8F3ECDEE515F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B213EE14-D5E9-41CD-B53A-8F3ECDEE515F}.Prod|Any CPU.ActiveCfg = Prod|Any CPU + {B213EE14-D5E9-41CD-B53A-8F3ECDEE515F}.Prod|Any CPU.Build.0 = Prod|Any CPU + {B213EE14-D5E9-41CD-B53A-8F3ECDEE515F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B213EE14-D5E9-41CD-B53A-8F3ECDEE515F}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/Miki.Framework/ContextObject.cs b/src/Miki.Framework/ContextObject.cs index 8d9d9cc..d7b0bf9 100644 --- a/src/Miki.Framework/ContextObject.cs +++ b/src/Miki.Framework/ContextObject.cs @@ -1,4 +1,7 @@ -namespace Miki.Framework +using Miki.Discord.Common; +using Miki.Framework.Models; + +namespace Miki.Framework { using Microsoft.Extensions.DependencyInjection; using System; @@ -9,6 +12,11 @@ /// public interface IContext { + /// + /// The message received from discord. + /// + IMessage Message { get; } + /// /// The command executed in this current session. /// @@ -45,6 +53,9 @@ public class ContextObject : IMutableContext, IDisposable public IServiceProvider Services => scope.ServiceProvider; + /// + public IMessage Message { get; } + /// /// Current set Executable. /// @@ -53,8 +64,9 @@ public IServiceProvider Services /// /// Creates a scoped context object /// - public ContextObject(IServiceProvider p) + public ContextObject(IServiceProvider p, IMessage message) { + Message = message; contextObjects = new Dictionary(); scope = p.CreateScope(); } diff --git a/src/Miki.Framework/Extension/HostExtensions.cs b/src/Miki.Framework/Extension/HostExtensions.cs new file mode 100644 index 0000000..4734dbb --- /dev/null +++ b/src/Miki.Framework/Extension/HostExtensions.cs @@ -0,0 +1,22 @@ +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Miki.Framework.Hosting; + +namespace Miki.Framework +{ + public static class HostExtensions + { + public static IHostBuilder ConfigureBot( + this IHostBuilder hostBuilder, + Action configure = null) + { + hostBuilder.ConfigureServices((ctx, services) => + { + services.AddSingleton(provider => new BotApplicationBuilderFactory(provider, hostBuilder, configure)); + }); + + return hostBuilder; + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework/Extension/MiddlewareExtensions.cs b/src/Miki.Framework/Extension/MiddlewareExtensions.cs new file mode 100644 index 0000000..d9e6ebe --- /dev/null +++ b/src/Miki.Framework/Extension/MiddlewareExtensions.cs @@ -0,0 +1,152 @@ +using System; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Miki.Framework.Hosting; + +namespace Miki.Framework +{ + using Predicate = Func; + using PredicateAsync = Func>; + + public static class MiddlewareExtensions + { + /// + /// Adds a middleware delegate defined in-line to the application's request pipeline. + /// + /// The instance. + /// A function that handles the request or calls the given next function. + /// The instance. + public static IBotApplicationBuilder Use(this IBotApplicationBuilder app, Func, ValueTask> middleware) + { + if (app == null) + { + throw new ArgumentNullException(nameof(app)); + } + + if (middleware == null) + { + throw new ArgumentNullException(nameof(app)); + } + + return app.Use(next => + { + return context => + { + return middleware(context, () => next(context)); + }; + }); + } + + /// + /// Adds a terminal middleware delegate to the application's request pipeline. + /// + /// The instance. + /// A delegate that handles the request. + public static void Run(this IBotApplicationBuilder app, MessageDelegate handler) + { + if (app == null) + { + throw new ArgumentNullException(nameof(app)); + } + + if (handler == null) + { + throw new ArgumentNullException(nameof(handler)); + } + + app.Use(_ => handler); + } + + /// + /// Conditionally creates a branch in the request pipeline that is rejoined to the main pipeline. + /// + /// + /// Invoked with the request environment to determine if the branch should be taken + /// Configures a branch to take + /// + public static IBotApplicationBuilder UseWhen(this IBotApplicationBuilder app, Predicate predicate, Action configuration) + { + if (app == null) + { + throw new ArgumentNullException(nameof(app)); + } + + if (predicate == null) + { + throw new ArgumentNullException(nameof(predicate)); + } + + if (configuration == null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + // Create and configure the branch builder right away; otherwise, + // we would end up running our branch after all the components + // that were subsequently added to the main builder. + var branchBuilder = app.New(); + configuration(branchBuilder); + + return app.Use(main => + { + // This is called only when the main application builder + // is built, not per request. + branchBuilder.Run(main); + var branch = branchBuilder.Build(); + + return context => predicate(context) ? branch(context) : main(context); + }); + } + + /// + /// Conditionally creates a branch in the request pipeline that is rejoined to the main pipeline. + /// + /// + /// Invoked with the request environment to determine if the branch should be taken + /// Configures a branch to take + /// + public static IBotApplicationBuilder UseWhen(this IBotApplicationBuilder app, PredicateAsync predicate, Action configuration) + { + if (app == null) + { + throw new ArgumentNullException(nameof(app)); + } + + if (predicate == null) + { + throw new ArgumentNullException(nameof(predicate)); + } + + if (configuration == null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + // Create and configure the branch builder right away; otherwise, + // we would end up running our branch after all the components + // that were subsequently added to the main builder. + var branchBuilder = app.New(); + configuration(branchBuilder); + + return app.Use(main => + { + // This is called only when the main application builder + // is built, not per request. + branchBuilder.Run(main); + var branch = branchBuilder.Build(); + + return async context => + { + if (await predicate(context)) + { + await branch(context); + } + else + { + await main(context); + } + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework/Extension/ServiceExtensions.cs b/src/Miki.Framework/Extension/ServiceExtensions.cs new file mode 100644 index 0000000..b9d875e --- /dev/null +++ b/src/Miki.Framework/Extension/ServiceExtensions.cs @@ -0,0 +1,24 @@ +using System; +using Microsoft.Extensions.DependencyInjection; + using Miki.Cache; + using Miki.Discord.Common; +using Miki.Framework.Hosting; + +namespace Miki.Framework +{ + public static class ServiceCollectionExtensions + { + public static T GetOrCreateService(this IServiceProvider provider) + { + return provider.GetService() ?? ActivatorUtilities.CreateInstance(provider); + } + + public static IServiceCollection AddCacheClient(this IServiceCollection services) + where T : class, IExtendedCacheClient + { + services.AddSingleton(provider => provider.GetService()); + services.AddSingleton(); + return services; + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework/Hosting/BotApplicationBuilder.cs b/src/Miki.Framework/Hosting/BotApplicationBuilder.cs new file mode 100644 index 0000000..c4ca577 --- /dev/null +++ b/src/Miki.Framework/Hosting/BotApplicationBuilder.cs @@ -0,0 +1,64 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Miki.Framework.Utils; + +namespace Miki.Framework.Hosting +{ + public class BotApplicationBuilder : IBotApplicationBuilder + { + private const string ApplicationServicesKey = "Miki.ApplicationServices"; + + private readonly IList> components = new List>(); + + public BotApplicationBuilder() + { + Properties = new Dictionary(); + } + + public BotApplicationBuilder(IBotApplicationBuilder builder) + { + Properties = new CopyOnWriteDictionary(builder.Properties, StringComparer.Ordinal); + } + + public IServiceProvider ApplicationServices + { + get => GetProperty(ApplicationServicesKey); + set => SetProperty(ApplicationServicesKey, value); + } + + public IDictionary Properties { get; } + + public T GetProperty(string key) + { + return Properties.TryGetValue(key, out var value) ? (T) value : default; + } + + public void SetProperty(string key, T value) + { + Properties[key] = value; + } + + public IBotApplicationBuilder Use(Func middleware) + { + components.Add(middleware); + return this; + } + + private static ValueTask InvokeAsync(IContext context) + { + return context.Executable?.ExecuteAsync(context) ?? default; + } + + public MessageDelegate Build() + { + return components.Reverse().Aggregate((MessageDelegate) InvokeAsync, (current, component) => component(current)); + } + + public IBotApplicationBuilder New() + { + return new BotApplicationBuilder(this); + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework/Hosting/BotApplicationBuilderFactory.cs b/src/Miki.Framework/Hosting/BotApplicationBuilderFactory.cs new file mode 100644 index 0000000..a58b3e9 --- /dev/null +++ b/src/Miki.Framework/Hosting/BotApplicationBuilderFactory.cs @@ -0,0 +1,107 @@ +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace Miki.Framework.Hosting +{ + public class BotApplicationBuilderFactory : IBotApplicationBuilderFactory + { + private readonly Action configure; + private readonly IServiceProvider serviceProvider; + private readonly IHostBuilder hostBuilder; + + public BotApplicationBuilderFactory( + IServiceProvider serviceProvider, + IHostBuilder hostBuilder = null, + Action configure = null) + { + this.configure = configure; + this.serviceProvider = serviceProvider; + this.hostBuilder = hostBuilder; + } + + public IBotApplicationBuilder CreateBuilder() + { + var builder = new BotApplicationBuilder + { + ApplicationServices = serviceProvider + }; + + if (configure != null) + { + configure.Invoke(builder); + } + else if (hostBuilder != null + && hostBuilder.Properties.TryGetValue("UseStartup.StartupType", out var value) + && value is Type startupType) + { + InitializeStartup(serviceProvider, startupType, builder); + } + + return builder; + } + + /// + /// Configure the pipeline through the ASP.NET Core startup class. + /// + private static bool InitializeStartup(IServiceProvider provider, Type startupType, IBotApplicationBuilder builder) + { + var startupInterfaceType = Type.GetType("Microsoft.AspNetCore.Hosting.IStartup, Microsoft.AspNetCore.Hosting.Abstractions"); + object startup; + + if (startupInterfaceType != null) + { + startup = provider.GetService(startupInterfaceType); + + if (startup != null) + { + startupType = startup.GetType(); + } + } + else + { + startup = null; + } + + if (startup == null) + { + if (startupType != null) + { + startup = ActivatorUtilities.CreateInstance(provider, startupType); + } + else + { + return false; + } + } + + var configureMethod = startupType.GetMethod("ConfigureBot"); + + if (configureMethod == null) + { + return false; + } + + var parameterInfos = configureMethod.GetParameters(); + var parameters = new object[parameterInfos.Length]; + + for (var i = 0; i < parameterInfos.Length; i++) + { + var parameterType = parameterInfos[i].ParameterType; + + if (parameterType == typeof(IBotApplicationBuilder)) + { + parameters[i] = builder; + } + else + { + parameters[i] = provider.GetRequiredService(parameterType); + } + } + + configureMethod.Invoke(startup, parameters); + + return true; + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework/Hosting/IBotApplicationBuilder.cs b/src/Miki.Framework/Hosting/IBotApplicationBuilder.cs new file mode 100644 index 0000000..a1402de --- /dev/null +++ b/src/Miki.Framework/Hosting/IBotApplicationBuilder.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace Miki.Framework.Hosting +{ + public delegate ValueTask MessageDelegate(IContext context); + + public interface IBotApplicationBuilder + { + /// + /// Gets or sets the that provides access to the application's service container. + /// + IServiceProvider ApplicationServices { get; set; } + + /// + /// Gets a key/value collection that can be used to share data between middleware. + /// + IDictionary Properties { get; } + + /// + /// Get the property by its key. + /// + /// The type of the property. + /// The key of the property. + /// The value of the property. + T GetProperty(string key); + + /// + /// Set the property by its key. + /// + /// The type of the property. + /// The key of the property. + /// The new value. + void SetProperty(string key, T value); + + /// + /// Adds a middleware delegate to the application's request pipeline. + /// + /// The delegate middleware. + /// The current application builder. + IBotApplicationBuilder Use(Func middleware); + + /// + /// Builds the delegate used by this application to process Discord messages. + /// + /// + MessageDelegate Build(); + + /// + /// Create a sub-builder. + /// + IBotApplicationBuilder New(); + } +} \ No newline at end of file diff --git a/src/Miki.Framework/Hosting/IBotApplicationBuilderFactory.cs b/src/Miki.Framework/Hosting/IBotApplicationBuilderFactory.cs new file mode 100644 index 0000000..25fafb5 --- /dev/null +++ b/src/Miki.Framework/Hosting/IBotApplicationBuilderFactory.cs @@ -0,0 +1,7 @@ +namespace Miki.Framework.Hosting +{ + public interface IBotApplicationBuilderFactory + { + IBotApplicationBuilder CreateBuilder(); + } +} \ No newline at end of file diff --git a/src/Miki.Framework/Hosting/IParameterProvider.cs b/src/Miki.Framework/Hosting/IParameterProvider.cs new file mode 100644 index 0000000..eff8d0a --- /dev/null +++ b/src/Miki.Framework/Hosting/IParameterProvider.cs @@ -0,0 +1,12 @@ +using System; +using System.Linq.Expressions; + +namespace Miki.Framework.Hosting +{ + public interface IParameterProvider + { + Type ParameterType { get; } + + Expression Provide(ParameterBuilder context); + } +} \ No newline at end of file diff --git a/src/Miki.Framework/Hosting/ParameterBuilder.cs b/src/Miki.Framework/Hosting/ParameterBuilder.cs new file mode 100644 index 0000000..c55d64b --- /dev/null +++ b/src/Miki.Framework/Hosting/ParameterBuilder.cs @@ -0,0 +1,35 @@ +using System; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; + +namespace Miki.Framework.Hosting +{ + public class ParameterBuilder + { + private static readonly MethodInfo GetServiceMethod = typeof(ContextExtensions) + .GetMethods() + .First(m => m.Name == nameof(ContextExtensions.GetService) && m.IsGenericMethod); + + private static readonly MethodInfo GetContextMethod = typeof(ContextExtensions) + .GetMethods() + .First(m => m.Name == nameof(ContextExtensions.GetContext) && m.IsGenericMethod); + + public ParameterBuilder(Expression context) + { + Context = context; + } + + public Expression Context { get; } + + public Expression GetService(Type type) + { + return Expression.Call(GetServiceMethod.MakeGenericMethod(type), Context); + } + + public Expression GetContext(Type type, string name) + { + return Expression.Call(GetContextMethod.MakeGenericMethod(type), Context, Expression.Constant(name)); + } + } +} \ No newline at end of file diff --git a/src/Miki.Framework/Miki.Framework.csproj b/src/Miki.Framework/Miki.Framework.csproj index 02c9f4b..60f6e6c 100644 --- a/src/Miki.Framework/Miki.Framework.csproj +++ b/src/Miki.Framework/Miki.Framework.csproj @@ -45,6 +45,7 @@ + diff --git a/src/Miki.Framework/Models/DiscordMessage.cs b/src/Miki.Framework/Models/DiscordMessage.cs new file mode 100644 index 0000000..ac0d948 --- /dev/null +++ b/src/Miki.Framework/Models/DiscordMessage.cs @@ -0,0 +1,18 @@ +using Miki.Discord.Common; + +namespace Miki.Framework.Models +{ + public class DiscordMessage : IMessage + { + private readonly IDiscordMessage message; + + public DiscordMessage(IDiscordMessage message) + { + this.message = message; + } + + public object InnerMessage => message; + + public string Content => message.Content; + } +} \ No newline at end of file diff --git a/src/Miki.Framework/Models/IChannel.cs b/src/Miki.Framework/Models/IChannel.cs deleted file mode 100644 index 4bb3c4a..0000000 --- a/src/Miki.Framework/Models/IChannel.cs +++ /dev/null @@ -1,9 +0,0 @@ -namespace Miki.Framework.Models -{ - using System.Threading.Tasks; - - public interface IChannel - { - Task CreateMessageAsync(string content); - } -} diff --git a/src/Miki.Framework/Models/IMessage.cs b/src/Miki.Framework/Models/IMessage.cs index 656c329..7d2a6d0 100644 --- a/src/Miki.Framework/Models/IMessage.cs +++ b/src/Miki.Framework/Models/IMessage.cs @@ -1,13 +1,9 @@ namespace Miki.Framework.Models { - using System.Threading.Tasks; - public interface IMessage { - Task DeleteAsync(); - - Task GetChannelAsync(); - - Task ModifyAsync(string content); + object InnerMessage { get; } + + string Content { get; } } } diff --git a/src/Miki.Framework/TestContextObject.cs b/src/Miki.Framework/TestContextObject.cs index afcfc02..75c3779 100644 --- a/src/Miki.Framework/TestContextObject.cs +++ b/src/Miki.Framework/TestContextObject.cs @@ -1,4 +1,7 @@ -namespace Miki.Framework +using Miki.Discord.Common; +using Miki.Framework.Models; + +namespace Miki.Framework { using System; using System.Collections.Generic; @@ -11,6 +14,9 @@ public class TestContextObject : IMutableContext private readonly Dictionary contextObjects = new Dictionary(); private readonly Dictionary serviceObjects = new Dictionary(); + /// + public IMessage Message { get; set; } + /// public IExecutable Executable { get; set; } diff --git a/src/Miki.Framework/Utils/CopyOnWriteDictionary.cs b/src/Miki.Framework/Utils/CopyOnWriteDictionary.cs new file mode 100644 index 0000000..adae12e --- /dev/null +++ b/src/Miki.Framework/Utils/CopyOnWriteDictionary.cs @@ -0,0 +1,155 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Miki.Framework.Utils +{ + internal class CopyOnWriteDictionary : IDictionary + { + private readonly IDictionary sourceDictionary; + private readonly IEqualityComparer comparer; + private IDictionary innerDictionary; + + public CopyOnWriteDictionary( + IDictionary sourceDictionary, + IEqualityComparer comparer) + { + if (sourceDictionary == null) + { + throw new ArgumentNullException(nameof(sourceDictionary)); + } + + if (comparer == null) + { + throw new ArgumentNullException(nameof(comparer)); + } + + this.sourceDictionary = sourceDictionary; + this.comparer = comparer; + } + + private IDictionary ReadDictionary + { + get + { + return innerDictionary ?? sourceDictionary; + } + } + + private IDictionary WriteDictionary + { + get + { + if (innerDictionary == null) + { + innerDictionary = new Dictionary(sourceDictionary, + comparer); + } + + return innerDictionary; + } + } + + public virtual ICollection Keys + { + get + { + return ReadDictionary.Keys; + } + } + + public virtual ICollection Values + { + get + { + return ReadDictionary.Values; + } + } + + public virtual int Count + { + get + { + return ReadDictionary.Count; + } + } + + public virtual bool IsReadOnly + { + get + { + return false; + } + } + + public virtual TValue this[TKey key] + { + get + { + return ReadDictionary[key]; + } + set + { + WriteDictionary[key] = value; + } + } + + public virtual bool ContainsKey(TKey key) + { + return ReadDictionary.ContainsKey(key); + } + + public virtual void Add(TKey key, TValue value) + { + WriteDictionary.Add(key, value); + } + + public virtual bool Remove(TKey key) + { + return WriteDictionary.Remove(key); + } + + public virtual bool TryGetValue(TKey key, out TValue value) + { + return ReadDictionary.TryGetValue(key, out value); + } + + public virtual void Add(KeyValuePair item) + { + WriteDictionary.Add(item); + } + + public virtual void Clear() + { + WriteDictionary.Clear(); + } + + public virtual bool Contains(KeyValuePair item) + { + return ReadDictionary.Contains(item); + } + + public virtual void CopyTo(KeyValuePair[] array, int arrayIndex) + { + ReadDictionary.CopyTo(array, arrayIndex); + } + + public bool Remove(KeyValuePair item) + { + return WriteDictionary.Remove(item); + } + + public virtual IEnumerator> GetEnumerator() + { + return ReadDictionary.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} \ No newline at end of file