use of org.apache.spark.sql.sources.Filter in project Gaffer by gchq.
the class FiltersToOperationConverter method applyPropertyFilters.
private AbstractGetRDD<?> applyPropertyFilters(final View derivedView, final AbstractGetRDD<?> operation) {
final List<Set<String>> groupsRelatedToFilters = new ArrayList<>();
for (final Filter filter : filters) {
final Set<String> groupsRelatedToFilter = getGroupsFromFilter(filter);
if (groupsRelatedToFilter != null && !groupsRelatedToFilter.isEmpty()) {
groupsRelatedToFilters.add(groupsRelatedToFilter);
}
LOGGER.info("Groups {} are related to filter {}", StringUtils.join(groupsRelatedToFilter, ','), filter);
}
LOGGER.info("Groups related to filters are: {}", StringUtils.join(groupsRelatedToFilters, ','));
// Take the intersection of this list of groups - only these groups can be related to the query
final Set<String> intersection = new HashSet<>(derivedView.getEntityGroups());
intersection.addAll(derivedView.getEdgeGroups());
for (final Set<String> groupsRelatedToFilter : groupsRelatedToFilters) {
intersection.retainAll(groupsRelatedToFilter);
}
LOGGER.info("Groups that can be returned are: {}", StringUtils.join(intersection, ','));
// Update view with filters and add to operation
final Map<String, List<ConsumerFunctionContext<String, FilterFunction>>> groupToFunctions = new HashMap<>();
for (final Filter filter : filters) {
final Map<String, List<ConsumerFunctionContext<String, FilterFunction>>> map = getFunctionsFromFilter(filter);
for (final Entry<String, List<ConsumerFunctionContext<String, FilterFunction>>> entry : map.entrySet()) {
if (!groupToFunctions.containsKey(entry.getKey())) {
groupToFunctions.put(entry.getKey(), new ArrayList<ConsumerFunctionContext<String, FilterFunction>>());
}
groupToFunctions.get(entry.getKey()).addAll(entry.getValue());
}
}
LOGGER.info("The following functions will be applied for the given group:");
for (final Entry<String, List<ConsumerFunctionContext<String, FilterFunction>>> entry : groupToFunctions.entrySet()) {
LOGGER.info("Group = {}: ", entry.getKey());
for (final ConsumerFunctionContext<String, FilterFunction> cfc : entry.getValue()) {
LOGGER.info("\t{} {}", StringUtils.join(cfc.getSelection(), ','), cfc.getFunction());
}
}
boolean updated = false;
View.Builder builder = new View.Builder();
for (final String group : derivedView.getEntityGroups()) {
if (intersection.contains(group)) {
if (groupToFunctions.get(group) != null) {
final ViewElementDefinition ved = new ViewElementDefinition.Builder().merge(derivedView.getEntity(group)).postAggregationFilterFunctions(groupToFunctions.get(group)).build();
LOGGER.info("Adding the following filter functions to the view for group {}:", group);
for (final ConsumerFunctionContext<String, FilterFunction> cfc : groupToFunctions.get(group)) {
LOGGER.info("\t{} {}", StringUtils.join(cfc.getSelection(), ','), cfc.getFunction());
}
builder = builder.entity(group, ved);
updated = true;
} else {
LOGGER.info("Not adding any filter functions to the view for group {}", group);
}
}
}
for (final String group : derivedView.getEdgeGroups()) {
if (intersection.contains(group)) {
if (groupToFunctions.get(group) != null) {
final ViewElementDefinition ved = new ViewElementDefinition.Builder().merge(derivedView.getEdge(group)).postAggregationFilterFunctions(groupToFunctions.get(group)).build();
LOGGER.info("Adding the following filter functions to the view for group {}:", group);
for (final ConsumerFunctionContext<String, FilterFunction> cfc : groupToFunctions.get(group)) {
LOGGER.info("\t{} {}", StringUtils.join(cfc.getSelection(), ','), cfc.getFunction());
}
builder = builder.edge(group, ved);
updated = true;
} else {
LOGGER.info("Not adding any filter functions to the view for group {}", group);
}
}
}
if (updated) {
operation.setView(builder.build());
} else {
operation.setView(derivedView);
}
return operation;
}
use of org.apache.spark.sql.sources.Filter in project Gaffer by gchq.
the class FilterToOperationConverterTest method testSingleGroupNotInSchema.
@Test
public void testSingleGroupNotInSchema() {
final Schema schema = getSchema();
final SparkSession sparkSession = SparkSessionProvider.getSparkSession();
final Filter[] filters = new Filter[1];
filters[0] = new EqualTo(SchemaToStructTypeConverter.GROUP, "random");
final FiltersToOperationConverter converter = new FiltersToOperationConverter(getViewFromSchema(schema), schema, filters);
final Operation operation = converter.getOperation();
assertNull(operation);
}
use of org.apache.spark.sql.sources.Filter in project Gaffer by gchq.
the class FilterToOperationConverterTest method testSingleGroup.
@Test
public void testSingleGroup() {
final Schema schema = getSchema();
final SparkSession sparkSession = SparkSessionProvider.getSparkSession();
final Filter[] filters = new Filter[1];
filters[0] = new EqualTo(SchemaToStructTypeConverter.GROUP, ENTITY_GROUP);
final FiltersToOperationConverter converter = new FiltersToOperationConverter(getViewFromSchema(schema), schema, filters);
final Operation operation = converter.getOperation();
assertTrue(operation instanceof GetRDDOfAllElements);
assertEquals(Collections.singleton(ENTITY_GROUP), ((GraphFilters) operation).getView().getEntityGroups());
assertEquals(0, ((GraphFilters) operation).getView().getEdgeGroups().size());
}
use of org.apache.spark.sql.sources.Filter in project Gaffer by gchq.
the class FilterToOperationConverterTest method testSpecifyVertexAndPropertyFilter.
@Test
public void testSpecifyVertexAndPropertyFilter() {
final Schema schema = getSchema();
final SparkSession sparkSession = SparkSessionProvider.getSparkSession();
// Specify vertex and a filter on property1
Filter[] filters = new Filter[2];
filters[0] = new GreaterThan("property1", 5);
filters[1] = new EqualTo(SchemaToStructTypeConverter.VERTEX_COL_NAME, "0");
FiltersToOperationConverter converter = new FiltersToOperationConverter(getViewFromSchema(schema), schema, filters);
Operation operation = converter.getOperation();
assertTrue(operation instanceof GetRDDOfElements);
assertEquals(1, ((GraphFilters) operation).getView().getEntityGroups().size());
assertEquals(0, ((GraphFilters) operation).getView().getEdgeGroups().size());
final Set<EntityId> seeds = new HashSet<>();
for (final Object seed : ((GetRDDOfElements) operation).getInput()) {
seeds.add((EntitySeed) seed);
}
assertEquals(Collections.singleton(new EntitySeed("0")), seeds);
View opView = ((GraphFilters) operation).getView();
List<TupleAdaptedPredicate<String, ?>> entityPostAggFilters = opView.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
assertThat(entityPostAggFilters).hasSize(1);
final ArrayList<String> expectedProperties = new ArrayList<>();
expectedProperties.add("property1");
assertThat(entityPostAggFilters.get(0).getSelection()).hasSize(1);
assertEquals(expectedProperties.get(0), entityPostAggFilters.get(0).getSelection()[0]);
final ArrayList<Predicate> expectedFunctions = new ArrayList<>();
expectedFunctions.add(new IsMoreThan(5, false));
assertEquals(expectedFunctions.get(0), entityPostAggFilters.get(0).getPredicate());
// Specify vertex and filters on properties property1 and property4
filters = new Filter[3];
filters[0] = new GreaterThan("property1", 5);
filters[1] = new EqualTo(SchemaToStructTypeConverter.VERTEX_COL_NAME, "0");
filters[2] = new LessThan("property4", 8);
converter = new FiltersToOperationConverter(getViewFromSchema(schema), schema, filters);
operation = converter.getOperation();
assertTrue(operation instanceof GetRDDOfElements);
assertEquals(1, ((GraphFilters) operation).getView().getEntityGroups().size());
assertEquals(0, ((GraphFilters) operation).getView().getEdgeGroups().size());
seeds.clear();
for (final Object seed : ((GetRDDOfElements) operation).getInput()) {
seeds.add((EntitySeed) seed);
}
assertEquals(Collections.singleton(new EntitySeed("0")), seeds);
opView = ((GraphFilters) operation).getView();
entityPostAggFilters = opView.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
assertThat(entityPostAggFilters).hasSize(2);
expectedProperties.clear();
expectedProperties.add("property1");
expectedProperties.add("property4");
assertThat(entityPostAggFilters.get(0).getSelection()).hasSize(1);
assertEquals(expectedProperties.get(0), entityPostAggFilters.get(0).getSelection()[0]);
assertThat(entityPostAggFilters.get(1).getSelection()).hasSize(1);
assertEquals(expectedProperties.get(1), entityPostAggFilters.get(1).getSelection()[0]);
expectedFunctions.clear();
expectedFunctions.add(new IsMoreThan(5, false));
expectedFunctions.add(new IsLessThan(8, false));
assertEquals(expectedFunctions.get(0), entityPostAggFilters.get(0).getPredicate());
assertEquals(expectedFunctions.get(1), entityPostAggFilters.get(1).getPredicate());
}
use of org.apache.spark.sql.sources.Filter in project Gaffer by gchq.
the class FilterToOperationConverterTest method testSpecifyPropertyFilters.
@Test
public void testSpecifyPropertyFilters() {
final Schema schema = getSchema();
final SparkSession sparkSession = SparkSessionProvider.getSparkSession();
final Filter[] filters = new Filter[1];
// GreaterThan
filters[0] = new GreaterThan("property1", 5);
FiltersToOperationConverter converter = new FiltersToOperationConverter(getViewFromSchema(schema), schema, filters);
Operation operation = converter.getOperation();
assertTrue(operation instanceof GetRDDOfAllElements);
View opView = ((GraphFilters) operation).getView();
List<TupleAdaptedPredicate<String, ?>> entityPostAggFilters = opView.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
assertThat(entityPostAggFilters).hasSize(1);
assertArrayEquals(new String[] { "property1" }, entityPostAggFilters.get(0).getSelection());
assertEquals(new IsMoreThan(5, false), entityPostAggFilters.get(0).getPredicate());
for (final String edgeGroup : EDGE_GROUPS) {
final List<TupleAdaptedPredicate<String, ?>> edgePostAggFilters = opView.getEdge(edgeGroup).getPostAggregationFilterFunctions();
assertThat(edgePostAggFilters).hasSize(1);
assertArrayEquals(new String[] { "property1" }, edgePostAggFilters.get(0).getSelection());
assertEquals(new IsMoreThan(5, false), edgePostAggFilters.get(0).getPredicate());
}
// LessThan
filters[0] = new LessThan("property4", 8L);
converter = new FiltersToOperationConverter(getViewFromSchema(schema), schema, filters);
operation = converter.getOperation();
assertTrue(operation instanceof GetRDDOfAllElements);
// Only groups ENTITY_GROUP and EDGE_GROUP should be in the view as only they have property4
opView = ((GraphFilters) operation).getView();
entityPostAggFilters = opView.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
assertThat(entityPostAggFilters).hasSize(1);
assertArrayEquals(new String[] { "property4" }, entityPostAggFilters.get(0).getSelection());
assertEquals(new IsLessThan(8L, false), entityPostAggFilters.get(0).getPredicate());
List<TupleAdaptedPredicate<String, ?>> edgePostAggFilters = opView.getEdge(EDGE_GROUP).getPostAggregationFilterFunctions();
assertThat(edgePostAggFilters).hasSize(1);
assertArrayEquals(new String[] { "property4" }, edgePostAggFilters.get(0).getSelection());
assertEquals(new IsLessThan(8L, false), edgePostAggFilters.get(0).getPredicate());
// And
final Filter left = new GreaterThan("property1", 5);
final Filter right = new GreaterThan("property4", 8L);
filters[0] = new And(left, right);
converter = new FiltersToOperationConverter(getViewFromSchema(schema), schema, filters);
operation = converter.getOperation();
assertTrue(operation instanceof GetRDDOfAllElements);
// Only groups ENTITY_GROUP and EDGE_GROUP should be in the view as only they have property1 and property4
opView = ((GraphFilters) operation).getView();
entityPostAggFilters = opView.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
assertThat(entityPostAggFilters).hasSize(2);
final ArrayList<String> expectedProperties = new ArrayList<>();
expectedProperties.add("property1");
expectedProperties.add("property4");
assertThat(entityPostAggFilters.get(0).getSelection()).hasSize(1);
assertEquals(expectedProperties.get(0), entityPostAggFilters.get(0).getSelection()[0]);
assertThat(entityPostAggFilters.get(1).getSelection()).hasSize(1);
assertEquals(expectedProperties.get(1), entityPostAggFilters.get(1).getSelection()[0]);
final ArrayList<Predicate> expectedFunctions = new ArrayList<>();
expectedFunctions.add(new IsMoreThan(5, false));
expectedFunctions.add(new IsMoreThan(8L, false));
assertEquals(expectedFunctions.get(0), entityPostAggFilters.get(0).getPredicate());
assertEquals(expectedFunctions.get(1), entityPostAggFilters.get(1).getPredicate());
edgePostAggFilters = opView.getEdge(EDGE_GROUP).getPostAggregationFilterFunctions();
assertThat(edgePostAggFilters).hasSize(2);
assertThat(edgePostAggFilters.get(0).getSelection()).hasSize(1);
assertEquals(expectedProperties.get(0), edgePostAggFilters.get(0).getSelection()[0]);
assertThat(edgePostAggFilters.get(1).getSelection()).hasSize(1);
assertEquals(expectedProperties.get(1), edgePostAggFilters.get(1).getSelection()[0]);
}
Aggregations