use of org.apache.sysml.runtime.instructions.cp.ScalarObject in project incubator-systemml by apache.
the class BinarySPInstruction method processMatrixScalarBinaryInstruction.
protected void processMatrixScalarBinaryInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext) ec;
// get input RDD
String rddVar = (input1.getDataType() == DataType.MATRIX) ? input1.getName() : input2.getName();
JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
// get operator and scalar
CPOperand scalar = (input1.getDataType() == DataType.MATRIX) ? input2 : input1;
ScalarObject constant = (ScalarObject) ec.getScalarInput(scalar.getName(), scalar.getValueType(), scalar.isLiteral());
ScalarOperator sc_op = (ScalarOperator) _optr;
sc_op = sc_op.setConstant(constant.getDoubleValue());
// execute scalar matrix arithmetic instruction
JavaPairRDD<MatrixIndexes, MatrixBlock> out = in1.mapValues(new MatrixScalarUnaryFunction(sc_op));
// put output RDD handle into symbol table
updateUnaryOutputMatrixCharacteristics(sec, rddVar, output.getName());
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), rddVar);
}
use of org.apache.sysml.runtime.instructions.cp.ScalarObject in project incubator-systemml by apache.
the class CentralMomentSPInstruction method processInstruction.
@Override
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext) ec;
// parse 'order' input argument
CPOperand scalarInput = (input3 == null ? input2 : input3);
ScalarObject order = ec.getScalarInput(scalarInput.getName(), scalarInput.getValueType(), scalarInput.isLiteral());
CMOperator cop = ((CMOperator) _optr);
if (cop.getAggOpType() == AggregateOperationTypes.INVALID) {
cop.setCMAggOp((int) order.getLongValue());
}
// get input
JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
// process central moment instruction
CM_COV_Object cmobj = null;
if (// w/o weights
input3 == null) {
cmobj = in1.values().map(new RDDCMFunction(cop)).fold(new CM_COV_Object(), new RDDCMReduceFunction(cop));
} else // with weights
{
JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = sec.getBinaryBlockRDDHandleForVariable(input2.getName());
cmobj = in1.join(in2).values().map(new RDDCMWeightsFunction(cop)).fold(new CM_COV_Object(), new RDDCMReduceFunction(cop));
}
// create scalar output (no lineage information required)
double val = cmobj.getRequiredResult(_optr);
ec.setScalarOutput(output.getName(), new DoubleObject(val));
}
use of org.apache.sysml.runtime.instructions.cp.ScalarObject in project incubator-systemml by apache.
the class LiteralReplacement method replaceLiteralValueTypeCastScalarRead.
private static LiteralOp replaceLiteralValueTypeCastScalarRead(Hop c, LocalVariableMap vars) {
LiteralOp ret = null;
// as.double/as.integer/as.boolean over scalar read - literal replacement
if (c instanceof UnaryOp && (((UnaryOp) c).getOp() == OpOp1.CAST_AS_DOUBLE || ((UnaryOp) c).getOp() == OpOp1.CAST_AS_INT || ((UnaryOp) c).getOp() == OpOp1.CAST_AS_BOOLEAN) && c.getInput().get(0) instanceof DataOp && c.getDataType() == DataType.SCALAR) {
Data dat = vars.get(c.getInput().get(0).getName());
if (// required for selective constant propagation
dat != null) {
ScalarObject sdat = (ScalarObject) dat;
UnaryOp cast = (UnaryOp) c;
switch(cast.getOp()) {
case CAST_AS_INT:
ret = new LiteralOp(sdat.getLongValue());
break;
case CAST_AS_DOUBLE:
ret = new LiteralOp(sdat.getDoubleValue());
break;
case CAST_AS_BOOLEAN:
ret = new LiteralOp(sdat.getBooleanValue());
break;
default:
}
}
}
return ret;
}
use of org.apache.sysml.runtime.instructions.cp.ScalarObject in project incubator-systemml by apache.
the class ForProgramBlock method executePredicateInstructions.
protected IntObject executePredicateInstructions(int pos, ArrayList<Instruction> instructions, ExecutionContext ec) {
ScalarObject tmp = null;
IntObject ret = null;
try {
if (_sb != null) {
if (// set program block specific remote memory
DMLScript.isActiveAM())
DMLAppMasterUtils.setupProgramBlockRemoteMaxMemory(this);
ForStatementBlock fsb = (ForStatementBlock) _sb;
Hop predHops = null;
boolean recompile = false;
if (pos == 1) {
predHops = fsb.getFromHops();
recompile = fsb.requiresFromRecompilation();
} else if (pos == 2) {
predHops = fsb.getToHops();
recompile = fsb.requiresToRecompilation();
} else if (pos == 3) {
predHops = fsb.getIncrementHops();
recompile = fsb.requiresIncrementRecompilation();
}
tmp = (IntObject) executePredicate(instructions, predHops, recompile, ValueType.INT, ec);
} else
tmp = (IntObject) executePredicate(instructions, null, false, ValueType.INT, ec);
} catch (Exception ex) {
String predNameStr = null;
if (pos == 1)
predNameStr = "from";
else if (pos == 2)
predNameStr = "to";
else if (pos == 3)
predNameStr = "increment";
throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error evaluating '" + predNameStr + "' predicate", ex);
}
// final check of resulting int object (guaranteed to be non-null, see executePredicate)
if (tmp instanceof IntObject)
ret = (IntObject) tmp;
else
// downcast to int if necessary
ret = new IntObject(tmp.getLongValue());
return ret;
}
use of org.apache.sysml.runtime.instructions.cp.ScalarObject in project incubator-systemml by apache.
the class ProgramBlock method executePredicateInstructions.
protected ScalarObject executePredicateInstructions(ArrayList<Instruction> inst, ValueType retType, ExecutionContext ec) {
// execute all instructions (indexed access required due to debug mode)
int pos = 0;
for (Instruction currInst : inst) {
ec.updateDebugState(pos++);
executeSingleInstruction(currInst, ec);
}
// get scalar return
ScalarObject ret = (ScalarObject) ec.getScalarInput(PRED_VAR, retType, false);
// check and correct scalar ret type (incl save double to int)
if (ret.getValueType() != retType)
switch(retType) {
case BOOLEAN:
ret = new BooleanObject(ret.getBooleanValue());
break;
case INT:
ret = new IntObject(ret.getLongValue());
break;
case DOUBLE:
ret = new DoubleObject(ret.getDoubleValue());
break;
case STRING:
ret = new StringObject(ret.getStringValue());
break;
default:
}
// remove predicate variable
ec.removeVariable(PRED_VAR);
return ret;
}
Aggregations