use of org.apache.spark.sql.Row in project Gaffer by gchq.
the class AccumuloStoreRelationTest method testBuildScanWithView.
private void testBuildScanWithView(final String name, final View view, final Predicate<Element> returnElement) throws OperationException, StoreException {
// Given
final SQLContext sqlContext = getSqlContext(name);
final Schema schema = getSchema();
final AccumuloProperties properties = AccumuloProperties.loadStoreProperties(AccumuloStoreRelationTest.class.getResourceAsStream("/store.properties"));
final SingleUseMockAccumuloStore store = new SingleUseMockAccumuloStore();
store.initialise(schema, properties);
addElements(store);
// When
final AccumuloStoreRelation relation = new AccumuloStoreRelation(sqlContext, Collections.emptyList(), view, store, new User());
final RDD<Row> rdd = relation.buildScan();
final Row[] returnedElements = (Row[]) rdd.collect();
// Then
// - Actual results are:
final Set<Row> results = new HashSet<>();
for (int i = 0; i < returnedElements.length; i++) {
results.add(returnedElements[i]);
}
// - Expected results are:
final SchemaToStructTypeConverter schemaConverter = new SchemaToStructTypeConverter(schema, view, new ArrayList<>());
final ConvertElementToRow elementConverter = new ConvertElementToRow(schemaConverter.getUsedProperties(), schemaConverter.getPropertyNeedsConversion(), schemaConverter.getConverterByProperty());
final Set<Row> expectedRows = new HashSet<>();
StreamSupport.stream(getElements().spliterator(), false).filter(returnElement).map(elementConverter::apply).forEach(expectedRows::add);
assertEquals(expectedRows, results);
sqlContext.sparkContext().stop();
}
use of org.apache.spark.sql.Row in project Gaffer by gchq.
the class GetDataFrameOfElementsHandlerTest method checkCanDealWithUserDefinedConversion.
@Test
public void checkCanDealWithUserDefinedConversion() throws OperationException {
final Graph graph = getGraph("/schema-DataFrame/dataSchemaUserDefinedConversion.json", getElementsForUserDefinedConversion());
final SQLContext sqlContext = getSqlContext("checkCanDealWithUserDefinedConversion");
// Edges group - check get correct edges
final List<Converter> converters = new ArrayList<>();
converters.add(new MyPropertyConverter());
GetDataFrameOfElements dfOperation = new GetDataFrameOfElements.Builder().sqlContext(sqlContext).view(new View.Builder().edge(EDGE_GROUP).build()).converters(converters).build();
Dataset<Row> dataFrame = graph.execute(dfOperation, new User());
Set<Row> results = new HashSet<>(dataFrame.collectAsList());
final Set<Row> expectedRows = new HashSet<>();
final MutableList<Object> fields1 = new MutableList<>();
Map<String, Long> freqMap = Map$.MODULE$.empty();
freqMap.put("Y", 1000L);
freqMap.put("Z", 10000L);
fields1.appendElem(EDGE_GROUP);
fields1.appendElem("B");
fields1.appendElem("C");
fields1.appendElem(freqMap);
final HyperLogLogPlus hllpp = new HyperLogLogPlus(5, 5);
hllpp.offer("AAA");
hllpp.offer("BBB");
fields1.appendElem(hllpp.cardinality());
fields1.appendElem(50);
expectedRows.add(Row$.MODULE$.fromSeq(fields1));
assertEquals(expectedRows, results);
// Entities group - check get correct entities
dfOperation = new GetDataFrameOfElements.Builder().sqlContext(sqlContext).view(new View.Builder().entity(ENTITY_GROUP).build()).converters(converters).build();
dataFrame = graph.execute(dfOperation, new User());
results.clear();
results.addAll(dataFrame.collectAsList());
expectedRows.clear();
fields1.clear();
freqMap.clear();
freqMap.put("W", 10L);
freqMap.put("X", 100L);
fields1.appendElem(ENTITY_GROUP);
fields1.appendElem("A");
fields1.appendElem(freqMap);
final HyperLogLogPlus hllpp2 = new HyperLogLogPlus(5, 5);
hllpp2.offer("AAA");
fields1.appendElem(hllpp2.cardinality());
fields1.appendElem(10);
expectedRows.add(Row$.MODULE$.fromSeq(fields1));
assertEquals(expectedRows, results);
sqlContext.sparkContext().stop();
}
use of org.apache.spark.sql.Row in project Gaffer by gchq.
the class GetDataFrameOfElementsHandlerTest method checkGetCorrectElementsInDataFrameMultipleGroups.
@Test
public void checkGetCorrectElementsInDataFrameMultipleGroups() throws OperationException {
final Graph graph = getGraph("/schema-DataFrame/dataSchema.json", getElements());
final SQLContext sqlContext = getSqlContext("checkGetCorrectElementsInDataFrameMultipleGroups");
// Use entity and edges group - check get correct data
GetDataFrameOfElements dfOperation = new GetDataFrameOfElements.Builder().sqlContext(sqlContext).view(new View.Builder().entity(ENTITY_GROUP).edge(EDGE_GROUP).build()).build();
Dataset<Row> dataFrame = graph.execute(dfOperation, new User());
final Set<Row> results = new HashSet<>(dataFrame.collectAsList());
final Set<Row> expectedRows = new HashSet<>();
for (int i = 0; i < NUM_ELEMENTS; i++) {
final MutableList<Object> fields1 = new MutableList<>();
fields1.appendElem(EDGE_GROUP);
fields1.appendElem(null);
fields1.appendElem(1);
fields1.appendElem(2);
fields1.appendElem(3.0F);
fields1.appendElem(4.0D);
fields1.appendElem(5L);
fields1.appendElem(100L);
fields1.appendElem("" + i);
fields1.appendElem("B");
expectedRows.add(Row$.MODULE$.fromSeq(fields1));
final MutableList<Object> fields2 = new MutableList<>();
fields2.appendElem(EDGE_GROUP);
fields2.appendElem(null);
fields2.appendElem(6);
fields2.appendElem(7);
fields2.appendElem(8.0F);
fields2.appendElem(9.0D);
fields2.appendElem(10L);
fields2.appendElem(i * 200L);
fields2.appendElem("" + i);
fields2.appendElem("C");
expectedRows.add(Row$.MODULE$.fromSeq(fields2));
final MutableList<Object> fields3 = new MutableList<>();
fields3.appendElem(ENTITY_GROUP);
fields3.appendElem("" + i);
fields3.appendElem(1);
fields3.appendElem(i);
fields3.appendElem(3.0F);
fields3.appendElem(4.0D);
fields3.appendElem(5L);
fields3.appendElem(6);
fields3.appendElem(null);
fields3.appendElem(null);
expectedRows.add(Row$.MODULE$.fromSeq(fields3));
}
assertEquals(expectedRows, results);
// Entities group - check get correct entities
dfOperation = new GetDataFrameOfElements.Builder().sqlContext(sqlContext).view(new View.Builder().entity(ENTITY_GROUP).build()).build();
dataFrame = graph.execute(dfOperation, new User());
results.clear();
results.addAll(dataFrame.collectAsList());
expectedRows.clear();
for (int i = 0; i < NUM_ELEMENTS; i++) {
final MutableList<Object> fields1 = new MutableList<>();
fields1.clear();
fields1.appendElem(ENTITY_GROUP);
fields1.appendElem("" + i);
fields1.appendElem(1);
fields1.appendElem(i);
fields1.appendElem(3.0F);
fields1.appendElem(4.0D);
fields1.appendElem(5L);
fields1.appendElem(6);
expectedRows.add(Row$.MODULE$.fromSeq(fields1));
}
assertEquals(expectedRows, results);
sqlContext.sparkContext().stop();
}
use of org.apache.spark.sql.Row in project Gaffer by gchq.
the class GetDataFrameOfElementsHandlerTest method checkGetCorrectElementsInDataFrameWithProjectionAndFiltering.
@Test
public void checkGetCorrectElementsInDataFrameWithProjectionAndFiltering() throws OperationException {
final Graph graph = getGraph("/schema-DataFrame/dataSchema.json", getElements());
final SQLContext sqlContext = getSqlContext("checkGetCorrectElementsInDataFrameWithProjectionAndFiltering");
// Get DataFrame
final GetDataFrameOfElements dfOperation = new GetDataFrameOfElements.Builder().sqlContext(sqlContext).view(new View.Builder().edge(EDGE_GROUP).build()).build();
final Dataset<Row> dataFrame = graph.execute(dfOperation, new User());
// Check get correct rows when ask for all columns but only rows where property2 > 4.0
Set<Row> results = new HashSet<>(dataFrame.filter("property2 > 4.0").collectAsList());
final Set<Row> expectedRows = new HashSet<>();
for (int i = 0; i < NUM_ELEMENTS; i++) {
final MutableList<Object> fields = new MutableList<>();
fields.appendElem(EDGE_GROUP);
fields.appendElem("" + i);
fields.appendElem("C");
fields.appendElem(6);
fields.appendElem(7);
fields.appendElem(8.0F);
fields.appendElem(9.0D);
fields.appendElem(10L);
fields.appendElem(i * 200L);
expectedRows.add(Row$.MODULE$.fromSeq(fields));
}
assertEquals(expectedRows, results);
// Check get correct rows when ask for columns property2 and property3 but only rows where property2 > 4.0
results = new HashSet<>(dataFrame.select("property2", "property3").filter("property2 > 4.0").collectAsList());
expectedRows.clear();
for (int i = 0; i < NUM_ELEMENTS; i++) {
final MutableList<Object> fields = new MutableList<>();
fields.appendElem(8.0F);
fields.appendElem(9.0D);
expectedRows.add(Row$.MODULE$.fromSeq(fields));
}
assertEquals(expectedRows, results);
sqlContext.sparkContext().stop();
}
use of org.apache.spark.sql.Row in project Gaffer by gchq.
the class GetDataFrameOfElementsExample method getDataFrameOfElementsWithEntityGroup.
public void getDataFrameOfElementsWithEntityGroup(final SQLContext sqlc, final Graph graph) throws OperationException {
ROOT_LOGGER.setLevel(Level.INFO);
log("#### " + getMethodNameAsSentence(0) + "\n");
printGraph();
ROOT_LOGGER.setLevel(Level.OFF);
final GetDataFrameOfElements operation = new GetDataFrameOfElements.Builder().view(new View.Builder().entity("entity").build()).sqlContext(sqlc).build();
final Dataset<Row> df = graph.execute(operation, new User("user01"));
// Show
String result = df.showString(100, 20);
ROOT_LOGGER.setLevel(Level.INFO);
printJava("GetDataFrameOfElements operation = new GetDataFrameOfElements.Builder()\n" + " .view(new View.Builder()\n" + " .entity(\"entity\")\n" + " .build()).\n" + " .sqlContext(sqlc)\n" + " .build();\n" + "Dataset<Row> df = getGraph().execute(operation, new User(\"user01\"));\n" + "df.show();");
log("The results are:");
log("```");
log(result.substring(0, result.length() - 2));
log("```");
ROOT_LOGGER.setLevel(Level.OFF);
// Restrict to entities involving certain vertices
final Dataset<Row> seeded = df.filter("vertex = 1 OR vertex = 2");
result = seeded.showString(100, 20);
ROOT_LOGGER.setLevel(Level.INFO);
printJava("df.filter(\"vertex = 1 OR vertex = 2\").show();");
log("The results are:");
log("```");
log(result.substring(0, result.length() - 2));
log("```");
ROOT_LOGGER.setLevel(Level.OFF);
// Filter by property
final Dataset<Row> filtered = df.filter("count > 1");
result = filtered.showString(100, 20);
ROOT_LOGGER.setLevel(Level.INFO);
printJava("df.filter(\"count > 1\").show();");
log("The results are:");
log("```");
log(result.substring(0, result.length() - 2));
log("```");
ROOT_LOGGER.setLevel(Level.OFF);
}
Aggregations