Search in sources :

Example 21 with MLContext

use of org.apache.sysml.api.mlcontext.MLContext in project incubator-systemml by apache.

the class FrameTest method setUpClass.

@BeforeClass
public static void setUpClass() {
    spark = createSystemMLSparkSession("FrameTest", "local");
    ml = new MLContext(spark);
    sc = MLContextUtil.getJavaSparkContext(ml);
}
Also used : MLContext(org.apache.sysml.api.mlcontext.MLContext) BeforeClass(org.junit.BeforeClass)

Example 22 with MLContext

use of org.apache.sysml.api.mlcontext.MLContext in project incubator-systemml by apache.

the class GNMFTest method setUpClass.

@BeforeClass
public static void setUpClass() {
    spark = createSystemMLSparkSession("GNMFTest", "local");
    ml = new MLContext(spark);
    sc = MLContextUtil.getJavaSparkContext(ml);
}
Also used : MLContext(org.apache.sysml.api.mlcontext.MLContext) BeforeClass(org.junit.BeforeClass)

Example 23 with MLContext

use of org.apache.sysml.api.mlcontext.MLContext in project incubator-systemml by apache.

the class DataFrameVectorScriptTest method setUpClass.

@BeforeClass
public static void setUpClass() {
    spark = createSystemMLSparkSession("DataFrameVectorScriptTest", "local");
    ml = new MLContext(spark);
    ml.setExplain(true);
}
Also used : MLContext(org.apache.sysml.api.mlcontext.MLContext) BeforeClass(org.junit.BeforeClass)

Example 24 with MLContext

use of org.apache.sysml.api.mlcontext.MLContext in project incubator-systemml by apache.

the class MLContextScratchCleanupTest method runMLContextTestMultipleScript.

private static void runMLContextTestMultipleScript(RUNTIME_PLATFORM platform, boolean wRead) {
    RUNTIME_PLATFORM oldplatform = DMLScript.rtplatform;
    DMLScript.rtplatform = platform;
    // create mlcontext
    SparkSession spark = createSystemMLSparkSession("MLContextScratchCleanupTest", "local");
    MLContext ml = new MLContext(spark);
    ml.setExplain(true);
    String dml1 = baseDirectory + File.separator + "ScratchCleanup1.dml";
    String dml2 = baseDirectory + File.separator + (wRead ? "ScratchCleanup2b.dml" : "ScratchCleanup2.dml");
    try {
        Script script1 = dmlFromFile(dml1).in("$rows", rows).in("$cols", cols).out("X");
        Matrix X = ml.execute(script1).getMatrix("X");
        // clear in-memory/cached data to emulate on-disk storage
        X.toMatrixObject().clearData();
        Script script2 = dmlFromFile(dml2).in("X", X).out("z");
        String z = ml.execute(script2).getString("z");
        System.out.println(z);
    } catch (Exception ex) {
        throw new RuntimeException(ex);
    } finally {
        DMLScript.rtplatform = oldplatform;
        // stop underlying spark context to allow single jvm tests (otherwise the
        // next test that tries to create a SparkContext would fail)
        spark.stop();
        // clear status mlcontext and spark exec context
        ml.close();
    }
}
Also used : RUNTIME_PLATFORM(org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM) Script(org.apache.sysml.api.mlcontext.Script) DMLScript(org.apache.sysml.api.DMLScript) SparkSession(org.apache.spark.sql.SparkSession) Matrix(org.apache.sysml.api.mlcontext.Matrix) MLContext(org.apache.sysml.api.mlcontext.MLContext)

Example 25 with MLContext

use of org.apache.sysml.api.mlcontext.MLContext in project incubator-systemml by apache.

the class GPUTests method assertEqualMatrices.

/**
 * Asserts that the values in two matrices are in {@link UnaryOpTests#DOUBLE_PRECISION_THRESHOLD} of each other
 *
 * @param expected expected matrix
 * @param actual   actual matrix
 */
private void assertEqualMatrices(Matrix expected, Matrix actual) {
    try {
        // Faster way to compare two matrices
        MLContext cpuMLC = new MLContext(spark);
        String scriptStr = "num_mismatch = sum((abs(X - Y) / X) > " + getTHRESHOLD() + ");";
        Script script = ScriptFactory.dmlFromString(scriptStr).in("X", expected).in("Y", actual).out("num_mismatch");
        long num_mismatch = cpuMLC.execute(script).getLong("num_mismatch");
        cpuMLC.close();
        if (num_mismatch == 0)
            return;
        // If error, print the actual incorrect values
        MatrixBlock expectedMB = expected.toMatrixObject().acquireRead();
        MatrixBlock actualMB = actual.toMatrixObject().acquireRead();
        long rows = expectedMB.getNumRows();
        long cols = expectedMB.getNumColumns();
        Assert.assertEquals(rows, actualMB.getNumRows());
        Assert.assertEquals(cols, actualMB.getNumColumns());
        if (PRINT_MAT_ERROR)
            printMatrixIfNotEqual(expectedMB, actualMB);
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                double expectedDouble = expectedMB.quickGetValue(i, j);
                double actualDouble = actualMB.quickGetValue(i, j);
                if (expectedDouble != 0.0 && !Double.isNaN(expectedDouble) && Double.isFinite(expectedDouble)) {
                    double relativeError = Math.abs((expectedDouble - actualDouble) / expectedDouble);
                    double absoluteError = Math.abs(expectedDouble - actualDouble);
                    Formatter format = new Formatter();
                    format.format("Relative error(%f) is more than threshold (%f). Expected = %f, Actual = %f, differed at [%d, %d]", relativeError, getTHRESHOLD(), expectedDouble, actualDouble, i, j);
                    if (FLOATING_POINT_PRECISION.equals("double"))
                        Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD());
                    else
                        Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD() || absoluteError < getTHRESHOLD());
                    format.close();
                } else {
                    Assert.assertEquals(expectedDouble, actualDouble, getTHRESHOLD());
                }
            }
        }
        expected.toMatrixObject().release();
        actual.toMatrixObject().release();
    } catch (DMLRuntimeException e) {
        throw new RuntimeException(e);
    }
}
Also used : Script(org.apache.sysml.api.mlcontext.Script) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) Formatter(java.util.Formatter) MLContext(org.apache.sysml.api.mlcontext.MLContext) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Aggregations

MLContext (org.apache.sysml.api.mlcontext.MLContext)30 Script (org.apache.sysml.api.mlcontext.Script)18 Matrix (org.apache.sysml.api.mlcontext.Matrix)10 BeforeClass (org.junit.BeforeClass)8 DMLScript (org.apache.sysml.api.DMLScript)6 ArrayList (java.util.ArrayList)4 SparkConf (org.apache.spark.SparkConf)4 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)4 SparkSession (org.apache.spark.sql.SparkSession)4 RUNTIME_PLATFORM (org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM)4 MLResults (org.apache.sysml.api.mlcontext.MLResults)4 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)4 FileNotFoundException (java.io.FileNotFoundException)2 IOException (java.io.IOException)2 Formatter (java.util.Formatter)2 CannotCompileException (javassist.CannotCompileException)2 ClassPool (javassist.ClassPool)2 CtClass (javassist.CtClass)2 CtMethod (javassist.CtMethod)2 NotFoundException (javassist.NotFoundException)2