Search in sources :

Example 1 with QueryVisitor

use of graphql.analysis.QueryVisitor in project graphql-java by graphql-java.

the class Anonymizer method rewriteQuery.

private static String rewriteQuery(String query, GraphQLSchema schema, Map<GraphQLNamedSchemaElement, String> newNames, Map<String, Object> variables) {
    AtomicInteger fragmentCounter = new AtomicInteger(1);
    AtomicInteger variableCounter = new AtomicInteger(1);
    Map<Node, String> astNodeToNewName = new LinkedHashMap<>();
    Map<String, String> variableNames = new LinkedHashMap<>();
    Map<Field, GraphQLFieldDefinition> fieldToFieldDefinition = new LinkedHashMap<>();
    Document document = new Parser().parseDocument(query);
    assertUniqueOperation(document);
    QueryTraverser queryTraverser = QueryTraverser.newQueryTraverser().document(document).schema(schema).variables(variables).build();
    queryTraverser.visitDepthFirst(new QueryVisitor() {

        @Override
        public void visitField(QueryVisitorFieldEnvironment env) {
            if (env.isTypeNameIntrospectionField()) {
                return;
            }
            fieldToFieldDefinition.put(env.getField(), env.getFieldDefinition());
            String newName = assertNotNull(newNames.get(env.getFieldDefinition()));
            Field field = env.getField();
            astNodeToNewName.put(field, newName);
            List<Directive> directives = field.getDirectives();
            for (Directive directive : directives) {
                // this is a directive definition
                GraphQLDirective directiveDefinition = assertNotNull(schema.getDirective(directive.getName()), () -> format("%s directive definition not found ", directive.getName()));
                String directiveName = directiveDefinition.getName();
                String newDirectiveName = assertNotNull(newNames.get(directiveDefinition), () -> format("No new name found for directive %s", directiveName));
                astNodeToNewName.put(directive, newDirectiveName);
                for (Argument argument : directive.getArguments()) {
                    GraphQLArgument argumentDefinition = directiveDefinition.getArgument(argument.getName());
                    String newArgumentName = assertNotNull(newNames.get(argumentDefinition), () -> format("%s no new name found for directive argument %s %s", directiveName, argument.getName()));
                    astNodeToNewName.put(argument, newArgumentName);
                    visitDirectiveArgumentValues(directive, argument.getValue());
                }
            }
        }

        private void visitDirectiveArgumentValues(Directive directive, Value value) {
            if (value instanceof VariableReference) {
                String name = ((VariableReference) value).getName();
                if (!variableNames.containsKey(name)) {
                    String newName = "var" + variableCounter.getAndIncrement();
                    variableNames.put(name, newName);
                }
            }
        }

        @Override
        public void visitInlineFragment(QueryVisitorInlineFragmentEnvironment queryVisitorInlineFragmentEnvironment) {
        }

        @Override
        public TraversalControl visitArgumentValue(QueryVisitorFieldArgumentValueEnvironment environment) {
            QueryVisitorFieldArgumentInputValue argumentInputValue = environment.getArgumentInputValue();
            if (argumentInputValue.getValue() instanceof VariableReference) {
                String name = ((VariableReference) argumentInputValue.getValue()).getName();
                if (!variableNames.containsKey(name)) {
                    String newName = "var" + variableCounter.getAndIncrement();
                    variableNames.put(name, newName);
                }
            }
            return CONTINUE;
        }

        @Override
        public void visitFragmentSpread(QueryVisitorFragmentSpreadEnvironment queryVisitorFragmentSpreadEnvironment) {
            FragmentDefinition fragmentDefinition = queryVisitorFragmentSpreadEnvironment.getFragmentDefinition();
            String newName;
            if (!astNodeToNewName.containsKey(fragmentDefinition)) {
                newName = "Fragment" + fragmentCounter.getAndIncrement();
                astNodeToNewName.put(fragmentDefinition, newName);
            } else {
                newName = astNodeToNewName.get(fragmentDefinition);
            }
            astNodeToNewName.put(queryVisitorFragmentSpreadEnvironment.getFragmentSpread(), newName);
        }

        @Override
        public TraversalControl visitArgument(QueryVisitorFieldArgumentEnvironment environment) {
            String newName = assertNotNull(newNames.get(environment.getGraphQLArgument()));
            astNodeToNewName.put(environment.getArgument(), newName);
            return CONTINUE;
        }
    });
    AtomicInteger stringValueCounter = new AtomicInteger(1);
    AtomicInteger intValueCounter = new AtomicInteger(1);
    AstTransformer astTransformer = new AstTransformer();
    AtomicInteger aliasCounter = new AtomicInteger(1);
    AtomicInteger defaultStringValueCounter = new AtomicInteger(1);
    AtomicInteger defaultIntValueCounter = new AtomicInteger(1);
    Document newDocument = (Document) astTransformer.transform(document, new NodeVisitorStub() {

        @Override
        public TraversalControl visitDirective(Directive directive, TraverserContext<Node> context) {
            String newName = assertNotNull(astNodeToNewName.get(directive));
            GraphQLDirective directiveDefinition = schema.getDirective(directive.getName());
            context.setVar(GraphQLDirective.class, directiveDefinition);
            return changeNode(context, directive.transform(builder -> builder.name(newName)));
        }

        @Override
        public TraversalControl visitOperationDefinition(OperationDefinition node, TraverserContext<Node> context) {
            if (node.getName() != null) {
                return changeNode(context, node.transform(builder -> builder.name("operation")));
            } else {
                return CONTINUE;
            }
        }

        @Override
        public TraversalControl visitField(Field field, TraverserContext<Node> context) {
            String newAlias = null;
            if (field.getAlias() != null) {
                newAlias = "alias" + aliasCounter.getAndIncrement();
            }
            String newName;
            if (field.getName().equals(Introspection.TypeNameMetaFieldDef.getName())) {
                newName = Introspection.TypeNameMetaFieldDef.getName();
            } else {
                newName = assertNotNull(astNodeToNewName.get(field));
                context.setVar(GraphQLFieldDefinition.class, assertNotNull(fieldToFieldDefinition.get(field)));
            }
            String finalNewAlias = newAlias;
            return changeNode(context, field.transform(builder -> builder.name(newName).alias(finalNewAlias)));
        }

        @Override
        public TraversalControl visitVariableDefinition(VariableDefinition node, TraverserContext<Node> context) {
            String newName = assertNotNull(variableNames.get(node.getName()));
            VariableDefinition newNode = node.transform(builder -> {
                builder.name(newName).comments(Collections.emptyList());
                // convert variable language type to renamed language type
                TypeName typeName = TypeUtil.unwrapAll(node.getType());
                GraphQLNamedType originalType = schema.getTypeAs(typeName.getName());
                // has the type name changed? (standard scalars such as String don't change)
                if (newNames.containsKey(originalType)) {
                    String newTypeName = newNames.get(originalType);
                    builder.type(replaceTypeName(node.getType(), newTypeName));
                }
                if (node.getDefaultValue() != null) {
                    Value<?> defaultValueLiteral = node.getDefaultValue();
                    GraphQLType graphQLType = fromTypeToGraphQLType(node.getType(), schema);
                    builder.defaultValue(replaceValue(defaultValueLiteral, (GraphQLInputType) graphQLType, newNames, defaultStringValueCounter, defaultIntValueCounter));
                }
            });
            return changeNode(context, newNode);
        }

        @Override
        public TraversalControl visitVariableReference(VariableReference node, TraverserContext<Node> context) {
            String newName = assertNotNull(variableNames.get(node.getName()), () -> format("No new variable name found for %s", node.getName()));
            return changeNode(context, node.transform(builder -> builder.name(newName)));
        }

        @Override
        public TraversalControl visitFragmentDefinition(FragmentDefinition node, TraverserContext<Node> context) {
            String newName = assertNotNull(astNodeToNewName.get(node));
            GraphQLType currentCondition = assertNotNull(schema.getType(node.getTypeCondition().getName()));
            String newCondition = newNames.get(currentCondition);
            return changeNode(context, node.transform(builder -> builder.name(newName).typeCondition(new TypeName(newCondition))));
        }

        @Override
        public TraversalControl visitInlineFragment(InlineFragment node, TraverserContext<Node> context) {
            GraphQLType currentCondition = assertNotNull(schema.getType(node.getTypeCondition().getName()));
            String newCondition = newNames.get(currentCondition);
            return changeNode(context, node.transform(builder -> builder.typeCondition(new TypeName(newCondition))));
        }

        @Override
        public TraversalControl visitFragmentSpread(FragmentSpread node, TraverserContext<Node> context) {
            String newName = assertNotNull(astNodeToNewName.get(node));
            return changeNode(context, node.transform(builder -> builder.name(newName)));
        }

        @Override
        public TraversalControl visitArgument(Argument argument, TraverserContext<Node> context) {
            GraphQLArgument graphQLArgumentDefinition;
            // An argument is either from a applied query directive or from a field
            if (context.getVarFromParents(GraphQLDirective.class) != null) {
                GraphQLDirective directiveDefinition = context.getVarFromParents(GraphQLDirective.class);
                graphQLArgumentDefinition = directiveDefinition.getArgument(argument.getName());
            } else {
                GraphQLFieldDefinition graphQLFieldDefinition = assertNotNull(context.getVarFromParents(GraphQLFieldDefinition.class));
                graphQLArgumentDefinition = graphQLFieldDefinition.getArgument(argument.getName());
            }
            GraphQLInputType argumentType = graphQLArgumentDefinition.getType();
            String newName = assertNotNull(astNodeToNewName.get(argument));
            Value newValue = replaceValue(argument.getValue(), argumentType, newNames, defaultStringValueCounter, defaultIntValueCounter);
            return changeNode(context, argument.transform(builder -> builder.name(newName).value(newValue)));
        }
    });
    return AstPrinter.printAstCompact(newDocument);
}
Also used : OperationDefinition(graphql.language.OperationDefinition) Value(graphql.language.Value) QueryVisitorInlineFragmentEnvironment(graphql.analysis.QueryVisitorInlineFragmentEnvironment) ValuesResolver(graphql.execution.ValuesResolver) GraphQLInputObjectType(graphql.schema.GraphQLInputObjectType) QueryTraverser(graphql.analysis.QueryTraverser) GraphQLFieldDefinition(graphql.schema.GraphQLFieldDefinition) FragmentSpread(graphql.language.FragmentSpread) GraphQLInterfaceType(graphql.schema.GraphQLInterfaceType) GraphQLInputObjectField(graphql.schema.GraphQLInputObjectField) GraphQLUnionType(graphql.schema.GraphQLUnionType) GraphQLEnumValueDefinition(graphql.schema.GraphQLEnumValueDefinition) GraphQLNamedSchemaElement(graphql.schema.GraphQLNamedSchemaElement) GraphQLTypeUtil.unwrapNonNullAs(graphql.schema.GraphQLTypeUtil.unwrapNonNullAs) Directives(graphql.Directives) GraphQLAppliedDirective(graphql.schema.GraphQLAppliedDirective) QueryVisitorFieldEnvironment(graphql.analysis.QueryVisitorFieldEnvironment) Type(graphql.language.Type) DirectiveInfo(graphql.schema.idl.DirectiveInfo) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Map(java.util.Map) QueryVisitorFragmentSpreadEnvironment(graphql.analysis.QueryVisitorFragmentSpreadEnvironment) BigInteger(java.math.BigInteger) TypeName(graphql.language.TypeName) TypeResolver(graphql.schema.TypeResolver) GraphQLObjectType(graphql.schema.GraphQLObjectType) GraphQLDirective(graphql.schema.GraphQLDirective) GraphQLNamedOutputType(graphql.schema.GraphQLNamedOutputType) GraphQLNonNull(graphql.schema.GraphQLNonNull) ObjectField(graphql.language.ObjectField) GraphQLInputType(graphql.schema.GraphQLInputType) Set(java.util.Set) GraphQLArgument(graphql.schema.GraphQLArgument) String.format(java.lang.String.format) AstPrinter(graphql.language.AstPrinter) List(java.util.List) QueryVisitor(graphql.analysis.QueryVisitor) ArrayValue(graphql.language.ArrayValue) Optional(java.util.Optional) FragmentDefinition(graphql.language.FragmentDefinition) NonNullType(graphql.language.NonNullType) GraphQLEnumType(graphql.schema.GraphQLEnumType) ListType(graphql.language.ListType) AstTransformer(graphql.language.AstTransformer) ObjectValue(graphql.language.ObjectValue) GraphQLSchemaElement(graphql.schema.GraphQLSchemaElement) GraphQLCodeRegistry(graphql.schema.GraphQLCodeRegistry) GraphQLNamedType(graphql.schema.GraphQLNamedType) Node(graphql.language.Node) SchemaTransformer(graphql.schema.SchemaTransformer) TreeTransformerUtil.changeNode(graphql.util.TreeTransformerUtil.changeNode) GraphQLScalarType(graphql.schema.GraphQLScalarType) QueryVisitorFieldArgumentEnvironment(graphql.analysis.QueryVisitorFieldArgumentEnvironment) QueryVisitorFieldArgumentInputValue(graphql.analysis.QueryVisitorFieldArgumentInputValue) EnumValue(graphql.language.EnumValue) HashMap(java.util.HashMap) GraphQLType(graphql.schema.GraphQLType) ArrayList(java.util.ArrayList) Introspection(graphql.introspection.Introspection) LinkedHashMap(java.util.LinkedHashMap) Scalars(graphql.Scalars) Parser(graphql.parser.Parser) Definition(graphql.language.Definition) VariableReference(graphql.language.VariableReference) BiConsumer(java.util.function.BiConsumer) GraphQLSchema(graphql.schema.GraphQLSchema) GraphQLAppliedDirectiveArgument(graphql.schema.GraphQLAppliedDirectiveArgument) QueryVisitorFieldArgumentValueEnvironment(graphql.analysis.QueryVisitorFieldArgumentValueEnvironment) LinkedHashSet(java.util.LinkedHashSet) NodeVisitorStub(graphql.language.NodeVisitorStub) SchemaGenerator.createdMockedSchema(graphql.schema.idl.SchemaGenerator.createdMockedSchema) ScalarInfo(graphql.schema.idl.ScalarInfo) CONTINUE(graphql.util.TraversalControl.CONTINUE) GraphQLArgument.newArgument(graphql.schema.GraphQLArgument.newArgument) Field(graphql.language.Field) GraphQLImplementingType(graphql.schema.GraphQLImplementingType) GraphQLTypeVisitorStub(graphql.schema.GraphQLTypeVisitorStub) SchemaUtil(graphql.schema.impl.SchemaUtil) AssertException(graphql.AssertException) Directive(graphql.language.Directive) Consumer(java.util.function.Consumer) Argument(graphql.language.Argument) Document(graphql.language.Document) TypeUtil(graphql.schema.idl.TypeUtil) VariableDefinition(graphql.language.VariableDefinition) GraphQLList(graphql.schema.GraphQLList) StringValue(graphql.language.StringValue) GraphQLTypeReference(graphql.schema.GraphQLTypeReference) GraphQLTypeUtil.unwrapNonNull(graphql.schema.GraphQLTypeUtil.unwrapNonNull) PublicApi(graphql.PublicApi) IntValue(graphql.language.IntValue) Assert.assertNotNull(graphql.Assert.assertNotNull) GraphQLTypeVisitor(graphql.schema.GraphQLTypeVisitor) GraphQLTypeUtil.unwrapOneAs(graphql.schema.GraphQLTypeUtil.unwrapOneAs) InlineFragment(graphql.language.InlineFragment) Collections(java.util.Collections) QueryVisitor(graphql.analysis.QueryVisitor) TypeName(graphql.language.TypeName) GraphQLArgument(graphql.schema.GraphQLArgument) GraphQLAppliedDirectiveArgument(graphql.schema.GraphQLAppliedDirectiveArgument) GraphQLArgument.newArgument(graphql.schema.GraphQLArgument.newArgument) Argument(graphql.language.Argument) VariableDefinition(graphql.language.VariableDefinition) QueryVisitorFieldEnvironment(graphql.analysis.QueryVisitorFieldEnvironment) Node(graphql.language.Node) TreeTransformerUtil.changeNode(graphql.util.TreeTransformerUtil.changeNode) GraphQLType(graphql.schema.GraphQLType) GraphQLFieldDefinition(graphql.schema.GraphQLFieldDefinition) Document(graphql.language.Document) LinkedHashMap(java.util.LinkedHashMap) GraphQLInputType(graphql.schema.GraphQLInputType) GraphQLInputObjectField(graphql.schema.GraphQLInputObjectField) ObjectField(graphql.language.ObjectField) Field(graphql.language.Field) List(java.util.List) ArrayList(java.util.ArrayList) GraphQLList(graphql.schema.GraphQLList) GraphQLNamedType(graphql.schema.GraphQLNamedType) OperationDefinition(graphql.language.OperationDefinition) QueryVisitorFieldArgumentEnvironment(graphql.analysis.QueryVisitorFieldArgumentEnvironment) VariableReference(graphql.language.VariableReference) FragmentDefinition(graphql.language.FragmentDefinition) QueryVisitorFragmentSpreadEnvironment(graphql.analysis.QueryVisitorFragmentSpreadEnvironment) GraphQLArgument(graphql.schema.GraphQLArgument) GraphQLDirective(graphql.schema.GraphQLDirective) QueryVisitorFieldArgumentValueEnvironment(graphql.analysis.QueryVisitorFieldArgumentValueEnvironment) Parser(graphql.parser.Parser) FragmentSpread(graphql.language.FragmentSpread) QueryVisitorFieldArgumentInputValue(graphql.analysis.QueryVisitorFieldArgumentInputValue) AstTransformer(graphql.language.AstTransformer) QueryTraverser(graphql.analysis.QueryTraverser) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) QueryVisitorInlineFragmentEnvironment(graphql.analysis.QueryVisitorInlineFragmentEnvironment) Value(graphql.language.Value) ArrayValue(graphql.language.ArrayValue) ObjectValue(graphql.language.ObjectValue) QueryVisitorFieldArgumentInputValue(graphql.analysis.QueryVisitorFieldArgumentInputValue) EnumValue(graphql.language.EnumValue) StringValue(graphql.language.StringValue) IntValue(graphql.language.IntValue) NodeVisitorStub(graphql.language.NodeVisitorStub) GraphQLAppliedDirective(graphql.schema.GraphQLAppliedDirective) GraphQLDirective(graphql.schema.GraphQLDirective) Directive(graphql.language.Directive) InlineFragment(graphql.language.InlineFragment)

Aggregations

Assert.assertNotNull (graphql.Assert.assertNotNull)1 AssertException (graphql.AssertException)1 Directives (graphql.Directives)1 PublicApi (graphql.PublicApi)1 Scalars (graphql.Scalars)1 QueryTraverser (graphql.analysis.QueryTraverser)1 QueryVisitor (graphql.analysis.QueryVisitor)1 QueryVisitorFieldArgumentEnvironment (graphql.analysis.QueryVisitorFieldArgumentEnvironment)1 QueryVisitorFieldArgumentInputValue (graphql.analysis.QueryVisitorFieldArgumentInputValue)1 QueryVisitorFieldArgumentValueEnvironment (graphql.analysis.QueryVisitorFieldArgumentValueEnvironment)1 QueryVisitorFieldEnvironment (graphql.analysis.QueryVisitorFieldEnvironment)1 QueryVisitorFragmentSpreadEnvironment (graphql.analysis.QueryVisitorFragmentSpreadEnvironment)1 QueryVisitorInlineFragmentEnvironment (graphql.analysis.QueryVisitorInlineFragmentEnvironment)1 ValuesResolver (graphql.execution.ValuesResolver)1 Introspection (graphql.introspection.Introspection)1 Argument (graphql.language.Argument)1 ArrayValue (graphql.language.ArrayValue)1 AstPrinter (graphql.language.AstPrinter)1 AstTransformer (graphql.language.AstTransformer)1 Definition (graphql.language.Definition)1