Search in sources :

Example 1 with StructField

use of org.apache.spark.sql.types.StructField in project Gaffer by gchq.

the class SchemaToStructTypeConverter method buildSchema.

private void buildSchema() {
    LOGGER.info("Building Spark SQL schema for groups {}", StringUtils.join(groups, ','));
    for (final String group : groups) {
        final SchemaElementDefinition elementDefn = schema.getElement(group);
        final List<StructField> structFieldList = new ArrayList<>();
        if (elementDefn instanceof SchemaEntityDefinition) {
            entityOrEdgeByGroup.put(group, EntityOrEdge.ENTITY);
            final SchemaEntityDefinition entityDefinition = (SchemaEntityDefinition) elementDefn;
            final String vertexClass = schema.getType(entityDefinition.getVertex()).getClassString();
            final DataType vertexType = getType(vertexClass);
            if (vertexType == null) {
                throw new RuntimeException("Vertex must be a recognised type: found " + vertexClass);
            }
            LOGGER.info("Group {} is an entity group - {} is of type {}", group, VERTEX_COL_NAME, vertexType);
            structFieldList.add(new StructField(VERTEX_COL_NAME, vertexType, true, Metadata.empty()));
        } else {
            entityOrEdgeByGroup.put(group, EntityOrEdge.EDGE);
            final SchemaEdgeDefinition edgeDefinition = (SchemaEdgeDefinition) elementDefn;
            final String srcClass = schema.getType(edgeDefinition.getSource()).getClassString();
            final String dstClass = schema.getType(edgeDefinition.getDestination()).getClassString();
            final DataType srcType = getType(srcClass);
            final DataType dstType = getType(dstClass);
            if (srcType == null || dstType == null) {
                throw new RuntimeException("Both source and destination must be recognised types: source was " + srcClass + " destination was " + dstClass);
            }
            LOGGER.info("Group {} is an edge group - {} is of type {}, {} is of type {}", group, SRC_COL_NAME, srcType, DST_COL_NAME, dstType);
            structFieldList.add(new StructField(SRC_COL_NAME, srcType, true, Metadata.empty()));
            structFieldList.add(new StructField(DST_COL_NAME, dstType, true, Metadata.empty()));
        }
        final Set<String> properties = elementDefn.getProperties();
        for (final String property : properties) {
            // Check if property is of a known type that can be handled by default
            final String propertyClass = elementDefn.getPropertyClass(property).getCanonicalName();
            DataType propertyType = getType(propertyClass);
            if (propertyType != null) {
                propertyNeedsConversion.put(property, needsConversion(propertyClass));
                structFieldList.add(new StructField(property, propertyType, true, Metadata.empty()));
                LOGGER.info("Property {} is of type {}", property, propertyType);
            } else {
                // Check if any of the provided converters can handle it
                if (converters != null) {
                    for (final Converter converter : converters) {
                        if (converter.canHandle(elementDefn.getPropertyClass(property))) {
                            propertyNeedsConversion.put(property, true);
                            propertyType = converter.convertedType();
                            converterByProperty.put(property, converter);
                            structFieldList.add(new StructField(property, propertyType, true, Metadata.empty()));
                            LOGGER.info("Property {} of type {} will be converted by {} to {}", property, propertyClass, converter.getClass().getName(), propertyType);
                            break;
                        }
                    }
                    if (propertyType == null) {
                        LOGGER.warn("Ignoring property {} as it is not a recognised type and none of the provided " + "converters can handle it", property);
                    }
                }
            }
        }
        structTypeByGroup.put(group, new StructType(structFieldList.toArray(new StructField[structFieldList.size()])));
    }
    // Create reverse map of field name to StructField
    final Map<String, Set<StructField>> fieldToStructs = new HashMap<>();
    for (final String group : groups) {
        final StructType groupSchema = structTypeByGroup.get(group);
        for (final String field : groupSchema.fieldNames()) {
            if (fieldToStructs.get(field) == null) {
                fieldToStructs.put(field, new HashSet<StructField>());
            }
            fieldToStructs.get(field).add(groupSchema.apply(field));
        }
    }
    // Check consistency, i.e. if the same field appears in multiple groups then the types are consistent
    for (final Entry<String, Set<StructField>> entry : fieldToStructs.entrySet()) {
        final Set<StructField> schemas = entry.getValue();
        if (schemas.size() > 1) {
            throw new IllegalArgumentException("Inconsistent fields: the field " + entry.getKey() + " has more than one definition: " + StringUtils.join(schemas, ','));
        }
    }
    // Merge schemas for groups together - fields should appear in the order the groups were provided
    final LinkedHashSet<StructField> fields = new LinkedHashSet<>();
    fields.add(new StructField(GROUP, DataTypes.StringType, false, Metadata.empty()));
    usedProperties.add(GROUP);
    for (final String group : groups) {
        final StructType groupSchema = structTypeByGroup.get(group);
        for (final String field : groupSchema.fieldNames()) {
            final StructField struct = groupSchema.apply(field);
            // Add struct to fields unless it has already been added
            if (!fields.contains(struct)) {
                fields.add(struct);
                usedProperties.add(field);
            }
        }
    }
    structType = new StructType(fields.toArray(new StructField[fields.size()]));
    LOGGER.info("Schema is {}", structType);
    LOGGER.debug("properties -> conversion: {}", StringUtils.join(propertyNeedsConversion.entrySet(), ','));
}
Also used : LinkedHashSet(java.util.LinkedHashSet) HashSet(java.util.HashSet) LinkedHashSet(java.util.LinkedHashSet) Set(java.util.Set) StructType(org.apache.spark.sql.types.StructType) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) SchemaEntityDefinition(uk.gov.gchq.gaffer.store.schema.SchemaEntityDefinition) StructField(org.apache.spark.sql.types.StructField) DataType(org.apache.spark.sql.types.DataType) SchemaEdgeDefinition(uk.gov.gchq.gaffer.store.schema.SchemaEdgeDefinition) Converter(uk.gov.gchq.gaffer.spark.operation.dataframe.converter.property.Converter) HyperLogLogPlusConverter(uk.gov.gchq.gaffer.spark.operation.dataframe.converter.property.impl.HyperLogLogPlusConverter) FreqMapConverter(uk.gov.gchq.gaffer.spark.operation.dataframe.converter.property.impl.FreqMapConverter) UnionConverter(uk.gov.gchq.gaffer.spark.operation.dataframe.converter.property.impl.datasketches.theta.UnionConverter) SchemaElementDefinition(uk.gov.gchq.gaffer.store.schema.SchemaElementDefinition)

Example 2 with StructField

use of org.apache.spark.sql.types.StructField in project carbondata by apache.

the class VectorizedCarbonRecordReader method initBatch.

/**
   * Returns the ColumnarBatch object that will be used for all rows returned by this reader.
   * This object is reused. Calling this enables the vectorized reader. This should be called
   * before any calls to nextKeyValue/nextBatch.
   */
private void initBatch(MemoryMode memMode) {
    List<QueryDimension> queryDimension = queryModel.getQueryDimension();
    List<QueryMeasure> queryMeasures = queryModel.getQueryMeasures();
    StructField[] fields = new StructField[queryDimension.size() + queryMeasures.size()];
    for (int i = 0; i < queryDimension.size(); i++) {
        QueryDimension dim = queryDimension.get(i);
        if (dim.getDimension().hasEncoding(Encoding.DIRECT_DICTIONARY)) {
            DirectDictionaryGenerator generator = DirectDictionaryKeyGeneratorFactory.getDirectDictionaryGenerator(dim.getDimension().getDataType());
            fields[dim.getQueryOrder()] = new StructField(dim.getColumnName(), CarbonScalaUtil.convertCarbonToSparkDataType(generator.getReturnType()), true, null);
        } else if (!dim.getDimension().hasEncoding(Encoding.DICTIONARY)) {
            fields[dim.getQueryOrder()] = new StructField(dim.getColumnName(), CarbonScalaUtil.convertCarbonToSparkDataType(dim.getDimension().getDataType()), true, null);
        } else if (dim.getDimension().isComplex()) {
            fields[dim.getQueryOrder()] = new StructField(dim.getColumnName(), CarbonScalaUtil.convertCarbonToSparkDataType(dim.getDimension().getDataType()), true, null);
        } else {
            fields[dim.getQueryOrder()] = new StructField(dim.getColumnName(), CarbonScalaUtil.convertCarbonToSparkDataType(DataType.INT), true, null);
        }
    }
    for (int i = 0; i < queryMeasures.size(); i++) {
        QueryMeasure msr = queryMeasures.get(i);
        switch(msr.getMeasure().getDataType()) {
            case SHORT:
            case INT:
            case LONG:
                fields[msr.getQueryOrder()] = new StructField(msr.getColumnName(), CarbonScalaUtil.convertCarbonToSparkDataType(msr.getMeasure().getDataType()), true, null);
                break;
            case DECIMAL:
                fields[msr.getQueryOrder()] = new StructField(msr.getColumnName(), new DecimalType(msr.getMeasure().getPrecision(), msr.getMeasure().getScale()), true, null);
                break;
            default:
                fields[msr.getQueryOrder()] = new StructField(msr.getColumnName(), CarbonScalaUtil.convertCarbonToSparkDataType(DataType.DOUBLE), true, null);
        }
    }
    columnarBatch = ColumnarBatch.allocate(new StructType(fields), memMode);
    CarbonColumnVector[] vectors = new CarbonColumnVector[fields.length];
    boolean[] filteredRows = new boolean[columnarBatch.capacity()];
    for (int i = 0; i < fields.length; i++) {
        vectors[i] = new ColumnarVectorWrapper(columnarBatch.column(i), filteredRows);
    }
    carbonColumnarBatch = new CarbonColumnarBatch(vectors, columnarBatch.capacity(), filteredRows);
}
Also used : StructType(org.apache.spark.sql.types.StructType) CarbonColumnarBatch(org.apache.carbondata.core.scan.result.vector.CarbonColumnarBatch) CarbonColumnVector(org.apache.carbondata.core.scan.result.vector.CarbonColumnVector) StructField(org.apache.spark.sql.types.StructField) QueryMeasure(org.apache.carbondata.core.scan.model.QueryMeasure) DecimalType(org.apache.spark.sql.types.DecimalType) DirectDictionaryGenerator(org.apache.carbondata.core.keygenerator.directdictionary.DirectDictionaryGenerator) QueryDimension(org.apache.carbondata.core.scan.model.QueryDimension)

Example 3 with StructField

use of org.apache.spark.sql.types.StructField in project incubator-systemml by apache.

the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrameNull.

@Test
public void testStringDataFrameToVectorDataFrameNull() throws DMLRuntimeException {
    List<String> list = new ArrayList<String>();
    list.add("[1.2, 3.4]");
    list.add(null);
    JavaRDD<String> javaRddString = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
    SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
    Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
    List<String> expectedResults = new ArrayList<String>();
    expectedResults.add("[[1.2,3.4]]");
    expectedResults.add("[null]");
    List<Row> outputList = outDF.collectAsList();
    for (Row row : outputList) {
        assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
    }
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) StructField(org.apache.spark.sql.types.StructField) Row(org.apache.spark.sql.Row) Test(org.junit.Test)

Example 4 with StructField

use of org.apache.spark.sql.types.StructField in project incubator-systemml by apache.

the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrameNonNumbers.

@Test(expected = SparkException.class)
public void testStringDataFrameToVectorDataFrameNonNumbers() throws DMLRuntimeException {
    List<String> list = new ArrayList<String>();
    list.add("[cheeseburger,fries]");
    JavaRDD<String> javaRddString = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
    SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
    Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
    // trigger evaluation to throw exception
    outDF.collectAsList();
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) StructField(org.apache.spark.sql.types.StructField) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) Row(org.apache.spark.sql.Row) Test(org.junit.Test)

Example 5 with StructField

use of org.apache.spark.sql.types.StructField in project incubator-systemml by apache.

the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrame.

@Test
public void testStringDataFrameToVectorDataFrame() throws DMLRuntimeException {
    List<String> list = new ArrayList<String>();
    list.add("((1.2, 4.3, 3.4))");
    list.add("(1.2, 3.4, 2.2)");
    list.add("[[1.2, 34.3, 1.2, 1.25]]");
    list.add("[1.2, 3.4]");
    JavaRDD<String> javaRddString = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
    SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
    Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
    List<String> expectedResults = new ArrayList<String>();
    expectedResults.add("[[1.2,4.3,3.4]]");
    expectedResults.add("[[1.2,3.4,2.2]]");
    expectedResults.add("[[1.2,34.3,1.2,1.25]]");
    expectedResults.add("[[1.2,3.4]]");
    List<Row> outputList = outDF.collectAsList();
    for (Row row : outputList) {
        assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
    }
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) StructField(org.apache.spark.sql.types.StructField) Row(org.apache.spark.sql.Row) Test(org.junit.Test)

Aggregations

StructField (org.apache.spark.sql.types.StructField)44 StructType (org.apache.spark.sql.types.StructType)40 ArrayList (java.util.ArrayList)39 Row (org.apache.spark.sql.Row)38 Test (org.junit.Test)32 Script (org.apache.sysml.api.mlcontext.Script)30 VectorUDT (org.apache.spark.ml.linalg.VectorUDT)18 MatrixMetadata (org.apache.sysml.api.mlcontext.MatrixMetadata)17 Vector (org.apache.spark.ml.linalg.Vector)13 Tuple2 (scala.Tuple2)7 DataType (org.apache.spark.sql.types.DataType)5 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)5 FrameMetadata (org.apache.sysml.api.mlcontext.FrameMetadata)4 MLResults (org.apache.sysml.api.mlcontext.MLResults)4 CommaSeparatedValueStringToDoubleArrayRow (org.apache.sysml.test.integration.mlcontext.MLContextTest.CommaSeparatedValueStringToDoubleArrayRow)4 SparkSession (org.apache.spark.sql.SparkSession)3 ValueType (org.apache.sysml.parser.Expression.ValueType)3 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)2 DenseVector (org.apache.spark.ml.linalg.DenseVector)2 BinaryBlockMatrix (org.apache.sysml.api.mlcontext.BinaryBlockMatrix)2