Search in sources :

Example 36 with PreparedScript

use of org.apache.sysml.api.jmlc.PreparedScript in project incubator-systemml by apache.

the class BuildLiteExecution method jmlcLinReg.

public static void jmlcLinReg() throws Exception {
    Connection conn = getConfiguredConnection();
    String linRegDS = conn.readScript("scripts/algorithms/LinearRegDS.dml");
    PreparedScript linRegDSScript = conn.prepareScript(linRegDS, new String[] { "X", "y" }, new String[] { "beta_out" }, false);
    double[][] trainData = new double[500][3];
    for (int i = 0; i < 500; i++) {
        double one = ThreadLocalRandom.current().nextDouble(0, 100);
        double two = ThreadLocalRandom.current().nextDouble(0, 100);
        double three = ThreadLocalRandom.current().nextDouble(0, 100);
        double[] row = new double[] { one, two, three };
        trainData[i] = row;
    }
    linRegDSScript.setMatrix("X", trainData);
    log.debug(displayMatrix(trainData));
    double[][] trainLabels = new double[500][1];
    for (int i = 0; i < 500; i++) {
        double one = ThreadLocalRandom.current().nextDouble(0, 100);
        double[] row = new double[] { one };
        trainLabels[i] = row;
    }
    linRegDSScript.setMatrix("y", trainLabels);
    log.debug(displayMatrix(trainLabels));
    ResultVariables linRegDSResults = linRegDSScript.executeScript();
    double[][] dsBetas = linRegDSResults.getMatrix("beta_out");
    log.debug("DS BETAS:");
    log.debug(displayMatrix(dsBetas));
    String linRegCG = conn.readScript("scripts/algorithms/LinearRegCG.dml");
    PreparedScript linRegCGScript = conn.prepareScript(linRegCG, new String[] { "X", "y" }, new String[] { "beta_out" }, false);
    linRegCGScript.setMatrix("X", trainData);
    linRegCGScript.setMatrix("y", trainLabels);
    ResultVariables linRegCGResults = linRegCGScript.executeScript();
    double[][] cgBetas = linRegCGResults.getMatrix("beta_out");
    log.debug("CG BETAS:");
    log.debug(displayMatrix(cgBetas));
    String glmPredict = conn.readScript("scripts/algorithms/GLM-predict.dml");
    PreparedScript glmPredictScript = conn.prepareScript(glmPredict, new String[] { "X", "Y", "B_full" }, new String[] { "means" }, false);
    double[][] testData = new double[500][3];
    for (int i = 0; i < 500; i++) {
        double one = ThreadLocalRandom.current().nextDouble(0, 100);
        double two = ThreadLocalRandom.current().nextDouble(0, 100);
        double three = ThreadLocalRandom.current().nextDouble(0, 100);
        double[] row = new double[] { one, two, three };
        testData[i] = row;
    }
    glmPredictScript.setMatrix("X", testData);
    double[][] testLabels = new double[500][1];
    for (int i = 0; i < 500; i++) {
        double one = ThreadLocalRandom.current().nextDouble(0, 100);
        double[] row = new double[] { one };
        testLabels[i] = row;
    }
    glmPredictScript.setMatrix("Y", testLabels);
    glmPredictScript.setMatrix("B_full", cgBetas);
    ResultVariables glmPredictResults = glmPredictScript.executeScript();
    double[][] means = glmPredictResults.getMatrix("means");
    log.debug("GLM PREDICT MEANS:");
    log.debug(displayMatrix(means));
    conn.close();
}
Also used : PreparedScript(org.apache.sysml.api.jmlc.PreparedScript) ResultVariables(org.apache.sysml.api.jmlc.ResultVariables) Connection(org.apache.sysml.api.jmlc.Connection)

Example 37 with PreparedScript

use of org.apache.sysml.api.jmlc.PreparedScript in project incubator-systemml by apache.

the class SystemTMulticlassSVMScoreTest method execDMLScriptviaJMLC.

/**
	 * 
	 * @param X
	 * @return
	 * @throws DMLException
	 * @throws IOException
	 */
private ArrayList<double[][]> execDMLScriptviaJMLC(ArrayList<double[][]> X) throws IOException {
    Timing time = new Timing(true);
    ArrayList<double[][]> ret = new ArrayList<double[][]>();
    //establish connection to SystemML
    Connection conn = new Connection();
    try {
        // For now, JMLC pipeline only allows dml
        boolean parsePyDML = false;
        //read and precompile script
        String script = conn.readScript(SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml");
        PreparedScript pstmt = conn.prepareScript(script, new String[] { "X", "W" }, new String[] { "predicted_y" }, parsePyDML);
        //read model
        String modelData = conn.readScript(SCRIPT_DIR + TEST_DIR + MODEL_FILE);
        double[][] W = conn.convertToDoubleMatrix(modelData, rows, cols);
        //execute script multiple times
        for (int i = 0; i < nRuns; i++) {
            //bind input parameters
            pstmt.setMatrix("W", W);
            pstmt.setMatrix("X", X.get(i));
            //execute script
            ResultVariables rs = pstmt.executeScript();
            //get output parameter
            double[][] Y = rs.getMatrix("predicted_y");
            //keep result for comparison
            ret.add(Y);
        }
    } catch (Exception ex) {
        ex.printStackTrace();
        throw new IOException(ex);
    } finally {
        if (conn != null)
            conn.close();
    }
    System.out.println("JMLC scoring w/ " + nRuns + " runs in " + time.stop() + "ms.");
    return ret;
}
Also used : PreparedScript(org.apache.sysml.api.jmlc.PreparedScript) ResultVariables(org.apache.sysml.api.jmlc.ResultVariables) ArrayList(java.util.ArrayList) Connection(org.apache.sysml.api.jmlc.Connection) Timing(org.apache.sysml.runtime.controlprogram.parfor.stat.Timing) IOException(java.io.IOException) IOException(java.io.IOException) DMLException(org.apache.sysml.api.DMLException)

Aggregations

Connection (org.apache.sysml.api.jmlc.Connection)37 PreparedScript (org.apache.sysml.api.jmlc.PreparedScript)37 ResultVariables (org.apache.sysml.api.jmlc.ResultVariables)14 Test (org.junit.Test)14 IOException (java.io.IOException)12 HashMap (java.util.HashMap)11 ArrayList (java.util.ArrayList)10 Timing (org.apache.sysml.runtime.controlprogram.parfor.stat.Timing)9 DMLException (org.apache.sysml.api.DMLException)2 File (java.io.File)1 ExecutorService (java.util.concurrent.ExecutorService)1 Future (java.util.concurrent.Future)1 SparkConf (org.apache.spark.SparkConf)1 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)1 DMLScript (org.apache.sysml.api.DMLScript)1 MLContext (org.apache.sysml.api.mlcontext.MLContext)1 Script (org.apache.sysml.api.mlcontext.Script)1 ScalarObject (org.apache.sysml.runtime.instructions.cp.ScalarObject)1 FrameBlock (org.apache.sysml.runtime.matrix.data.FrameBlock)1 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)1