use of graphql.analysis.QueryVisitorFragmentSpreadEnvironment in project graphql-java by graphql-java.
the class Anonymizer method rewriteQuery.
private static String rewriteQuery(String query, GraphQLSchema schema, Map<GraphQLNamedSchemaElement, String> newNames, Map<String, Object> variables) {
AtomicInteger fragmentCounter = new AtomicInteger(1);
AtomicInteger variableCounter = new AtomicInteger(1);
Map<Node, String> astNodeToNewName = new LinkedHashMap<>();
Map<String, String> variableNames = new LinkedHashMap<>();
Map<Field, GraphQLFieldDefinition> fieldToFieldDefinition = new LinkedHashMap<>();
Document document = new Parser().parseDocument(query);
assertUniqueOperation(document);
QueryTraverser queryTraverser = QueryTraverser.newQueryTraverser().document(document).schema(schema).variables(variables).build();
queryTraverser.visitDepthFirst(new QueryVisitor() {
@Override
public void visitField(QueryVisitorFieldEnvironment env) {
if (env.isTypeNameIntrospectionField()) {
return;
}
fieldToFieldDefinition.put(env.getField(), env.getFieldDefinition());
String newName = assertNotNull(newNames.get(env.getFieldDefinition()));
Field field = env.getField();
astNodeToNewName.put(field, newName);
List<Directive> directives = field.getDirectives();
for (Directive directive : directives) {
// this is a directive definition
GraphQLDirective directiveDefinition = assertNotNull(schema.getDirective(directive.getName()), () -> format("%s directive definition not found ", directive.getName()));
String directiveName = directiveDefinition.getName();
String newDirectiveName = assertNotNull(newNames.get(directiveDefinition), () -> format("No new name found for directive %s", directiveName));
astNodeToNewName.put(directive, newDirectiveName);
for (Argument argument : directive.getArguments()) {
GraphQLArgument argumentDefinition = directiveDefinition.getArgument(argument.getName());
String newArgumentName = assertNotNull(newNames.get(argumentDefinition), () -> format("%s no new name found for directive argument %s %s", directiveName, argument.getName()));
astNodeToNewName.put(argument, newArgumentName);
visitDirectiveArgumentValues(directive, argument.getValue());
}
}
}
private void visitDirectiveArgumentValues(Directive directive, Value value) {
if (value instanceof VariableReference) {
String name = ((VariableReference) value).getName();
if (!variableNames.containsKey(name)) {
String newName = "var" + variableCounter.getAndIncrement();
variableNames.put(name, newName);
}
}
}
@Override
public void visitInlineFragment(QueryVisitorInlineFragmentEnvironment queryVisitorInlineFragmentEnvironment) {
}
@Override
public TraversalControl visitArgumentValue(QueryVisitorFieldArgumentValueEnvironment environment) {
QueryVisitorFieldArgumentInputValue argumentInputValue = environment.getArgumentInputValue();
if (argumentInputValue.getValue() instanceof VariableReference) {
String name = ((VariableReference) argumentInputValue.getValue()).getName();
if (!variableNames.containsKey(name)) {
String newName = "var" + variableCounter.getAndIncrement();
variableNames.put(name, newName);
}
}
return CONTINUE;
}
@Override
public void visitFragmentSpread(QueryVisitorFragmentSpreadEnvironment queryVisitorFragmentSpreadEnvironment) {
FragmentDefinition fragmentDefinition = queryVisitorFragmentSpreadEnvironment.getFragmentDefinition();
String newName;
if (!astNodeToNewName.containsKey(fragmentDefinition)) {
newName = "Fragment" + fragmentCounter.getAndIncrement();
astNodeToNewName.put(fragmentDefinition, newName);
} else {
newName = astNodeToNewName.get(fragmentDefinition);
}
astNodeToNewName.put(queryVisitorFragmentSpreadEnvironment.getFragmentSpread(), newName);
}
@Override
public TraversalControl visitArgument(QueryVisitorFieldArgumentEnvironment environment) {
String newName = assertNotNull(newNames.get(environment.getGraphQLArgument()));
astNodeToNewName.put(environment.getArgument(), newName);
return CONTINUE;
}
});
AtomicInteger stringValueCounter = new AtomicInteger(1);
AtomicInteger intValueCounter = new AtomicInteger(1);
AstTransformer astTransformer = new AstTransformer();
AtomicInteger aliasCounter = new AtomicInteger(1);
AtomicInteger defaultStringValueCounter = new AtomicInteger(1);
AtomicInteger defaultIntValueCounter = new AtomicInteger(1);
Document newDocument = (Document) astTransformer.transform(document, new NodeVisitorStub() {
@Override
public TraversalControl visitDirective(Directive directive, TraverserContext<Node> context) {
String newName = assertNotNull(astNodeToNewName.get(directive));
GraphQLDirective directiveDefinition = schema.getDirective(directive.getName());
context.setVar(GraphQLDirective.class, directiveDefinition);
return changeNode(context, directive.transform(builder -> builder.name(newName)));
}
@Override
public TraversalControl visitOperationDefinition(OperationDefinition node, TraverserContext<Node> context) {
if (node.getName() != null) {
return changeNode(context, node.transform(builder -> builder.name("operation")));
} else {
return CONTINUE;
}
}
@Override
public TraversalControl visitField(Field field, TraverserContext<Node> context) {
String newAlias = null;
if (field.getAlias() != null) {
newAlias = "alias" + aliasCounter.getAndIncrement();
}
String newName;
if (field.getName().equals(Introspection.TypeNameMetaFieldDef.getName())) {
newName = Introspection.TypeNameMetaFieldDef.getName();
} else {
newName = assertNotNull(astNodeToNewName.get(field));
context.setVar(GraphQLFieldDefinition.class, assertNotNull(fieldToFieldDefinition.get(field)));
}
String finalNewAlias = newAlias;
return changeNode(context, field.transform(builder -> builder.name(newName).alias(finalNewAlias)));
}
@Override
public TraversalControl visitVariableDefinition(VariableDefinition node, TraverserContext<Node> context) {
String newName = assertNotNull(variableNames.get(node.getName()));
VariableDefinition newNode = node.transform(builder -> {
builder.name(newName).comments(Collections.emptyList());
// convert variable language type to renamed language type
TypeName typeName = TypeUtil.unwrapAll(node.getType());
GraphQLNamedType originalType = schema.getTypeAs(typeName.getName());
// has the type name changed? (standard scalars such as String don't change)
if (newNames.containsKey(originalType)) {
String newTypeName = newNames.get(originalType);
builder.type(replaceTypeName(node.getType(), newTypeName));
}
if (node.getDefaultValue() != null) {
Value<?> defaultValueLiteral = node.getDefaultValue();
GraphQLType graphQLType = fromTypeToGraphQLType(node.getType(), schema);
builder.defaultValue(replaceValue(defaultValueLiteral, (GraphQLInputType) graphQLType, newNames, defaultStringValueCounter, defaultIntValueCounter));
}
});
return changeNode(context, newNode);
}
@Override
public TraversalControl visitVariableReference(VariableReference node, TraverserContext<Node> context) {
String newName = assertNotNull(variableNames.get(node.getName()), () -> format("No new variable name found for %s", node.getName()));
return changeNode(context, node.transform(builder -> builder.name(newName)));
}
@Override
public TraversalControl visitFragmentDefinition(FragmentDefinition node, TraverserContext<Node> context) {
String newName = assertNotNull(astNodeToNewName.get(node));
GraphQLType currentCondition = assertNotNull(schema.getType(node.getTypeCondition().getName()));
String newCondition = newNames.get(currentCondition);
return changeNode(context, node.transform(builder -> builder.name(newName).typeCondition(new TypeName(newCondition))));
}
@Override
public TraversalControl visitInlineFragment(InlineFragment node, TraverserContext<Node> context) {
GraphQLType currentCondition = assertNotNull(schema.getType(node.getTypeCondition().getName()));
String newCondition = newNames.get(currentCondition);
return changeNode(context, node.transform(builder -> builder.typeCondition(new TypeName(newCondition))));
}
@Override
public TraversalControl visitFragmentSpread(FragmentSpread node, TraverserContext<Node> context) {
String newName = assertNotNull(astNodeToNewName.get(node));
return changeNode(context, node.transform(builder -> builder.name(newName)));
}
@Override
public TraversalControl visitArgument(Argument argument, TraverserContext<Node> context) {
GraphQLArgument graphQLArgumentDefinition;
// An argument is either from a applied query directive or from a field
if (context.getVarFromParents(GraphQLDirective.class) != null) {
GraphQLDirective directiveDefinition = context.getVarFromParents(GraphQLDirective.class);
graphQLArgumentDefinition = directiveDefinition.getArgument(argument.getName());
} else {
GraphQLFieldDefinition graphQLFieldDefinition = assertNotNull(context.getVarFromParents(GraphQLFieldDefinition.class));
graphQLArgumentDefinition = graphQLFieldDefinition.getArgument(argument.getName());
}
GraphQLInputType argumentType = graphQLArgumentDefinition.getType();
String newName = assertNotNull(astNodeToNewName.get(argument));
Value newValue = replaceValue(argument.getValue(), argumentType, newNames, defaultStringValueCounter, defaultIntValueCounter);
return changeNode(context, argument.transform(builder -> builder.name(newName).value(newValue)));
}
});
return AstPrinter.printAstCompact(newDocument);
}
Aggregations