use of org.apache.sysml.api.jmlc.PreparedScript in project incubator-systemml by apache.
the class BuildLiteExecution method jmlcLinReg.
public static void jmlcLinReg() throws Exception {
Connection conn = getConfiguredConnection();
String linRegDS = conn.readScript("scripts/algorithms/LinearRegDS.dml");
PreparedScript linRegDSScript = conn.prepareScript(linRegDS, new String[] { "X", "y" }, new String[] { "beta_out" }, false);
double[][] trainData = new double[500][3];
for (int i = 0; i < 500; i++) {
double one = ThreadLocalRandom.current().nextDouble(0, 100);
double two = ThreadLocalRandom.current().nextDouble(0, 100);
double three = ThreadLocalRandom.current().nextDouble(0, 100);
double[] row = new double[] { one, two, three };
trainData[i] = row;
}
linRegDSScript.setMatrix("X", trainData);
log.debug(displayMatrix(trainData));
double[][] trainLabels = new double[500][1];
for (int i = 0; i < 500; i++) {
double one = ThreadLocalRandom.current().nextDouble(0, 100);
double[] row = new double[] { one };
trainLabels[i] = row;
}
linRegDSScript.setMatrix("y", trainLabels);
log.debug(displayMatrix(trainLabels));
ResultVariables linRegDSResults = linRegDSScript.executeScript();
double[][] dsBetas = linRegDSResults.getMatrix("beta_out");
log.debug("DS BETAS:");
log.debug(displayMatrix(dsBetas));
String linRegCG = conn.readScript("scripts/algorithms/LinearRegCG.dml");
PreparedScript linRegCGScript = conn.prepareScript(linRegCG, new String[] { "X", "y" }, new String[] { "beta_out" }, false);
linRegCGScript.setMatrix("X", trainData);
linRegCGScript.setMatrix("y", trainLabels);
ResultVariables linRegCGResults = linRegCGScript.executeScript();
double[][] cgBetas = linRegCGResults.getMatrix("beta_out");
log.debug("CG BETAS:");
log.debug(displayMatrix(cgBetas));
String glmPredict = conn.readScript("scripts/algorithms/GLM-predict.dml");
PreparedScript glmPredictScript = conn.prepareScript(glmPredict, new String[] { "X", "Y", "B_full" }, new String[] { "means" }, false);
double[][] testData = new double[500][3];
for (int i = 0; i < 500; i++) {
double one = ThreadLocalRandom.current().nextDouble(0, 100);
double two = ThreadLocalRandom.current().nextDouble(0, 100);
double three = ThreadLocalRandom.current().nextDouble(0, 100);
double[] row = new double[] { one, two, three };
testData[i] = row;
}
glmPredictScript.setMatrix("X", testData);
double[][] testLabels = new double[500][1];
for (int i = 0; i < 500; i++) {
double one = ThreadLocalRandom.current().nextDouble(0, 100);
double[] row = new double[] { one };
testLabels[i] = row;
}
glmPredictScript.setMatrix("Y", testLabels);
glmPredictScript.setMatrix("B_full", cgBetas);
ResultVariables glmPredictResults = glmPredictScript.executeScript();
double[][] means = glmPredictResults.getMatrix("means");
log.debug("GLM PREDICT MEANS:");
log.debug(displayMatrix(means));
conn.close();
}
use of org.apache.sysml.api.jmlc.PreparedScript in project incubator-systemml by apache.
the class SystemTMulticlassSVMScoreTest method execDMLScriptviaJMLC.
/**
*
* @param X
* @return
* @throws DMLException
* @throws IOException
*/
private ArrayList<double[][]> execDMLScriptviaJMLC(ArrayList<double[][]> X) throws IOException {
Timing time = new Timing(true);
ArrayList<double[][]> ret = new ArrayList<double[][]>();
//establish connection to SystemML
Connection conn = new Connection();
try {
// For now, JMLC pipeline only allows dml
boolean parsePyDML = false;
//read and precompile script
String script = conn.readScript(SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml");
PreparedScript pstmt = conn.prepareScript(script, new String[] { "X", "W" }, new String[] { "predicted_y" }, parsePyDML);
//read model
String modelData = conn.readScript(SCRIPT_DIR + TEST_DIR + MODEL_FILE);
double[][] W = conn.convertToDoubleMatrix(modelData, rows, cols);
//execute script multiple times
for (int i = 0; i < nRuns; i++) {
//bind input parameters
pstmt.setMatrix("W", W);
pstmt.setMatrix("X", X.get(i));
//execute script
ResultVariables rs = pstmt.executeScript();
//get output parameter
double[][] Y = rs.getMatrix("predicted_y");
//keep result for comparison
ret.add(Y);
}
} catch (Exception ex) {
ex.printStackTrace();
throw new IOException(ex);
} finally {
if (conn != null)
conn.close();
}
System.out.println("JMLC scoring w/ " + nRuns + " runs in " + time.stop() + "ms.");
return ret;
}
Aggregations