Search in sources :

Example 1 with SparkSession

use of org.apache.spark.sql.SparkSession in project incubator-systemml by apache.

the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrameNull.

@Test
public void testStringDataFrameToVectorDataFrameNull() throws DMLRuntimeException {
    List<String> list = new ArrayList<String>();
    list.add("[1.2, 3.4]");
    list.add(null);
    JavaRDD<String> javaRddString = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
    SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
    Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
    List<String> expectedResults = new ArrayList<String>();
    expectedResults.add("[[1.2,3.4]]");
    expectedResults.add("[null]");
    List<Row> outputList = outDF.collectAsList();
    for (Row row : outputList) {
        assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
    }
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) StructField(org.apache.spark.sql.types.StructField) Row(org.apache.spark.sql.Row) Test(org.junit.Test)

Example 2 with SparkSession

use of org.apache.spark.sql.SparkSession in project incubator-systemml by apache.

the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrameNonNumbers.

@Test(expected = SparkException.class)
public void testStringDataFrameToVectorDataFrameNonNumbers() throws DMLRuntimeException {
    List<String> list = new ArrayList<String>();
    list.add("[cheeseburger,fries]");
    JavaRDD<String> javaRddString = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
    SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
    Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
    // trigger evaluation to throw exception
    outDF.collectAsList();
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) StructField(org.apache.spark.sql.types.StructField) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) Row(org.apache.spark.sql.Row) Test(org.junit.Test)

Example 3 with SparkSession

use of org.apache.spark.sql.SparkSession in project incubator-systemml by apache.

the class AutomatedTestBase method createSystemMLSparkSession.

/**
	 * Create a SystemML-preferred Spark Session.
	 *
	 * @param appName the application name
	 * @param master the master value (ie, "local", etc)
	 * @return Spark Session
	 */
public static SparkSession createSystemMLSparkSession(String appName, String master) {
    Builder builder = SparkSession.builder();
    if (appName != null) {
        builder.appName(appName);
    }
    if (master != null) {
        builder.master(master);
    }
    builder.config("spark.driver.maxResultSize", "0");
    if (SparkExecutionContext.FAIR_SCHEDULER_MODE) {
        builder.config("spark.scheduler.mode", "FAIR");
    }
    builder.config("spark.locality.wait", "5s");
    SparkSession spark = builder.getOrCreate();
    return spark;
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) ParameterBuilder(org.apache.sysml.utils.ParameterBuilder) Builder(org.apache.spark.sql.SparkSession.Builder)

Example 4 with SparkSession

use of org.apache.spark.sql.SparkSession in project incubator-systemml by apache.

the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrame.

@Test
public void testStringDataFrameToVectorDataFrame() throws DMLRuntimeException {
    List<String> list = new ArrayList<String>();
    list.add("((1.2, 4.3, 3.4))");
    list.add("(1.2, 3.4, 2.2)");
    list.add("[[1.2, 34.3, 1.2, 1.25]]");
    list.add("[1.2, 3.4]");
    JavaRDD<String> javaRddString = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
    SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
    Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
    List<String> expectedResults = new ArrayList<String>();
    expectedResults.add("[[1.2,4.3,3.4]]");
    expectedResults.add("[[1.2,3.4,2.2]]");
    expectedResults.add("[[1.2,34.3,1.2,1.25]]");
    expectedResults.add("[[1.2,3.4]]");
    List<Row> outputList = outDF.collectAsList();
    for (Row row : outputList) {
        assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
    }
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) StructField(org.apache.spark.sql.types.StructField) Row(org.apache.spark.sql.Row) Test(org.junit.Test)

Example 5 with SparkSession

use of org.apache.spark.sql.SparkSession in project incubator-systemml by apache.

the class MLContextMultipleScriptsTest method runMLContextTestMultipleScript.

/**
	 * 
	 * @param platform
	 */
private void runMLContextTestMultipleScript(RUNTIME_PLATFORM platform, boolean wRead) {
    RUNTIME_PLATFORM oldplatform = DMLScript.rtplatform;
    DMLScript.rtplatform = platform;
    //create mlcontext
    SparkSession spark = createSystemMLSparkSession("MLContextMultipleScriptsTest", "local");
    MLContext ml = new MLContext(spark);
    ml.setExplain(true);
    String dml1 = baseDirectory + File.separator + "MultiScript1.dml";
    String dml2 = baseDirectory + File.separator + (wRead ? "MultiScript2b.dml" : "MultiScript2.dml");
    String dml3 = baseDirectory + File.separator + (wRead ? "MultiScript3b.dml" : "MultiScript3.dml");
    try {
        //run script 1
        Script script1 = dmlFromFile(dml1).in("$rows", rows).in("$cols", cols).out("X");
        Matrix X = ml.execute(script1).getMatrix("X");
        Script script2 = dmlFromFile(dml2).in("X", X).out("Y");
        Matrix Y = ml.execute(script2).getMatrix("Y");
        Script script3 = dmlFromFile(dml3).in("X", X).in("Y", Y).out("z");
        String z = ml.execute(script3).getString("z");
        System.out.println(z);
    } finally {
        DMLScript.rtplatform = oldplatform;
        // stop underlying spark context to allow single jvm tests (otherwise the
        // next test that tries to create a SparkContext would fail)
        spark.stop();
        // clear status mlcontext and spark exec context
        ml.close();
    }
}
Also used : RUNTIME_PLATFORM(org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM) Script(org.apache.sysml.api.mlcontext.Script) DMLScript(org.apache.sysml.api.DMLScript) SparkSession(org.apache.spark.sql.SparkSession) Matrix(org.apache.sysml.api.mlcontext.Matrix) MLContext(org.apache.sysml.api.mlcontext.MLContext)

Aggregations

SparkSession (org.apache.spark.sql.SparkSession)8 Row (org.apache.spark.sql.Row)4 StructType (org.apache.spark.sql.types.StructType)4 ArrayList (java.util.ArrayList)3 StructField (org.apache.spark.sql.types.StructField)3 Script (org.apache.sysml.api.mlcontext.Script)3 Test (org.junit.Test)3 DMLScript (org.apache.sysml.api.DMLScript)2 RUNTIME_PLATFORM (org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM)2 MLContext (org.apache.sysml.api.mlcontext.MLContext)2 Matrix (org.apache.sysml.api.mlcontext.Matrix)2 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)2 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)2 LongWritable (org.apache.hadoop.io.LongWritable)1 Text (org.apache.hadoop.io.Text)1 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)1 JavaRDD (org.apache.spark.api.java.JavaRDD)1 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)1 Dataset (org.apache.spark.sql.Dataset)1 Builder (org.apache.spark.sql.SparkSession.Builder)1