use of org.apache.sysml.hops.AggBinaryOp.MMultMethod in project incubator-systemml by apache.
the class ZipMMSparkMatrixMultiplicationTest method runZipMMMatrixMultiplicationTest.
/**
* @param sparseM1
* @param sparseM2
* @param instType
*/
private void runZipMMMatrixMultiplicationTest(boolean sparseM1, boolean sparseM2, ExecType instType, boolean vectorM2) {
// rtplatform for MR
RUNTIME_PLATFORM platformOld = rtplatform;
switch(instType) {
case MR:
rtplatform = RUNTIME_PLATFORM.HADOOP;
break;
case SPARK:
rtplatform = RUNTIME_PLATFORM.SPARK;
break;
default:
rtplatform = RUNTIME_PLATFORM.HYBRID;
break;
}
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
if (rtplatform == RUNTIME_PLATFORM.SPARK)
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
// force zipmm execution
MMultMethod methodOld = AggBinaryOp.FORCED_MMULT_METHOD;
AggBinaryOp.FORCED_MMULT_METHOD = MMultMethod.ZIPMM;
int colsB = vectorM2 ? colsB1 : colsB2;
String TEST_NAME = TEST_NAME1;
try {
getAndLoadTestConfiguration(TEST_NAME);
/* This is for running the junit test the new way, i.e., construct the arguments directly */
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] { "-explain", "-args", input("A"), input("B"), output("C") };
fullRScriptName = HOME + TEST_NAME + ".R";
rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
// generate actual dataset
double[][] A = getRandomMatrix(rowsA, colsA, 0, 1, sparseM1 ? sparsity2 : sparsity1, 7);
writeInputMatrixWithMTD("A", A, true);
double[][] B = getRandomMatrix(rowsB, colsB, 0, 1, sparseM2 ? sparsity2 : sparsity1, 3);
writeInputMatrixWithMTD("B", B, true);
// run test case
runTest(true, false, null, -1);
runRScript(true);
// compare matrices
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("C");
HashMap<CellIndex, Double> rfile = readRMatrixFromFS("C");
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
} finally {
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
AggBinaryOp.FORCED_MMULT_METHOD = methodOld;
}
}
use of org.apache.sysml.hops.AggBinaryOp.MMultMethod in project incubator-systemml by apache.
the class FullDistributedMatrixMultiplicationTest method runDistributedMatrixMatrixMultiplicationTest.
/**
* @param sparseM1
* @param sparseM2
* @param instType
*/
private void runDistributedMatrixMatrixMultiplicationTest(boolean sparseM1, boolean sparseM2, MMultMethod method, ExecType instType) {
// rtplatform for MR
RUNTIME_PLATFORM platformOld = rtplatform;
switch(instType) {
case MR:
rtplatform = RUNTIME_PLATFORM.HADOOP;
break;
case SPARK:
rtplatform = RUNTIME_PLATFORM.SPARK;
break;
default:
rtplatform = RUNTIME_PLATFORM.HYBRID;
break;
}
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
if (rtplatform == RUNTIME_PLATFORM.SPARK)
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
MMultMethod methodOld = AggBinaryOp.FORCED_MMULT_METHOD;
AggBinaryOp.FORCED_MMULT_METHOD = method;
try {
TestConfiguration config = getTestConfiguration(TEST_NAME);
double sparsityA = sparseM1 ? sparsity2 : sparsity1;
double sparsityB = sparseM2 ? sparsity2 : sparsity1;
String TEST_CACHE_DIR = "";
if (TEST_CACHE_ENABLED) {
TEST_CACHE_DIR = String.valueOf(sparsityA) + "_" + String.valueOf(sparsityB) + "/";
}
loadTestConfiguration(config, TEST_CACHE_DIR);
/* This is for running the junit test the new way, i.e., construct the arguments directly */
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] { "-args", input("A"), input("B"), output("C") };
fullRScriptName = HOME + TEST_NAME + ".R";
rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
// generate actual dataset
double[][] A = getRandomMatrix(rowsA, colsA, 0, 1, sparsityA, 12357);
writeInputMatrixWithMTD("A", A, true);
double[][] B = getRandomMatrix(rowsB, colsB, 0, 1, sparsityB, 9873);
writeInputMatrixWithMTD("B", B, true);
runTest(true, false, null, -1);
runRScript(true);
// compare matrices
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("C");
HashMap<CellIndex, Double> rfile = readRMatrixFromFS("C");
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
} finally {
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
AggBinaryOp.FORCED_MMULT_METHOD = methodOld;
}
}
Aggregations