<html><head><meta name="color-scheme" content="light dark"></head><body><pre style="word-wrap: break-word; white-space: pre-wrap;">from typing import cast, Any, Dict, List, Optional, Tuple, Union

from ...error import GraphQLError
from ...language import (
    DirectiveLocation,
    DirectiveDefinitionNode,
    DirectiveNode,
    Node,
    OperationDefinitionNode,
)
from ...type import specified_directives
from . import ASTValidationRule, SDLValidationContext, ValidationContext

__all__ = ["KnownDirectivesRule"]


class KnownDirectivesRule(ASTValidationRule):
    """Known directives

    A GraphQL document is only valid if all ``@directives`` are known by the schema and
    legally positioned.

    See https://spec.graphql.org/draft/#sec-Directives-Are-Defined
    """

    context: Union[ValidationContext, SDLValidationContext]

    def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
        super().__init__(context)
        locations_map: Dict[str, Tuple[DirectiveLocation, ...]] = {}

        schema = context.schema
        defined_directives = (
            schema.directives if schema else cast(List, specified_directives)
        )
        for directive in defined_directives:
            locations_map[directive.name] = directive.locations
        ast_definitions = context.document.definitions
        for def_ in ast_definitions:
            if isinstance(def_, DirectiveDefinitionNode):
                locations_map[def_.name.value] = tuple(
                    DirectiveLocation[name.value] for name in def_.locations
                )
        self.locations_map = locations_map

    def enter_directive(
        self,
        node: DirectiveNode,
        _key: Any,
        _parent: Any,
        _path: Any,
        ancestors: List[Node],
    ) -&gt; None:
        name = node.name.value
        locations = self.locations_map.get(name)
        if locations:
            candidate_location = get_directive_location_for_ast_path(ancestors)
            if candidate_location and candidate_location not in locations:
                self.report_error(
                    GraphQLError(
                        f"Directive '@{name}'"
                        f" may not be used on {candidate_location.value}.",
                        node,
                    )
                )
        else:
            self.report_error(GraphQLError(f"Unknown directive '@{name}'.", node))


_operation_location = {
    "query": DirectiveLocation.QUERY,
    "mutation": DirectiveLocation.MUTATION,
    "subscription": DirectiveLocation.SUBSCRIPTION,
}

_directive_location = {
    "field": DirectiveLocation.FIELD,
    "fragment_spread": DirectiveLocation.FRAGMENT_SPREAD,
    "inline_fragment": DirectiveLocation.INLINE_FRAGMENT,
    "fragment_definition": DirectiveLocation.FRAGMENT_DEFINITION,
    "variable_definition": DirectiveLocation.VARIABLE_DEFINITION,
    "schema_definition": DirectiveLocation.SCHEMA,
    "schema_extension": DirectiveLocation.SCHEMA,
    "scalar_type_definition": DirectiveLocation.SCALAR,
    "scalar_type_extension": DirectiveLocation.SCALAR,
    "object_type_definition": DirectiveLocation.OBJECT,
    "object_type_extension": DirectiveLocation.OBJECT,
    "field_definition": DirectiveLocation.FIELD_DEFINITION,
    "interface_type_definition": DirectiveLocation.INTERFACE,
    "interface_type_extension": DirectiveLocation.INTERFACE,
    "union_type_definition": DirectiveLocation.UNION,
    "union_type_extension": DirectiveLocation.UNION,
    "enum_type_definition": DirectiveLocation.ENUM,
    "enum_type_extension": DirectiveLocation.ENUM,
    "enum_value_definition": DirectiveLocation.ENUM_VALUE,
    "input_object_type_definition": DirectiveLocation.INPUT_OBJECT,
    "input_object_type_extension": DirectiveLocation.INPUT_OBJECT,
}


def get_directive_location_for_ast_path(
    ancestors: List[Node],
) -&gt; Optional[DirectiveLocation]:
    applied_to = ancestors[-1]
    if not isinstance(applied_to, Node):  # pragma: no cover
        raise TypeError("Unexpected error in directive.")
    kind = applied_to.kind
    if kind == "operation_definition":
        applied_to = cast(OperationDefinitionNode, applied_to)
        return _operation_location[applied_to.operation.value]
    elif kind == "input_value_definition":
        parent_node = ancestors[-3]
        return (
            DirectiveLocation.INPUT_FIELD_DEFINITION
            if parent_node.kind == "input_object_type_definition"
            else DirectiveLocation.ARGUMENT_DEFINITION
        )
    else:
        return _directive_location.get(kind)
</pre></body></html>