Search in sources :

Example 1 with FrameSchema

use of org.apache.sysml.api.mlcontext.FrameSchema in project systemml by apache.

the class FrameTest method testFrameGeneral.

private void testFrameGeneral(InputInfo iinfo, OutputInfo oinfo, boolean bFromDataFrame, boolean bToDataFrame) throws IOException, DMLException, ParseException {
    boolean oldConfig = DMLScript.USE_LOCAL_SPARK_CONFIG;
    DMLScript.USE_LOCAL_SPARK_CONFIG = true;
    RUNTIME_PLATFORM oldRT = DMLScript.rtplatform;
    DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
    int rowstart = 234, rowend = 1478, colstart = 125, colend = 568;
    int bRows = rowend - rowstart + 1, bCols = colend - colstart + 1;
    int rowstartC = 124, rowendC = 1178, colstartC = 143, colendC = 368;
    int cRows = rowendC - rowstartC + 1, cCols = colendC - colstartC + 1;
    HashMap<String, ValueType[]> outputSchema = new HashMap<String, ValueType[]>();
    HashMap<String, MatrixCharacteristics> outputMC = new HashMap<String, MatrixCharacteristics>();
    TestConfiguration config = getTestConfiguration(TEST_NAME);
    loadTestConfiguration(config);
    List<String> proArgs = new ArrayList<String>();
    proArgs.add(input("A"));
    proArgs.add(Integer.toString(rows));
    proArgs.add(Integer.toString(cols));
    proArgs.add(input("B"));
    proArgs.add(Integer.toString(bRows));
    proArgs.add(Integer.toString(bCols));
    proArgs.add(Integer.toString(rowstart));
    proArgs.add(Integer.toString(rowend));
    proArgs.add(Integer.toString(colstart));
    proArgs.add(Integer.toString(colend));
    proArgs.add(output("A"));
    proArgs.add(Integer.toString(rowstartC));
    proArgs.add(Integer.toString(rowendC));
    proArgs.add(Integer.toString(colstartC));
    proArgs.add(Integer.toString(colendC));
    proArgs.add(output("C"));
    fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml";
    ValueType[] schema = schemaMixedLarge;
    // initialize the frame data.
    List<ValueType> lschema = Arrays.asList(schema);
    fullRScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".R";
    rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + rowstart + " " + rowend + " " + colstart + " " + colend + " " + expectedDir() + " " + rowstartC + " " + rowendC + " " + colstartC + " " + colendC;
    double sparsity = sparsity1;
    double[][] A = getRandomMatrix(rows, cols, min, max, sparsity, 1111);
    writeInputFrameWithMTD("A", A, true, schema, oinfo);
    sparsity = sparsity2;
    double[][] B = getRandomMatrix((int) (bRows), (int) (bCols), min, max, sparsity, 2345);
    ValueType[] schemaB = new ValueType[bCols];
    for (int i = 0; i < bCols; ++i) schemaB[i] = schema[colstart - 1 + i];
    List<ValueType> lschemaB = Arrays.asList(schemaB);
    writeInputFrameWithMTD("B", B, true, schemaB, oinfo);
    ValueType[] schemaC = new ValueType[colendC - colstartC + 1];
    for (int i = 0; i < cCols; ++i) schemaC[i] = schema[colstartC - 1 + i];
    Dataset<Row> dfA = null, dfB = null;
    if (bFromDataFrame) {
        // Create DataFrame for input A
        StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schema, false);
        JavaRDD<Row> rowRDDA = FrameRDDConverterUtils.csvToRowRDD(sc, input("A"), DataExpression.DEFAULT_DELIM_DELIMITER, schema);
        dfA = spark.createDataFrame(rowRDDA, dfSchemaA);
        // Create DataFrame for input B
        StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaB, false);
        JavaRDD<Row> rowRDDB = FrameRDDConverterUtils.csvToRowRDD(sc, input("B"), DataExpression.DEFAULT_DELIM_DELIMITER, schemaB);
        dfB = spark.createDataFrame(rowRDDB, dfSchemaB);
    }
    try {
        Script script = ScriptFactory.dmlFromFile(fullDMLScriptName);
        String format = "csv";
        if (oinfo == OutputInfo.TextCellOutputInfo)
            format = "text";
        if (bFromDataFrame) {
            script.in("A", dfA);
        } else {
            JavaRDD<String> aIn = sc.textFile(input("A"));
            FrameSchema fs = new FrameSchema(lschema);
            FrameFormat ff = (format.equals("text")) ? FrameFormat.IJV : FrameFormat.CSV;
            FrameMetadata fm = new FrameMetadata(ff, fs, rows, cols);
            script.in("A", aIn, fm);
        }
        if (bFromDataFrame) {
            script.in("B", dfB);
        } else {
            JavaRDD<String> bIn = sc.textFile(input("B"));
            FrameSchema fs = new FrameSchema(lschemaB);
            FrameFormat ff = (format.equals("text")) ? FrameFormat.IJV : FrameFormat.CSV;
            FrameMetadata fm = new FrameMetadata(ff, fs, bRows, bCols);
            script.in("B", bIn, fm);
        }
        // Output one frame to HDFS and get one as RDD //TODO HDFS input/output to do
        script.out("A", "C");
        // set positional argument values
        for (int argNum = 1; argNum <= proArgs.size(); argNum++) {
            script.in("$" + argNum, proArgs.get(argNum - 1));
        }
        MLResults results = ml.execute(script);
        format = "csv";
        if (iinfo == InputInfo.TextCellInputInfo)
            format = "text";
        String fName = output("AB");
        try {
            MapReduceTool.deleteFileIfExistOnHDFS(fName);
        } catch (IOException e) {
            throw new DMLRuntimeException("Error: While deleting file on HDFS");
        }
        if (!bToDataFrame) {
            if (format.equals("text")) {
                JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("A");
                javaRDDStringIJV.saveAsTextFile(fName);
            } else {
                JavaRDD<String> javaRDDStringCSV = results.getJavaRDDStringCSV("A");
                javaRDDStringCSV.saveAsTextFile(fName);
            }
        } else {
            Dataset<Row> df = results.getDataFrame("A");
            // Convert back DataFrame to binary block for comparison using original binary to converted DF and back to binary
            MatrixCharacteristics mc = new MatrixCharacteristics(rows, cols, -1, -1, -1);
            JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils.dataFrameToBinaryBlock(sc, df, mc, bFromDataFrame).mapToPair(new LongFrameToLongWritableFrameFunction());
            rddOut.saveAsHadoopFile(output("AB"), LongWritable.class, FrameBlock.class, OutputInfo.BinaryBlockOutputInfo.outputFormatClass);
        }
        fName = output("C");
        try {
            MapReduceTool.deleteFileIfExistOnHDFS(fName);
        } catch (IOException e) {
            throw new DMLRuntimeException("Error: While deleting file on HDFS");
        }
        if (!bToDataFrame) {
            if (format.equals("text")) {
                JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("C");
                javaRDDStringIJV.saveAsTextFile(fName);
            } else {
                JavaRDD<String> javaRDDStringCSV = results.getJavaRDDStringCSV("C");
                javaRDDStringCSV.saveAsTextFile(fName);
            }
        } else {
            Dataset<Row> df = results.getDataFrame("C");
            // Convert back DataFrame to binary block for comparison using original binary to converted DF and back to binary
            MatrixCharacteristics mc = new MatrixCharacteristics(cRows, cCols, -1, -1, -1);
            JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils.dataFrameToBinaryBlock(sc, df, mc, bFromDataFrame).mapToPair(new LongFrameToLongWritableFrameFunction());
            rddOut.saveAsHadoopFile(fName, LongWritable.class, FrameBlock.class, OutputInfo.BinaryBlockOutputInfo.outputFormatClass);
        }
        runRScript(true);
        outputSchema.put("AB", schema);
        outputMC.put("AB", new MatrixCharacteristics(rows, cols, -1, -1));
        outputSchema.put("C", schemaC);
        outputMC.put("C", new MatrixCharacteristics(cRows, cCols, -1, -1));
        for (String file : config.getOutputFiles()) {
            MatrixCharacteristics md = outputMC.get(file);
            FrameBlock frameBlock = readDMLFrameFromHDFS(file, iinfo, md);
            FrameBlock frameRBlock = readRFrameFromHDFS(file + ".csv", InputInfo.CSVInputInfo, md);
            ValueType[] schemaOut = outputSchema.get(file);
            verifyFrameData(frameBlock, frameRBlock, schemaOut);
            System.out.println("File " + file + " processed successfully.");
        }
        System.out.println("Frame MLContext test completed successfully.");
    } finally {
        DMLScript.rtplatform = oldRT;
        DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig;
    }
}
Also used : FrameFormat(org.apache.sysml.api.mlcontext.FrameFormat) StructType(org.apache.spark.sql.types.StructType) HashMap(java.util.HashMap) MLResults(org.apache.sysml.api.mlcontext.MLResults) TestConfiguration(org.apache.sysml.test.integration.TestConfiguration) ArrayList(java.util.ArrayList) FrameBlock(org.apache.sysml.runtime.matrix.data.FrameBlock) LongWritable(org.apache.hadoop.io.LongWritable) LongFrameToLongWritableFrameFunction(org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction) Script(org.apache.sysml.api.mlcontext.Script) DMLScript(org.apache.sysml.api.DMLScript) ValueType(org.apache.sysml.parser.Expression.ValueType) FrameSchema(org.apache.sysml.api.mlcontext.FrameSchema) IOException(java.io.IOException) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) RUNTIME_PLATFORM(org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM) Row(org.apache.spark.sql.Row) FrameMetadata(org.apache.sysml.api.mlcontext.FrameMetadata)

Example 2 with FrameSchema

use of org.apache.sysml.api.mlcontext.FrameSchema in project incubator-systemml by apache.

the class FrameTest method testFrameGeneral.

private void testFrameGeneral(InputInfo iinfo, OutputInfo oinfo, boolean bFromDataFrame, boolean bToDataFrame) throws IOException, DMLException, ParseException {
    boolean oldConfig = DMLScript.USE_LOCAL_SPARK_CONFIG;
    DMLScript.USE_LOCAL_SPARK_CONFIG = true;
    RUNTIME_PLATFORM oldRT = DMLScript.rtplatform;
    DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
    int rowstart = 234, rowend = 1478, colstart = 125, colend = 568;
    int bRows = rowend - rowstart + 1, bCols = colend - colstart + 1;
    int rowstartC = 124, rowendC = 1178, colstartC = 143, colendC = 368;
    int cRows = rowendC - rowstartC + 1, cCols = colendC - colstartC + 1;
    HashMap<String, ValueType[]> outputSchema = new HashMap<String, ValueType[]>();
    HashMap<String, MatrixCharacteristics> outputMC = new HashMap<String, MatrixCharacteristics>();
    TestConfiguration config = getTestConfiguration(TEST_NAME);
    loadTestConfiguration(config);
    List<String> proArgs = new ArrayList<String>();
    proArgs.add(input("A"));
    proArgs.add(Integer.toString(rows));
    proArgs.add(Integer.toString(cols));
    proArgs.add(input("B"));
    proArgs.add(Integer.toString(bRows));
    proArgs.add(Integer.toString(bCols));
    proArgs.add(Integer.toString(rowstart));
    proArgs.add(Integer.toString(rowend));
    proArgs.add(Integer.toString(colstart));
    proArgs.add(Integer.toString(colend));
    proArgs.add(output("A"));
    proArgs.add(Integer.toString(rowstartC));
    proArgs.add(Integer.toString(rowendC));
    proArgs.add(Integer.toString(colstartC));
    proArgs.add(Integer.toString(colendC));
    proArgs.add(output("C"));
    fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml";
    ValueType[] schema = schemaMixedLarge;
    // initialize the frame data.
    List<ValueType> lschema = Arrays.asList(schema);
    fullRScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".R";
    rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + rowstart + " " + rowend + " " + colstart + " " + colend + " " + expectedDir() + " " + rowstartC + " " + rowendC + " " + colstartC + " " + colendC;
    double sparsity = sparsity1;
    double[][] A = getRandomMatrix(rows, cols, min, max, sparsity, 1111);
    writeInputFrameWithMTD("A", A, true, schema, oinfo);
    sparsity = sparsity2;
    double[][] B = getRandomMatrix((int) (bRows), (int) (bCols), min, max, sparsity, 2345);
    ValueType[] schemaB = new ValueType[bCols];
    for (int i = 0; i < bCols; ++i) schemaB[i] = schema[colstart - 1 + i];
    List<ValueType> lschemaB = Arrays.asList(schemaB);
    writeInputFrameWithMTD("B", B, true, schemaB, oinfo);
    ValueType[] schemaC = new ValueType[colendC - colstartC + 1];
    for (int i = 0; i < cCols; ++i) schemaC[i] = schema[colstartC - 1 + i];
    Dataset<Row> dfA = null, dfB = null;
    if (bFromDataFrame) {
        // Create DataFrame for input A
        StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schema, false);
        JavaRDD<Row> rowRDDA = FrameRDDConverterUtils.csvToRowRDD(sc, input("A"), DataExpression.DEFAULT_DELIM_DELIMITER, schema);
        dfA = spark.createDataFrame(rowRDDA, dfSchemaA);
        // Create DataFrame for input B
        StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaB, false);
        JavaRDD<Row> rowRDDB = FrameRDDConverterUtils.csvToRowRDD(sc, input("B"), DataExpression.DEFAULT_DELIM_DELIMITER, schemaB);
        dfB = spark.createDataFrame(rowRDDB, dfSchemaB);
    }
    try {
        Script script = ScriptFactory.dmlFromFile(fullDMLScriptName);
        String format = "csv";
        if (oinfo == OutputInfo.TextCellOutputInfo)
            format = "text";
        if (bFromDataFrame) {
            script.in("A", dfA);
        } else {
            JavaRDD<String> aIn = sc.textFile(input("A"));
            FrameSchema fs = new FrameSchema(lschema);
            FrameFormat ff = (format.equals("text")) ? FrameFormat.IJV : FrameFormat.CSV;
            FrameMetadata fm = new FrameMetadata(ff, fs, rows, cols);
            script.in("A", aIn, fm);
        }
        if (bFromDataFrame) {
            script.in("B", dfB);
        } else {
            JavaRDD<String> bIn = sc.textFile(input("B"));
            FrameSchema fs = new FrameSchema(lschemaB);
            FrameFormat ff = (format.equals("text")) ? FrameFormat.IJV : FrameFormat.CSV;
            FrameMetadata fm = new FrameMetadata(ff, fs, bRows, bCols);
            script.in("B", bIn, fm);
        }
        // Output one frame to HDFS and get one as RDD //TODO HDFS input/output to do
        script.out("A", "C");
        // set positional argument values
        for (int argNum = 1; argNum <= proArgs.size(); argNum++) {
            script.in("$" + argNum, proArgs.get(argNum - 1));
        }
        MLResults results = ml.execute(script);
        format = "csv";
        if (iinfo == InputInfo.TextCellInputInfo)
            format = "text";
        String fName = output("AB");
        try {
            MapReduceTool.deleteFileIfExistOnHDFS(fName);
        } catch (IOException e) {
            throw new DMLRuntimeException("Error: While deleting file on HDFS");
        }
        if (!bToDataFrame) {
            if (format.equals("text")) {
                JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("A");
                javaRDDStringIJV.saveAsTextFile(fName);
            } else {
                JavaRDD<String> javaRDDStringCSV = results.getJavaRDDStringCSV("A");
                javaRDDStringCSV.saveAsTextFile(fName);
            }
        } else {
            Dataset<Row> df = results.getDataFrame("A");
            // Convert back DataFrame to binary block for comparison using original binary to converted DF and back to binary
            MatrixCharacteristics mc = new MatrixCharacteristics(rows, cols, -1, -1, -1);
            JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils.dataFrameToBinaryBlock(sc, df, mc, bFromDataFrame).mapToPair(new LongFrameToLongWritableFrameFunction());
            rddOut.saveAsHadoopFile(output("AB"), LongWritable.class, FrameBlock.class, OutputInfo.BinaryBlockOutputInfo.outputFormatClass);
        }
        fName = output("C");
        try {
            MapReduceTool.deleteFileIfExistOnHDFS(fName);
        } catch (IOException e) {
            throw new DMLRuntimeException("Error: While deleting file on HDFS");
        }
        if (!bToDataFrame) {
            if (format.equals("text")) {
                JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("C");
                javaRDDStringIJV.saveAsTextFile(fName);
            } else {
                JavaRDD<String> javaRDDStringCSV = results.getJavaRDDStringCSV("C");
                javaRDDStringCSV.saveAsTextFile(fName);
            }
        } else {
            Dataset<Row> df = results.getDataFrame("C");
            // Convert back DataFrame to binary block for comparison using original binary to converted DF and back to binary
            MatrixCharacteristics mc = new MatrixCharacteristics(cRows, cCols, -1, -1, -1);
            JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils.dataFrameToBinaryBlock(sc, df, mc, bFromDataFrame).mapToPair(new LongFrameToLongWritableFrameFunction());
            rddOut.saveAsHadoopFile(fName, LongWritable.class, FrameBlock.class, OutputInfo.BinaryBlockOutputInfo.outputFormatClass);
        }
        runRScript(true);
        outputSchema.put("AB", schema);
        outputMC.put("AB", new MatrixCharacteristics(rows, cols, -1, -1));
        outputSchema.put("C", schemaC);
        outputMC.put("C", new MatrixCharacteristics(cRows, cCols, -1, -1));
        for (String file : config.getOutputFiles()) {
            MatrixCharacteristics md = outputMC.get(file);
            FrameBlock frameBlock = readDMLFrameFromHDFS(file, iinfo, md);
            FrameBlock frameRBlock = readRFrameFromHDFS(file + ".csv", InputInfo.CSVInputInfo, md);
            ValueType[] schemaOut = outputSchema.get(file);
            verifyFrameData(frameBlock, frameRBlock, schemaOut);
            System.out.println("File " + file + " processed successfully.");
        }
        System.out.println("Frame MLContext test completed successfully.");
    } finally {
        DMLScript.rtplatform = oldRT;
        DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig;
    }
}
Also used : FrameFormat(org.apache.sysml.api.mlcontext.FrameFormat) StructType(org.apache.spark.sql.types.StructType) HashMap(java.util.HashMap) MLResults(org.apache.sysml.api.mlcontext.MLResults) TestConfiguration(org.apache.sysml.test.integration.TestConfiguration) ArrayList(java.util.ArrayList) FrameBlock(org.apache.sysml.runtime.matrix.data.FrameBlock) LongWritable(org.apache.hadoop.io.LongWritable) LongFrameToLongWritableFrameFunction(org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction) Script(org.apache.sysml.api.mlcontext.Script) DMLScript(org.apache.sysml.api.DMLScript) ValueType(org.apache.sysml.parser.Expression.ValueType) FrameSchema(org.apache.sysml.api.mlcontext.FrameSchema) IOException(java.io.IOException) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) RUNTIME_PLATFORM(org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM) Row(org.apache.spark.sql.Row) FrameMetadata(org.apache.sysml.api.mlcontext.FrameMetadata)

Example 3 with FrameSchema

use of org.apache.sysml.api.mlcontext.FrameSchema in project incubator-systemml by apache.

the class MLContextFrameTest method testFrame.

public void testFrame(FrameFormat format, SCRIPT_TYPE script_type, IO_TYPE inputType, IO_TYPE outputType) {
    System.out.println("MLContextTest - Frame JavaRDD<String> for format: " + format + " Script: " + script_type);
    List<String> listA = new ArrayList<String>();
    List<String> listB = new ArrayList<String>();
    FrameMetadata fmA = null, fmB = null;
    Script script = null;
    ValueType[] schemaA = { ValueType.INT, ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN };
    List<ValueType> lschemaA = Arrays.asList(schemaA);
    FrameSchema fschemaA = new FrameSchema(lschemaA);
    ValueType[] schemaB = { ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN };
    List<ValueType> lschemaB = Arrays.asList(schemaB);
    FrameSchema fschemaB = new FrameSchema(lschemaB);
    if (inputType != IO_TYPE.FILE) {
        if (format == FrameFormat.CSV) {
            listA.add("1,Str2,3.0,true");
            listA.add("4,Str5,6.0,false");
            listA.add("7,Str8,9.0,true");
            listB.add("Str12,13.0,true");
            listB.add("Str25,26.0,false");
            fmA = new FrameMetadata(FrameFormat.CSV, fschemaA, 3, 4);
            fmB = new FrameMetadata(FrameFormat.CSV, fschemaB, 2, 3);
        } else if (format == FrameFormat.IJV) {
            listA.add("1 1 1");
            listA.add("1 2 Str2");
            listA.add("1 3 3.0");
            listA.add("1 4 true");
            listA.add("2 1 4");
            listA.add("2 2 Str5");
            listA.add("2 3 6.0");
            listA.add("2 4 false");
            listA.add("3 1 7");
            listA.add("3 2 Str8");
            listA.add("3 3 9.0");
            listA.add("3 4 true");
            listB.add("1 1 Str12");
            listB.add("1 2 13.0");
            listB.add("1 3 true");
            listB.add("2 1 Str25");
            listB.add("2 2 26.0");
            listB.add("2 3 false");
            fmA = new FrameMetadata(FrameFormat.IJV, fschemaA, 3, 4);
            fmB = new FrameMetadata(FrameFormat.IJV, fschemaB, 2, 3);
        }
        JavaRDD<String> javaRDDA = sc.parallelize(listA);
        JavaRDD<String> javaRDDB = sc.parallelize(listB);
        if (inputType == IO_TYPE.DATAFRAME) {
            JavaRDD<Row> javaRddRowA = FrameRDDConverterUtils.csvToRowRDD(sc, javaRDDA, CSV_DELIM, schemaA);
            JavaRDD<Row> javaRddRowB = FrameRDDConverterUtils.csvToRowRDD(sc, javaRDDB, CSV_DELIM, schemaB);
            // Create DataFrame
            StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaA, false);
            Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, dfSchemaA);
            StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaB, false);
            Dataset<Row> dataFrameB = spark.createDataFrame(javaRddRowB, dfSchemaB);
            if (script_type == SCRIPT_TYPE.DML)
                script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", dataFrameA, fmA).in("B", dataFrameB, fmB).out("A").out("C");
            else if (script_type == SCRIPT_TYPE.PYDML)
                // DO NOT USE ; at the end of any statment, it throws NPE
                script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", dataFrameA, fmA).in("B", dataFrameB, fmB).in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
        } else {
            if (inputType == IO_TYPE.JAVA_RDD_STR_CSV || inputType == IO_TYPE.JAVA_RDD_STR_IJV) {
                if (script_type == SCRIPT_TYPE.DML)
                    script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", javaRDDA, fmA).in("B", javaRDDB, fmB).out("A").out("C");
                else if (script_type == SCRIPT_TYPE.PYDML)
                    // DO NOT USE ; at the end of any statment, it throws
                    // NPE
                    script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", javaRDDA, fmA).in("B", javaRDDB, fmB).in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
            } else if (inputType == IO_TYPE.RDD_STR_CSV || inputType == IO_TYPE.RDD_STR_IJV) {
                RDD<String> rddA = JavaRDD.toRDD(javaRDDA);
                RDD<String> rddB = JavaRDD.toRDD(javaRDDB);
                if (script_type == SCRIPT_TYPE.DML)
                    script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", rddA, fmA).in("B", rddB, fmB).out("A").out("C");
                else if (script_type == SCRIPT_TYPE.PYDML)
                    // DO NOT USE ; at the end of any statment, it throws
                    // NPE
                    script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", rddA, fmA).in("B", rddB, fmB).in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
            }
        }
    } else {
        // Input type is file
        String fileA = null, fileB = null;
        if (format == FrameFormat.CSV) {
            fileA = baseDirectory + File.separator + "FrameA.csv";
            fileB = baseDirectory + File.separator + "FrameB.csv";
        } else if (format == FrameFormat.IJV) {
            fileA = baseDirectory + File.separator + "FrameA.ijv";
            fileB = baseDirectory + File.separator + "FrameB.ijv";
        }
        if (script_type == SCRIPT_TYPE.DML)
            script = dml("A=read($A); B=read($B);A[2:3,2:4]=B;C=A[2:3,2:3];A[1,1]=234").in("$A", fileA, fmA).in("$B", fileB, fmB).out("A").out("C");
        else if (script_type == SCRIPT_TYPE.PYDML)
            // DO NOT USE ; at the end of any statment, it throws NPE
            script = pydml("A=load($A)\nB=load($B)\nA[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("$A", fileA).in("$B", fileB).in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
    }
    MLResults mlResults = ml.execute(script);
    // Validate output schema
    List<ValueType> lschemaOutA = Arrays.asList(mlResults.getFrameObject("A").getSchema());
    List<ValueType> lschemaOutC = Arrays.asList(mlResults.getFrameObject("C").getSchema());
    Assert.assertEquals(ValueType.INT, lschemaOutA.get(0));
    Assert.assertEquals(ValueType.STRING, lschemaOutA.get(1));
    Assert.assertEquals(ValueType.DOUBLE, lschemaOutA.get(2));
    Assert.assertEquals(ValueType.BOOLEAN, lschemaOutA.get(3));
    Assert.assertEquals(ValueType.STRING, lschemaOutC.get(0));
    Assert.assertEquals(ValueType.DOUBLE, lschemaOutC.get(1));
    if (outputType == IO_TYPE.JAVA_RDD_STR_CSV) {
        JavaRDD<String> javaRDDStringCSVA = mlResults.getJavaRDDStringCSV("A");
        List<String> linesA = javaRDDStringCSVA.collect();
        Assert.assertEquals("1,Str2,3.0,true", linesA.get(0));
        Assert.assertEquals("4,Str12,13.0,true", linesA.get(1));
        Assert.assertEquals("7,Str25,26.0,false", linesA.get(2));
        JavaRDD<String> javaRDDStringCSVC = mlResults.getJavaRDDStringCSV("C");
        List<String> linesC = javaRDDStringCSVC.collect();
        Assert.assertEquals("Str12,13.0", linesC.get(0));
        Assert.assertEquals("Str25,26.0", linesC.get(1));
    } else if (outputType == IO_TYPE.JAVA_RDD_STR_IJV) {
        JavaRDD<String> javaRDDStringIJVA = mlResults.getJavaRDDStringIJV("A");
        List<String> linesA = javaRDDStringIJVA.collect();
        Assert.assertEquals("1 1 1", linesA.get(0));
        Assert.assertEquals("1 2 Str2", linesA.get(1));
        Assert.assertEquals("1 3 3.0", linesA.get(2));
        Assert.assertEquals("1 4 true", linesA.get(3));
        Assert.assertEquals("2 1 4", linesA.get(4));
        Assert.assertEquals("2 2 Str12", linesA.get(5));
        Assert.assertEquals("2 3 13.0", linesA.get(6));
        Assert.assertEquals("2 4 true", linesA.get(7));
        JavaRDD<String> javaRDDStringIJVC = mlResults.getJavaRDDStringIJV("C");
        List<String> linesC = javaRDDStringIJVC.collect();
        Assert.assertEquals("1 1 Str12", linesC.get(0));
        Assert.assertEquals("1 2 13.0", linesC.get(1));
        Assert.assertEquals("2 1 Str25", linesC.get(2));
        Assert.assertEquals("2 2 26.0", linesC.get(3));
    } else if (outputType == IO_TYPE.RDD_STR_CSV) {
        RDD<String> rddStringCSVA = mlResults.getRDDStringCSV("A");
        Iterator<String> iteratorA = rddStringCSVA.toLocalIterator();
        Assert.assertEquals("1,Str2,3.0,true", iteratorA.next());
        Assert.assertEquals("4,Str12,13.0,true", iteratorA.next());
        Assert.assertEquals("7,Str25,26.0,false", iteratorA.next());
        RDD<String> rddStringCSVC = mlResults.getRDDStringCSV("C");
        Iterator<String> iteratorC = rddStringCSVC.toLocalIterator();
        Assert.assertEquals("Str12,13.0", iteratorC.next());
        Assert.assertEquals("Str25,26.0", iteratorC.next());
    } else if (outputType == IO_TYPE.RDD_STR_IJV) {
        RDD<String> rddStringIJVA = mlResults.getRDDStringIJV("A");
        Iterator<String> iteratorA = rddStringIJVA.toLocalIterator();
        Assert.assertEquals("1 1 1", iteratorA.next());
        Assert.assertEquals("1 2 Str2", iteratorA.next());
        Assert.assertEquals("1 3 3.0", iteratorA.next());
        Assert.assertEquals("1 4 true", iteratorA.next());
        Assert.assertEquals("2 1 4", iteratorA.next());
        Assert.assertEquals("2 2 Str12", iteratorA.next());
        Assert.assertEquals("2 3 13.0", iteratorA.next());
        Assert.assertEquals("2 4 true", iteratorA.next());
        Assert.assertEquals("3 1 7", iteratorA.next());
        Assert.assertEquals("3 2 Str25", iteratorA.next());
        Assert.assertEquals("3 3 26.0", iteratorA.next());
        Assert.assertEquals("3 4 false", iteratorA.next());
        RDD<String> rddStringIJVC = mlResults.getRDDStringIJV("C");
        Iterator<String> iteratorC = rddStringIJVC.toLocalIterator();
        Assert.assertEquals("1 1 Str12", iteratorC.next());
        Assert.assertEquals("1 2 13.0", iteratorC.next());
        Assert.assertEquals("2 1 Str25", iteratorC.next());
        Assert.assertEquals("2 2 26.0", iteratorC.next());
    } else if (outputType == IO_TYPE.DATAFRAME) {
        Dataset<Row> dataFrameA = mlResults.getDataFrame("A").drop(RDDConverterUtils.DF_ID_COLUMN);
        StructType dfschemaA = dataFrameA.schema();
        StructField structTypeA = dfschemaA.apply(0);
        Assert.assertEquals(DataTypes.LongType, structTypeA.dataType());
        structTypeA = dfschemaA.apply(1);
        Assert.assertEquals(DataTypes.StringType, structTypeA.dataType());
        structTypeA = dfschemaA.apply(2);
        Assert.assertEquals(DataTypes.DoubleType, structTypeA.dataType());
        structTypeA = dfschemaA.apply(3);
        Assert.assertEquals(DataTypes.BooleanType, structTypeA.dataType());
        List<Row> listAOut = dataFrameA.collectAsList();
        Row row1 = listAOut.get(0);
        Assert.assertEquals("Mismatch with expected value", Long.valueOf(1), row1.get(0));
        Assert.assertEquals("Mismatch with expected value", "Str2", row1.get(1));
        Assert.assertEquals("Mismatch with expected value", 3.0, row1.get(2));
        Assert.assertEquals("Mismatch with expected value", true, row1.get(3));
        Row row2 = listAOut.get(1);
        Assert.assertEquals("Mismatch with expected value", Long.valueOf(4), row2.get(0));
        Assert.assertEquals("Mismatch with expected value", "Str12", row2.get(1));
        Assert.assertEquals("Mismatch with expected value", 13.0, row2.get(2));
        Assert.assertEquals("Mismatch with expected value", true, row2.get(3));
        Dataset<Row> dataFrameC = mlResults.getDataFrame("C").drop(RDDConverterUtils.DF_ID_COLUMN);
        StructType dfschemaC = dataFrameC.schema();
        StructField structTypeC = dfschemaC.apply(0);
        Assert.assertEquals(DataTypes.StringType, structTypeC.dataType());
        structTypeC = dfschemaC.apply(1);
        Assert.assertEquals(DataTypes.DoubleType, structTypeC.dataType());
        List<Row> listCOut = dataFrameC.collectAsList();
        Row row3 = listCOut.get(0);
        Assert.assertEquals("Mismatch with expected value", "Str12", row3.get(0));
        Assert.assertEquals("Mismatch with expected value", 13.0, row3.get(1));
        Row row4 = listCOut.get(1);
        Assert.assertEquals("Mismatch with expected value", "Str25", row4.get(0));
        Assert.assertEquals("Mismatch with expected value", 26.0, row4.get(1));
    } else {
        String[][] frameA = mlResults.getFrameAs2DStringArray("A");
        Assert.assertEquals("Str2", frameA[0][1]);
        Assert.assertEquals("3.0", frameA[0][2]);
        Assert.assertEquals("13.0", frameA[1][2]);
        Assert.assertEquals("true", frameA[1][3]);
        Assert.assertEquals("Str25", frameA[2][1]);
        String[][] frameC = mlResults.getFrameAs2DStringArray("C");
        Assert.assertEquals("Str12", frameC[0][0]);
        Assert.assertEquals("Str25", frameC[1][0]);
        Assert.assertEquals("13.0", frameC[0][1]);
        Assert.assertEquals("26.0", frameC[1][1]);
    }
}
Also used : Script(org.apache.sysml.api.mlcontext.Script) StructType(org.apache.spark.sql.types.StructType) ValueType(org.apache.sysml.parser.Expression.ValueType) MLResults(org.apache.sysml.api.mlcontext.MLResults) ArrayList(java.util.ArrayList) FrameSchema(org.apache.sysml.api.mlcontext.FrameSchema) JavaRDD(org.apache.spark.api.java.JavaRDD) JavaRDD(org.apache.spark.api.java.JavaRDD) RDD(org.apache.spark.rdd.RDD) StructField(org.apache.spark.sql.types.StructField) Iterator(scala.collection.Iterator) ArrayList(java.util.ArrayList) List(java.util.List) Row(org.apache.spark.sql.Row) CommaSeparatedValueStringToDoubleArrayRow(org.apache.sysml.test.integration.mlcontext.MLContextTest.CommaSeparatedValueStringToDoubleArrayRow) FrameMetadata(org.apache.sysml.api.mlcontext.FrameMetadata)

Example 4 with FrameSchema

use of org.apache.sysml.api.mlcontext.FrameSchema in project systemml by apache.

the class MLContextFrameTest method testFrame.

public void testFrame(FrameFormat format, SCRIPT_TYPE script_type, IO_TYPE inputType, IO_TYPE outputType) {
    System.out.println("MLContextTest - Frame JavaRDD<String> for format: " + format + " Script: " + script_type);
    List<String> listA = new ArrayList<String>();
    List<String> listB = new ArrayList<String>();
    FrameMetadata fmA = null, fmB = null;
    Script script = null;
    ValueType[] schemaA = { ValueType.INT, ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN };
    List<ValueType> lschemaA = Arrays.asList(schemaA);
    FrameSchema fschemaA = new FrameSchema(lschemaA);
    ValueType[] schemaB = { ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN };
    List<ValueType> lschemaB = Arrays.asList(schemaB);
    FrameSchema fschemaB = new FrameSchema(lschemaB);
    if (inputType != IO_TYPE.FILE) {
        if (format == FrameFormat.CSV) {
            listA.add("1,Str2,3.0,true");
            listA.add("4,Str5,6.0,false");
            listA.add("7,Str8,9.0,true");
            listB.add("Str12,13.0,true");
            listB.add("Str25,26.0,false");
            fmA = new FrameMetadata(FrameFormat.CSV, fschemaA, 3, 4);
            fmB = new FrameMetadata(FrameFormat.CSV, fschemaB, 2, 3);
        } else if (format == FrameFormat.IJV) {
            listA.add("1 1 1");
            listA.add("1 2 Str2");
            listA.add("1 3 3.0");
            listA.add("1 4 true");
            listA.add("2 1 4");
            listA.add("2 2 Str5");
            listA.add("2 3 6.0");
            listA.add("2 4 false");
            listA.add("3 1 7");
            listA.add("3 2 Str8");
            listA.add("3 3 9.0");
            listA.add("3 4 true");
            listB.add("1 1 Str12");
            listB.add("1 2 13.0");
            listB.add("1 3 true");
            listB.add("2 1 Str25");
            listB.add("2 2 26.0");
            listB.add("2 3 false");
            fmA = new FrameMetadata(FrameFormat.IJV, fschemaA, 3, 4);
            fmB = new FrameMetadata(FrameFormat.IJV, fschemaB, 2, 3);
        }
        JavaRDD<String> javaRDDA = sc.parallelize(listA);
        JavaRDD<String> javaRDDB = sc.parallelize(listB);
        if (inputType == IO_TYPE.DATAFRAME) {
            JavaRDD<Row> javaRddRowA = FrameRDDConverterUtils.csvToRowRDD(sc, javaRDDA, CSV_DELIM, schemaA);
            JavaRDD<Row> javaRddRowB = FrameRDDConverterUtils.csvToRowRDD(sc, javaRDDB, CSV_DELIM, schemaB);
            // Create DataFrame
            StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaA, false);
            Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, dfSchemaA);
            StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaB, false);
            Dataset<Row> dataFrameB = spark.createDataFrame(javaRddRowB, dfSchemaB);
            if (script_type == SCRIPT_TYPE.DML)
                script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", dataFrameA, fmA).in("B", dataFrameB, fmB).out("A").out("C");
            else if (script_type == SCRIPT_TYPE.PYDML)
                // DO NOT USE ; at the end of any statment, it throws NPE
                script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", dataFrameA, fmA).in("B", dataFrameB, fmB).in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
        } else {
            if (inputType == IO_TYPE.JAVA_RDD_STR_CSV || inputType == IO_TYPE.JAVA_RDD_STR_IJV) {
                if (script_type == SCRIPT_TYPE.DML)
                    script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", javaRDDA, fmA).in("B", javaRDDB, fmB).out("A").out("C");
                else if (script_type == SCRIPT_TYPE.PYDML)
                    // DO NOT USE ; at the end of any statment, it throws
                    // NPE
                    script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", javaRDDA, fmA).in("B", javaRDDB, fmB).in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
            } else if (inputType == IO_TYPE.RDD_STR_CSV || inputType == IO_TYPE.RDD_STR_IJV) {
                RDD<String> rddA = JavaRDD.toRDD(javaRDDA);
                RDD<String> rddB = JavaRDD.toRDD(javaRDDB);
                if (script_type == SCRIPT_TYPE.DML)
                    script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", rddA, fmA).in("B", rddB, fmB).out("A").out("C");
                else if (script_type == SCRIPT_TYPE.PYDML)
                    // DO NOT USE ; at the end of any statment, it throws
                    // NPE
                    script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", rddA, fmA).in("B", rddB, fmB).in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
            }
        }
    } else {
        // Input type is file
        String fileA = null, fileB = null;
        if (format == FrameFormat.CSV) {
            fileA = baseDirectory + File.separator + "FrameA.csv";
            fileB = baseDirectory + File.separator + "FrameB.csv";
        } else if (format == FrameFormat.IJV) {
            fileA = baseDirectory + File.separator + "FrameA.ijv";
            fileB = baseDirectory + File.separator + "FrameB.ijv";
        }
        if (script_type == SCRIPT_TYPE.DML)
            script = dml("A=read($A); B=read($B);A[2:3,2:4]=B;C=A[2:3,2:3];A[1,1]=234").in("$A", fileA, fmA).in("$B", fileB, fmB).out("A").out("C");
        else if (script_type == SCRIPT_TYPE.PYDML)
            // DO NOT USE ; at the end of any statment, it throws NPE
            script = pydml("A=load($A)\nB=load($B)\nA[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("$A", fileA).in("$B", fileB).in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
    }
    MLResults mlResults = ml.execute(script);
    // Validate output schema
    List<ValueType> lschemaOutA = Arrays.asList(mlResults.getFrameObject("A").getSchema());
    List<ValueType> lschemaOutC = Arrays.asList(mlResults.getFrameObject("C").getSchema());
    Assert.assertEquals(ValueType.INT, lschemaOutA.get(0));
    Assert.assertEquals(ValueType.STRING, lschemaOutA.get(1));
    Assert.assertEquals(ValueType.DOUBLE, lschemaOutA.get(2));
    Assert.assertEquals(ValueType.BOOLEAN, lschemaOutA.get(3));
    Assert.assertEquals(ValueType.STRING, lschemaOutC.get(0));
    Assert.assertEquals(ValueType.DOUBLE, lschemaOutC.get(1));
    if (outputType == IO_TYPE.JAVA_RDD_STR_CSV) {
        JavaRDD<String> javaRDDStringCSVA = mlResults.getJavaRDDStringCSV("A");
        List<String> linesA = javaRDDStringCSVA.collect();
        Assert.assertEquals("1,Str2,3.0,true", linesA.get(0));
        Assert.assertEquals("4,Str12,13.0,true", linesA.get(1));
        Assert.assertEquals("7,Str25,26.0,false", linesA.get(2));
        JavaRDD<String> javaRDDStringCSVC = mlResults.getJavaRDDStringCSV("C");
        List<String> linesC = javaRDDStringCSVC.collect();
        Assert.assertEquals("Str12,13.0", linesC.get(0));
        Assert.assertEquals("Str25,26.0", linesC.get(1));
    } else if (outputType == IO_TYPE.JAVA_RDD_STR_IJV) {
        JavaRDD<String> javaRDDStringIJVA = mlResults.getJavaRDDStringIJV("A");
        List<String> linesA = javaRDDStringIJVA.collect();
        Assert.assertEquals("1 1 1", linesA.get(0));
        Assert.assertEquals("1 2 Str2", linesA.get(1));
        Assert.assertEquals("1 3 3.0", linesA.get(2));
        Assert.assertEquals("1 4 true", linesA.get(3));
        Assert.assertEquals("2 1 4", linesA.get(4));
        Assert.assertEquals("2 2 Str12", linesA.get(5));
        Assert.assertEquals("2 3 13.0", linesA.get(6));
        Assert.assertEquals("2 4 true", linesA.get(7));
        JavaRDD<String> javaRDDStringIJVC = mlResults.getJavaRDDStringIJV("C");
        List<String> linesC = javaRDDStringIJVC.collect();
        Assert.assertEquals("1 1 Str12", linesC.get(0));
        Assert.assertEquals("1 2 13.0", linesC.get(1));
        Assert.assertEquals("2 1 Str25", linesC.get(2));
        Assert.assertEquals("2 2 26.0", linesC.get(3));
    } else if (outputType == IO_TYPE.RDD_STR_CSV) {
        RDD<String> rddStringCSVA = mlResults.getRDDStringCSV("A");
        Iterator<String> iteratorA = rddStringCSVA.toLocalIterator();
        Assert.assertEquals("1,Str2,3.0,true", iteratorA.next());
        Assert.assertEquals("4,Str12,13.0,true", iteratorA.next());
        Assert.assertEquals("7,Str25,26.0,false", iteratorA.next());
        RDD<String> rddStringCSVC = mlResults.getRDDStringCSV("C");
        Iterator<String> iteratorC = rddStringCSVC.toLocalIterator();
        Assert.assertEquals("Str12,13.0", iteratorC.next());
        Assert.assertEquals("Str25,26.0", iteratorC.next());
    } else if (outputType == IO_TYPE.RDD_STR_IJV) {
        RDD<String> rddStringIJVA = mlResults.getRDDStringIJV("A");
        Iterator<String> iteratorA = rddStringIJVA.toLocalIterator();
        Assert.assertEquals("1 1 1", iteratorA.next());
        Assert.assertEquals("1 2 Str2", iteratorA.next());
        Assert.assertEquals("1 3 3.0", iteratorA.next());
        Assert.assertEquals("1 4 true", iteratorA.next());
        Assert.assertEquals("2 1 4", iteratorA.next());
        Assert.assertEquals("2 2 Str12", iteratorA.next());
        Assert.assertEquals("2 3 13.0", iteratorA.next());
        Assert.assertEquals("2 4 true", iteratorA.next());
        Assert.assertEquals("3 1 7", iteratorA.next());
        Assert.assertEquals("3 2 Str25", iteratorA.next());
        Assert.assertEquals("3 3 26.0", iteratorA.next());
        Assert.assertEquals("3 4 false", iteratorA.next());
        RDD<String> rddStringIJVC = mlResults.getRDDStringIJV("C");
        Iterator<String> iteratorC = rddStringIJVC.toLocalIterator();
        Assert.assertEquals("1 1 Str12", iteratorC.next());
        Assert.assertEquals("1 2 13.0", iteratorC.next());
        Assert.assertEquals("2 1 Str25", iteratorC.next());
        Assert.assertEquals("2 2 26.0", iteratorC.next());
    } else if (outputType == IO_TYPE.DATAFRAME) {
        Dataset<Row> dataFrameA = mlResults.getDataFrame("A").drop(RDDConverterUtils.DF_ID_COLUMN);
        StructType dfschemaA = dataFrameA.schema();
        StructField structTypeA = dfschemaA.apply(0);
        Assert.assertEquals(DataTypes.LongType, structTypeA.dataType());
        structTypeA = dfschemaA.apply(1);
        Assert.assertEquals(DataTypes.StringType, structTypeA.dataType());
        structTypeA = dfschemaA.apply(2);
        Assert.assertEquals(DataTypes.DoubleType, structTypeA.dataType());
        structTypeA = dfschemaA.apply(3);
        Assert.assertEquals(DataTypes.BooleanType, structTypeA.dataType());
        List<Row> listAOut = dataFrameA.collectAsList();
        Row row1 = listAOut.get(0);
        Assert.assertEquals("Mismatch with expected value", Long.valueOf(1), row1.get(0));
        Assert.assertEquals("Mismatch with expected value", "Str2", row1.get(1));
        Assert.assertEquals("Mismatch with expected value", 3.0, row1.get(2));
        Assert.assertEquals("Mismatch with expected value", true, row1.get(3));
        Row row2 = listAOut.get(1);
        Assert.assertEquals("Mismatch with expected value", Long.valueOf(4), row2.get(0));
        Assert.assertEquals("Mismatch with expected value", "Str12", row2.get(1));
        Assert.assertEquals("Mismatch with expected value", 13.0, row2.get(2));
        Assert.assertEquals("Mismatch with expected value", true, row2.get(3));
        Dataset<Row> dataFrameC = mlResults.getDataFrame("C").drop(RDDConverterUtils.DF_ID_COLUMN);
        StructType dfschemaC = dataFrameC.schema();
        StructField structTypeC = dfschemaC.apply(0);
        Assert.assertEquals(DataTypes.StringType, structTypeC.dataType());
        structTypeC = dfschemaC.apply(1);
        Assert.assertEquals(DataTypes.DoubleType, structTypeC.dataType());
        List<Row> listCOut = dataFrameC.collectAsList();
        Row row3 = listCOut.get(0);
        Assert.assertEquals("Mismatch with expected value", "Str12", row3.get(0));
        Assert.assertEquals("Mismatch with expected value", 13.0, row3.get(1));
        Row row4 = listCOut.get(1);
        Assert.assertEquals("Mismatch with expected value", "Str25", row4.get(0));
        Assert.assertEquals("Mismatch with expected value", 26.0, row4.get(1));
    } else {
        String[][] frameA = mlResults.getFrameAs2DStringArray("A");
        Assert.assertEquals("Str2", frameA[0][1]);
        Assert.assertEquals("3.0", frameA[0][2]);
        Assert.assertEquals("13.0", frameA[1][2]);
        Assert.assertEquals("true", frameA[1][3]);
        Assert.assertEquals("Str25", frameA[2][1]);
        String[][] frameC = mlResults.getFrameAs2DStringArray("C");
        Assert.assertEquals("Str12", frameC[0][0]);
        Assert.assertEquals("Str25", frameC[1][0]);
        Assert.assertEquals("13.0", frameC[0][1]);
        Assert.assertEquals("26.0", frameC[1][1]);
    }
}
Also used : Script(org.apache.sysml.api.mlcontext.Script) StructType(org.apache.spark.sql.types.StructType) ValueType(org.apache.sysml.parser.Expression.ValueType) MLResults(org.apache.sysml.api.mlcontext.MLResults) ArrayList(java.util.ArrayList) FrameSchema(org.apache.sysml.api.mlcontext.FrameSchema) JavaRDD(org.apache.spark.api.java.JavaRDD) JavaRDD(org.apache.spark.api.java.JavaRDD) RDD(org.apache.spark.rdd.RDD) StructField(org.apache.spark.sql.types.StructField) Iterator(scala.collection.Iterator) ArrayList(java.util.ArrayList) List(java.util.List) Row(org.apache.spark.sql.Row) CommaSeparatedValueStringToDoubleArrayRow(org.apache.sysml.test.integration.mlcontext.MLContextTest.CommaSeparatedValueStringToDoubleArrayRow) FrameMetadata(org.apache.sysml.api.mlcontext.FrameMetadata)

Aggregations

ArrayList (java.util.ArrayList)4 Row (org.apache.spark.sql.Row)4 StructType (org.apache.spark.sql.types.StructType)4 FrameMetadata (org.apache.sysml.api.mlcontext.FrameMetadata)4 FrameSchema (org.apache.sysml.api.mlcontext.FrameSchema)4 MLResults (org.apache.sysml.api.mlcontext.MLResults)4 Script (org.apache.sysml.api.mlcontext.Script)4 ValueType (org.apache.sysml.parser.Expression.ValueType)4 IOException (java.io.IOException)2 HashMap (java.util.HashMap)2 List (java.util.List)2 LongWritable (org.apache.hadoop.io.LongWritable)2 JavaRDD (org.apache.spark.api.java.JavaRDD)2 RDD (org.apache.spark.rdd.RDD)2 StructField (org.apache.spark.sql.types.StructField)2 DMLScript (org.apache.sysml.api.DMLScript)2 RUNTIME_PLATFORM (org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM)2 FrameFormat (org.apache.sysml.api.mlcontext.FrameFormat)2 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)2 LongFrameToLongWritableFrameFunction (org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction)2