Search in sources :

Example 6 with StructField

use of org.apache.spark.sql.types.StructField in project incubator-systemml by apache.

the class RDDConverterUtilsExt method addIDToDataFrame.

/**
 * Add element indices as new column to DataFrame
 *
 * @param df input data frame
 * @param sparkSession the Spark Session
 * @param nameOfCol name of index column
 * @return new data frame
 */
public static Dataset<Row> addIDToDataFrame(Dataset<Row> df, SparkSession sparkSession, String nameOfCol) {
    StructField[] oldSchema = df.schema().fields();
    StructField[] newSchema = new StructField[oldSchema.length + 1];
    for (int i = 0; i < oldSchema.length; i++) {
        newSchema[i] = oldSchema[i];
    }
    newSchema[oldSchema.length] = DataTypes.createStructField(nameOfCol, DataTypes.DoubleType, false);
    // JavaRDD<Row> newRows = df.rdd().toJavaRDD().map(new AddRowID());
    JavaRDD<Row> newRows = df.rdd().toJavaRDD().zipWithIndex().map(new AddRowID());
    return sparkSession.createDataFrame(newRows, new StructType(newSchema));
}
Also used : StructField(org.apache.spark.sql.types.StructField) StructType(org.apache.spark.sql.types.StructType) Row(org.apache.spark.sql.Row)

Example 7 with StructField

use of org.apache.spark.sql.types.StructField in project incubator-systemml by apache.

the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrameNonNumbers.

@Test(expected = SparkException.class)
public void testStringDataFrameToVectorDataFrameNonNumbers() {
    List<String> list = new ArrayList<String>();
    list.add("[cheeseburger,fries]");
    JavaRDD<String> javaRddString = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
    SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
    Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
    // trigger evaluation to throw exception
    outDF.collectAsList();
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) StructField(org.apache.spark.sql.types.StructField) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) Row(org.apache.spark.sql.Row) Test(org.junit.Test)

Example 8 with StructField

use of org.apache.spark.sql.types.StructField in project incubator-systemml by apache.

the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrameNull.

@Test
public void testStringDataFrameToVectorDataFrameNull() {
    List<String> list = new ArrayList<String>();
    list.add("[1.2, 3.4]");
    list.add(null);
    JavaRDD<String> javaRddString = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
    SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
    Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
    List<String> expectedResults = new ArrayList<String>();
    expectedResults.add("[[1.2,3.4]]");
    expectedResults.add("[null]");
    List<Row> outputList = outDF.collectAsList();
    for (Row row : outputList) {
        assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
    }
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) StructField(org.apache.spark.sql.types.StructField) Row(org.apache.spark.sql.Row) Test(org.junit.Test)

Example 9 with StructField

use of org.apache.spark.sql.types.StructField in project incubator-systemml by apache.

the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrame.

@Test
public void testStringDataFrameToVectorDataFrame() {
    List<String> list = new ArrayList<String>();
    list.add("((1.2, 4.3, 3.4))");
    list.add("(1.2, 3.4, 2.2)");
    list.add("[[1.2, 34.3, 1.2, 1.25]]");
    list.add("[1.2, 3.4]");
    JavaRDD<String> javaRddString = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
    SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
    Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
    List<String> expectedResults = new ArrayList<String>();
    expectedResults.add("[[1.2,4.3,3.4]]");
    expectedResults.add("[[1.2,3.4,2.2]]");
    expectedResults.add("[[1.2,34.3,1.2,1.25]]");
    expectedResults.add("[[1.2,3.4]]");
    List<Row> outputList = outDF.collectAsList();
    for (Row row : outputList) {
        assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
    }
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) StructField(org.apache.spark.sql.types.StructField) Row(org.apache.spark.sql.Row) Test(org.junit.Test)

Example 10 with StructField

use of org.apache.spark.sql.types.StructField in project incubator-systemml by apache.

the class MLContextTest method testDataFrameSumDMLDoublesWithNoIDColumn.

@Test
public void testDataFrameSumDMLDoublesWithNoIDColumn() {
    System.out.println("MLContextTest - DataFrame sum DML, doubles with no ID column");
    List<String> list = new ArrayList<String>();
    list.add("10,20,30");
    list.add("40,50,60");
    list.add("70,80,90");
    JavaRDD<String> javaRddString = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
    fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
    fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);
    MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES);
    Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
    setExpectedStdOut("sum: 450.0");
    ml.execute(script);
}
Also used : Script(org.apache.sysml.api.mlcontext.Script) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) StructField(org.apache.spark.sql.types.StructField) Row(org.apache.spark.sql.Row) MatrixMetadata(org.apache.sysml.api.mlcontext.MatrixMetadata) Test(org.junit.Test)

Aggregations

StructField (org.apache.spark.sql.types.StructField)52 StructType (org.apache.spark.sql.types.StructType)48 Row (org.apache.spark.sql.Row)45 ArrayList (java.util.ArrayList)43 Test (org.junit.Test)37 Script (org.apache.sysml.api.mlcontext.Script)34 VectorUDT (org.apache.spark.ml.linalg.VectorUDT)20 MatrixMetadata (org.apache.sysml.api.mlcontext.MatrixMetadata)17 DenseVector (org.apache.spark.ml.linalg.DenseVector)15 Vector (org.apache.spark.ml.linalg.Vector)15 Tuple2 (scala.Tuple2)7 SparkSession (org.apache.spark.sql.SparkSession)6 DataType (org.apache.spark.sql.types.DataType)5 MLResults (org.apache.sysml.api.mlcontext.MLResults)5 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)5 FrameMetadata (org.apache.sysml.api.mlcontext.FrameMetadata)4 CommaSeparatedValueStringToDoubleArrayRow (org.apache.sysml.test.integration.mlcontext.MLContextTest.CommaSeparatedValueStringToDoubleArrayRow)4 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)3 JavaRDD (org.apache.spark.api.java.JavaRDD)2 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)2