diff --git a/src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs b/src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs index 61cb32d..2fedd7c 100644 --- a/src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs +++ b/src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs @@ -206,6 +206,54 @@ public void fails_on_missing_claim_on_connection_type() }); } + [Fact] + public void passes_when_field_is_not_included() + { + Settings.AddPolicy("FieldPolicy", _ => _.RequireClaim("admin")); + + ShouldPassRule(_ => + { + _.Query = @"query { post @include(if: false) }"; + _.Schema = BasicSchema(); + }); + } + + [Fact] + public void fails_when_field_is_included() + { + Settings.AddPolicy("FieldPolicy", _ => _.RequireClaim("admin")); + + ShouldFailRule(_ => + { + _.Query = @"query { post @include(if: true) }"; + _.Schema = BasicSchema(); + }); + } + + [Fact] + public void passes_when_field_is_skipped() + { + Settings.AddPolicy("FieldPolicy", _ => _.RequireClaim("admin")); + + ShouldPassRule(_ => + { + _.Query = @"query { post @skip(if: true) }"; + _.Schema = BasicSchema(); + }); + } + + [Fact] + public void fails_when_field_is_not_skipped() + { + Settings.AddPolicy("FieldPolicy", _ => _.RequireClaim("admin")); + + ShouldFailRule(_ => + { + _.Query = @"query { post @skip(if: false) }"; + _.Schema = BasicSchema(); + }); + } + private static ISchema BasicSchema() { string defs = @" diff --git a/src/GraphQL.Authorization/AuthorizationValidationRule.cs b/src/GraphQL.Authorization/AuthorizationValidationRule.cs index 892c542..a6f4cd8 100644 --- a/src/GraphQL.Authorization/AuthorizationValidationRule.cs +++ b/src/GraphQL.Authorization/AuthorizationValidationRule.cs @@ -1,4 +1,6 @@ +using System.Linq; using System.Threading.Tasks; +using GraphQL.Execution; using GraphQL.Language.AST; using GraphQL.Types; using GraphQL.Validation; @@ -56,7 +58,7 @@ public Task ValidateAsync(ValidationContext context) { var fieldDef = context.TypeInfo.GetFieldDef(); - if (fieldDef == null) + if (fieldDef == null || SkipAuthCheck(fieldAst, context)) return; // check target field @@ -67,6 +69,46 @@ public Task ValidateAsync(ValidationContext context) )); } + private bool SkipAuthCheck(Field field, ValidationContext context) + { + if (field.Directives == null || !field.Directives.Any()) + return false; + + var operationName = context.OperationName; + var documentOperations = context.Document.Operations; + var operation = !string.IsNullOrWhiteSpace(operationName) + ? documentOperations.WithName(operationName) + : documentOperations.FirstOrDefault(); + var variables = ExecutionHelper.GetVariableValues(context.Document, context.Schema, + operation?.Variables, context.Inputs); + + var includeField = GetDirectiveValue(context, field.Directives, DirectiveGraphType.Include, variables); + if (includeField.HasValue) + return !includeField.Value; + + var skipField = GetDirectiveValue(context, field.Directives, DirectiveGraphType.Skip, variables); + if (skipField.HasValue) + return skipField.Value; + + return false; + } + + private static bool? GetDirectiveValue(ValidationContext context, Directives directives, DirectiveGraphType directiveType, Variables variables) + { + var directive = directives.Find(directiveType.Name); + if (directive == null) + return null; + + var argumentValues = ExecutionHelper.GetArgumentValues( + context.Schema, + directiveType.Arguments, + directive.Arguments, + variables); + + argumentValues.TryGetValue("if", out object ifObj); + return bool.TryParse(ifObj?.ToString() ?? string.Empty, out bool ifVal) && ifVal; + } + private void CheckAuth( INode node, IProvideMetadata provider,