use of org.apache.sysml.api.mlcontext.MLContext in project incubator-systemml by apache.
the class FrameTest method setUpClass.
@BeforeClass
public static void setUpClass() {
spark = createSystemMLSparkSession("FrameTest", "local");
ml = new MLContext(spark);
sc = MLContextUtil.getJavaSparkContext(ml);
}
use of org.apache.sysml.api.mlcontext.MLContext in project incubator-systemml by apache.
the class GNMFTest method setUpClass.
@BeforeClass
public static void setUpClass() {
spark = createSystemMLSparkSession("GNMFTest", "local");
ml = new MLContext(spark);
sc = MLContextUtil.getJavaSparkContext(ml);
}
use of org.apache.sysml.api.mlcontext.MLContext in project incubator-systemml by apache.
the class DataFrameVectorScriptTest method setUpClass.
@BeforeClass
public static void setUpClass() {
spark = createSystemMLSparkSession("DataFrameVectorScriptTest", "local");
ml = new MLContext(spark);
ml.setExplain(true);
}
use of org.apache.sysml.api.mlcontext.MLContext in project incubator-systemml by apache.
the class MLContextScratchCleanupTest method runMLContextTestMultipleScript.
private static void runMLContextTestMultipleScript(RUNTIME_PLATFORM platform, boolean wRead) {
RUNTIME_PLATFORM oldplatform = DMLScript.rtplatform;
DMLScript.rtplatform = platform;
// create mlcontext
SparkSession spark = createSystemMLSparkSession("MLContextScratchCleanupTest", "local");
MLContext ml = new MLContext(spark);
ml.setExplain(true);
String dml1 = baseDirectory + File.separator + "ScratchCleanup1.dml";
String dml2 = baseDirectory + File.separator + (wRead ? "ScratchCleanup2b.dml" : "ScratchCleanup2.dml");
try {
Script script1 = dmlFromFile(dml1).in("$rows", rows).in("$cols", cols).out("X");
Matrix X = ml.execute(script1).getMatrix("X");
// clear in-memory/cached data to emulate on-disk storage
X.toMatrixObject().clearData();
Script script2 = dmlFromFile(dml2).in("X", X).out("z");
String z = ml.execute(script2).getString("z");
System.out.println(z);
} catch (Exception ex) {
throw new RuntimeException(ex);
} finally {
DMLScript.rtplatform = oldplatform;
// stop underlying spark context to allow single jvm tests (otherwise the
// next test that tries to create a SparkContext would fail)
spark.stop();
// clear status mlcontext and spark exec context
ml.close();
}
}
use of org.apache.sysml.api.mlcontext.MLContext in project incubator-systemml by apache.
the class GPUTests method assertEqualMatrices.
/**
* Asserts that the values in two matrices are in {@link UnaryOpTests#DOUBLE_PRECISION_THRESHOLD} of each other
*
* @param expected expected matrix
* @param actual actual matrix
*/
private void assertEqualMatrices(Matrix expected, Matrix actual) {
try {
// Faster way to compare two matrices
MLContext cpuMLC = new MLContext(spark);
String scriptStr = "num_mismatch = sum((abs(X - Y) / X) > " + getTHRESHOLD() + ");";
Script script = ScriptFactory.dmlFromString(scriptStr).in("X", expected).in("Y", actual).out("num_mismatch");
long num_mismatch = cpuMLC.execute(script).getLong("num_mismatch");
cpuMLC.close();
if (num_mismatch == 0)
return;
// If error, print the actual incorrect values
MatrixBlock expectedMB = expected.toMatrixObject().acquireRead();
MatrixBlock actualMB = actual.toMatrixObject().acquireRead();
long rows = expectedMB.getNumRows();
long cols = expectedMB.getNumColumns();
Assert.assertEquals(rows, actualMB.getNumRows());
Assert.assertEquals(cols, actualMB.getNumColumns());
if (PRINT_MAT_ERROR)
printMatrixIfNotEqual(expectedMB, actualMB);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
double expectedDouble = expectedMB.quickGetValue(i, j);
double actualDouble = actualMB.quickGetValue(i, j);
if (expectedDouble != 0.0 && !Double.isNaN(expectedDouble) && Double.isFinite(expectedDouble)) {
double relativeError = Math.abs((expectedDouble - actualDouble) / expectedDouble);
double absoluteError = Math.abs(expectedDouble - actualDouble);
Formatter format = new Formatter();
format.format("Relative error(%f) is more than threshold (%f). Expected = %f, Actual = %f, differed at [%d, %d]", relativeError, getTHRESHOLD(), expectedDouble, actualDouble, i, j);
if (FLOATING_POINT_PRECISION.equals("double"))
Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD());
else
Assert.assertTrue(format.toString(), relativeError < getTHRESHOLD() || absoluteError < getTHRESHOLD());
format.close();
} else {
Assert.assertEquals(expectedDouble, actualDouble, getTHRESHOLD());
}
}
}
expected.toMatrixObject().release();
actual.toMatrixObject().release();
} catch (DMLRuntimeException e) {
throw new RuntimeException(e);
}
}
Aggregations