use of org.apache.spark.sql.types.StructType in project Gaffer by gchq.
the class SchemaToStructTypeConverter method buildSchema.
private void buildSchema() {
LOGGER.info("Building Spark SQL schema for groups {}", StringUtils.join(groups, ','));
for (final String group : groups) {
final SchemaElementDefinition elementDefn = schema.getElement(group);
final List<StructField> structFieldList = new ArrayList<>();
if (elementDefn instanceof SchemaEntityDefinition) {
entityOrEdgeByGroup.put(group, EntityOrEdge.ENTITY);
final SchemaEntityDefinition entityDefinition = (SchemaEntityDefinition) elementDefn;
final String vertexClass = schema.getType(entityDefinition.getVertex()).getClassString();
final DataType vertexType = getType(vertexClass);
if (vertexType == null) {
throw new RuntimeException("Vertex must be a recognised type: found " + vertexClass);
}
LOGGER.info("Group {} is an entity group - {} is of type {}", group, VERTEX_COL_NAME, vertexType);
structFieldList.add(new StructField(VERTEX_COL_NAME, vertexType, true, Metadata.empty()));
} else {
entityOrEdgeByGroup.put(group, EntityOrEdge.EDGE);
final SchemaEdgeDefinition edgeDefinition = (SchemaEdgeDefinition) elementDefn;
final String srcClass = schema.getType(edgeDefinition.getSource()).getClassString();
final String dstClass = schema.getType(edgeDefinition.getDestination()).getClassString();
final DataType srcType = getType(srcClass);
final DataType dstType = getType(dstClass);
if (srcType == null || dstType == null) {
throw new RuntimeException("Both source and destination must be recognised types: source was " + srcClass + " destination was " + dstClass);
}
LOGGER.info("Group {} is an edge group - {} is of type {}, {} is of type {}", group, SRC_COL_NAME, srcType, DST_COL_NAME, dstType);
structFieldList.add(new StructField(SRC_COL_NAME, srcType, true, Metadata.empty()));
structFieldList.add(new StructField(DST_COL_NAME, dstType, true, Metadata.empty()));
}
final Set<String> properties = elementDefn.getProperties();
for (final String property : properties) {
// Check if property is of a known type that can be handled by default
final String propertyClass = elementDefn.getPropertyClass(property).getCanonicalName();
DataType propertyType = getType(propertyClass);
if (propertyType != null) {
propertyNeedsConversion.put(property, needsConversion(propertyClass));
structFieldList.add(new StructField(property, propertyType, true, Metadata.empty()));
LOGGER.info("Property {} is of type {}", property, propertyType);
} else {
// Check if any of the provided converters can handle it
if (converters != null) {
for (final Converter converter : converters) {
if (converter.canHandle(elementDefn.getPropertyClass(property))) {
propertyNeedsConversion.put(property, true);
propertyType = converter.convertedType();
converterByProperty.put(property, converter);
structFieldList.add(new StructField(property, propertyType, true, Metadata.empty()));
LOGGER.info("Property {} of type {} will be converted by {} to {}", property, propertyClass, converter.getClass().getName(), propertyType);
break;
}
}
if (propertyType == null) {
LOGGER.warn("Ignoring property {} as it is not a recognised type and none of the provided " + "converters can handle it", property);
}
}
}
}
structTypeByGroup.put(group, new StructType(structFieldList.toArray(new StructField[structFieldList.size()])));
}
// Create reverse map of field name to StructField
final Map<String, Set<StructField>> fieldToStructs = new HashMap<>();
for (final String group : groups) {
final StructType groupSchema = structTypeByGroup.get(group);
for (final String field : groupSchema.fieldNames()) {
if (fieldToStructs.get(field) == null) {
fieldToStructs.put(field, new HashSet<StructField>());
}
fieldToStructs.get(field).add(groupSchema.apply(field));
}
}
// Check consistency, i.e. if the same field appears in multiple groups then the types are consistent
for (final Entry<String, Set<StructField>> entry : fieldToStructs.entrySet()) {
final Set<StructField> schemas = entry.getValue();
if (schemas.size() > 1) {
throw new IllegalArgumentException("Inconsistent fields: the field " + entry.getKey() + " has more than one definition: " + StringUtils.join(schemas, ','));
}
}
// Merge schemas for groups together - fields should appear in the order the groups were provided
final LinkedHashSet<StructField> fields = new LinkedHashSet<>();
fields.add(new StructField(GROUP, DataTypes.StringType, false, Metadata.empty()));
usedProperties.add(GROUP);
for (final String group : groups) {
final StructType groupSchema = structTypeByGroup.get(group);
for (final String field : groupSchema.fieldNames()) {
final StructField struct = groupSchema.apply(field);
// Add struct to fields unless it has already been added
if (!fields.contains(struct)) {
fields.add(struct);
usedProperties.add(field);
}
}
}
structType = new StructType(fields.toArray(new StructField[fields.size()]));
LOGGER.info("Schema is {}", structType);
LOGGER.debug("properties -> conversion: {}", StringUtils.join(propertyNeedsConversion.entrySet(), ','));
}
use of org.apache.spark.sql.types.StructType in project carbondata by apache.
the class VectorizedCarbonRecordReader method initBatch.
/**
* Returns the ColumnarBatch object that will be used for all rows returned by this reader.
* This object is reused. Calling this enables the vectorized reader. This should be called
* before any calls to nextKeyValue/nextBatch.
*/
private void initBatch(MemoryMode memMode) {
List<QueryDimension> queryDimension = queryModel.getQueryDimension();
List<QueryMeasure> queryMeasures = queryModel.getQueryMeasures();
StructField[] fields = new StructField[queryDimension.size() + queryMeasures.size()];
for (int i = 0; i < queryDimension.size(); i++) {
QueryDimension dim = queryDimension.get(i);
if (dim.getDimension().hasEncoding(Encoding.DIRECT_DICTIONARY)) {
DirectDictionaryGenerator generator = DirectDictionaryKeyGeneratorFactory.getDirectDictionaryGenerator(dim.getDimension().getDataType());
fields[dim.getQueryOrder()] = new StructField(dim.getColumnName(), CarbonScalaUtil.convertCarbonToSparkDataType(generator.getReturnType()), true, null);
} else if (!dim.getDimension().hasEncoding(Encoding.DICTIONARY)) {
fields[dim.getQueryOrder()] = new StructField(dim.getColumnName(), CarbonScalaUtil.convertCarbonToSparkDataType(dim.getDimension().getDataType()), true, null);
} else if (dim.getDimension().isComplex()) {
fields[dim.getQueryOrder()] = new StructField(dim.getColumnName(), CarbonScalaUtil.convertCarbonToSparkDataType(dim.getDimension().getDataType()), true, null);
} else {
fields[dim.getQueryOrder()] = new StructField(dim.getColumnName(), CarbonScalaUtil.convertCarbonToSparkDataType(DataType.INT), true, null);
}
}
for (int i = 0; i < queryMeasures.size(); i++) {
QueryMeasure msr = queryMeasures.get(i);
switch(msr.getMeasure().getDataType()) {
case SHORT:
case INT:
case LONG:
fields[msr.getQueryOrder()] = new StructField(msr.getColumnName(), CarbonScalaUtil.convertCarbonToSparkDataType(msr.getMeasure().getDataType()), true, null);
break;
case DECIMAL:
fields[msr.getQueryOrder()] = new StructField(msr.getColumnName(), new DecimalType(msr.getMeasure().getPrecision(), msr.getMeasure().getScale()), true, null);
break;
default:
fields[msr.getQueryOrder()] = new StructField(msr.getColumnName(), CarbonScalaUtil.convertCarbonToSparkDataType(DataType.DOUBLE), true, null);
}
}
columnarBatch = ColumnarBatch.allocate(new StructType(fields), memMode);
CarbonColumnVector[] vectors = new CarbonColumnVector[fields.length];
boolean[] filteredRows = new boolean[columnarBatch.capacity()];
for (int i = 0; i < fields.length; i++) {
vectors[i] = new ColumnarVectorWrapper(columnarBatch.column(i), filteredRows);
}
carbonColumnarBatch = new CarbonColumnarBatch(vectors, columnarBatch.capacity(), filteredRows);
}
use of org.apache.spark.sql.types.StructType in project incubator-systemml by apache.
the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrameNull.
@Test
public void testStringDataFrameToVectorDataFrameNull() throws DMLRuntimeException {
List<String> list = new ArrayList<String>();
list.add("[1.2, 3.4]");
list.add(null);
JavaRDD<String> javaRddString = sc.parallelize(list);
JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
List<String> expectedResults = new ArrayList<String>();
expectedResults.add("[[1.2,3.4]]");
expectedResults.add("[null]");
List<Row> outputList = outDF.collectAsList();
for (Row row : outputList) {
assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
}
}
use of org.apache.spark.sql.types.StructType in project incubator-systemml by apache.
the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrameNonNumbers.
@Test(expected = SparkException.class)
public void testStringDataFrameToVectorDataFrameNonNumbers() throws DMLRuntimeException {
List<String> list = new ArrayList<String>();
list.add("[cheeseburger,fries]");
JavaRDD<String> javaRddString = sc.parallelize(list);
JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
// trigger evaluation to throw exception
outDF.collectAsList();
}
use of org.apache.spark.sql.types.StructType in project incubator-systemml by apache.
the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrame.
@Test
public void testStringDataFrameToVectorDataFrame() throws DMLRuntimeException {
List<String> list = new ArrayList<String>();
list.add("((1.2, 4.3, 3.4))");
list.add("(1.2, 3.4, 2.2)");
list.add("[[1.2, 34.3, 1.2, 1.25]]");
list.add("[1.2, 3.4]");
JavaRDD<String> javaRddString = sc.parallelize(list);
JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
List<String> expectedResults = new ArrayList<String>();
expectedResults.add("[[1.2,4.3,3.4]]");
expectedResults.add("[[1.2,3.4,2.2]]");
expectedResults.add("[[1.2,34.3,1.2,1.25]]");
expectedResults.add("[[1.2,3.4]]");
List<Row> outputList = outDF.collectAsList();
for (Row row : outputList) {
assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
}
}
Aggregations