use of org.apache.sysml.runtime.matrix.operators.QuaternaryOperator in project incubator-systemml by apache.
the class QuaternarySPInstruction method parseInstruction.
public static QuaternarySPInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
// validity check
if (!InstructionUtils.isDistQuaternaryOpcode(opcode)) {
throw new DMLRuntimeException("Quaternary.parseInstruction():: Unknown opcode " + opcode);
}
// instruction parsing
if (// wsloss
WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) || WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode)) {
boolean isRed = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode);
// check number of fields (4 inputs, output, type)
if (isRed)
InstructionUtils.checkNumFields(parts, 8);
else
InstructionUtils.checkNumFields(parts, 6);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
WeightsType wtype = WeightsType.valueOf(parts[6]);
// in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
return new QuaternarySPInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
} else if (// wumm
WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode)) {
boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode);
// check number of fields (4 inputs, output, type)
if (isRed)
InstructionUtils.checkNumFields(parts, 8);
else
InstructionUtils.checkNumFields(parts, 6);
String uopcode = parts[1];
CPOperand in1 = new CPOperand(parts[2]);
CPOperand in2 = new CPOperand(parts[3]);
CPOperand in3 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
WUMMType wtype = WUMMType.valueOf(parts[6]);
// in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
return new QuaternarySPInstruction(new QuaternaryOperator(wtype, uopcode), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
} else if (// wdivmm
WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) || WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode)) {
boolean isRed = opcode.startsWith("red");
// check number of fields (4 inputs, output, type)
if (isRed)
InstructionUtils.checkNumFields(parts, 8);
else
InstructionUtils.checkNumFields(parts, 6);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
// in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
final WDivMMType wt = WDivMMType.valueOf(parts[6]);
QuaternaryOperator qop = (wt.hasScalar() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt));
return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
} else // map/redwsigmoid, map/redwcemm
{
boolean isRed = opcode.startsWith("red");
int addInput4 = (opcode.endsWith("wcemm")) ? 1 : 0;
// check number of fields (3 or 4 inputs, output, type)
if (isRed)
InstructionUtils.checkNumFields(parts, 7 + addInput4);
else
InstructionUtils.checkNumFields(parts, 5 + addInput4);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4 + addInput4]);
// in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true;
if (opcode.endsWith("wsigmoid"))
return new QuaternarySPInstruction(new QuaternaryOperator(WSigmoidType.valueOf(parts[5])), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
else if (opcode.endsWith("wcemm")) {
CPOperand in4 = new CPOperand(parts[4]);
final WCeMMType wt = WCeMMType.valueOf(parts[6]);
QuaternaryOperator qop = (wt.hasFourInputs() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt));
return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
}
}
return null;
}
use of org.apache.sysml.runtime.matrix.operators.QuaternaryOperator in project systemml by apache.
the class QuaternaryCPInstruction method parseInstruction.
public static QuaternaryCPInstruction parseInstruction(String inst) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst);
String opcode = parts[0];
if (opcode.equalsIgnoreCase("wsloss") || opcode.equalsIgnoreCase("wdivmm") || opcode.equalsIgnoreCase("wcemm")) {
InstructionUtils.checkNumFields(parts, 7);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
int k = Integer.parseInt(parts[7]);
if (opcode.equalsIgnoreCase("wsloss"))
return new QuaternaryCPInstruction(new QuaternaryOperator(WeightsType.valueOf(parts[6])), in1, in2, in3, in4, out, k, opcode, inst);
else if (opcode.equalsIgnoreCase("wdivmm"))
return new QuaternaryCPInstruction(new QuaternaryOperator(WDivMMType.valueOf(parts[6])), in1, in2, in3, in4, out, k, opcode, inst);
else if (opcode.equalsIgnoreCase("wcemm"))
return new QuaternaryCPInstruction(new QuaternaryOperator(WCeMMType.valueOf(parts[6])), in1, in2, in3, in4, out, k, opcode, inst);
} else if (opcode.equalsIgnoreCase("wsigmoid")) {
InstructionUtils.checkNumFields(parts, 6);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4]);
int k = Integer.parseInt(parts[6]);
if (opcode.equalsIgnoreCase("wsigmoid"))
return new QuaternaryCPInstruction(new QuaternaryOperator(WSigmoidType.valueOf(parts[5])), in1, in2, in3, null, out, k, opcode, inst);
} else if (opcode.equalsIgnoreCase("wumm")) {
InstructionUtils.checkNumFields(parts, 7);
String uopcode = parts[1];
CPOperand in1 = new CPOperand(parts[2]);
CPOperand in2 = new CPOperand(parts[3]);
CPOperand in3 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
int k = Integer.parseInt(parts[7]);
return new QuaternaryCPInstruction(new QuaternaryOperator(WUMMType.valueOf(parts[6]), uopcode), in1, in2, in3, null, out, k, opcode, inst);
}
throw new DMLRuntimeException("Unexpected opcode in QuaternaryCPInstruction: " + inst);
}
use of org.apache.sysml.runtime.matrix.operators.QuaternaryOperator in project systemml by apache.
the class QuaternaryInstruction method computeMatrixCharacteristics.
public void computeMatrixCharacteristics(MatrixCharacteristics mc1, MatrixCharacteristics mc2, MatrixCharacteristics mc3, MatrixCharacteristics dimOut) {
QuaternaryOperator qop = (QuaternaryOperator) optr;
if (qop.wtype1 != null || qop.wtype4 != null) {
// wsloss/wcemm
// output size independent of chain type (scalar)
dimOut.set(1, 1, mc1.getRowsPerBlock(), mc1.getColsPerBlock());
} else if (qop.wtype2 != null || qop.wtype5 != null) {
// wsigmoid/wumm
// output size determined by main input
dimOut.set(mc1.getRows(), mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
} else if (qop.wtype3 != null) {
// wdivmm
// note: cannot directly consume mc2 or mc3 for redwdivmm because rep instruction changed
// the relevant dimensions; as a workaround the original dims are passed via nnz
boolean mapwdivmm = _cacheU && _cacheV;
long rank = qop.wtype3.isLeft() ? mapwdivmm ? mc3.getCols() : mc3.getNonZeros() : mapwdivmm ? mc2.getCols() : mc2.getNonZeros();
MatrixCharacteristics mcTmp = qop.wtype3.computeOutputCharacteristics(mc1.getRows(), mc1.getCols(), rank);
dimOut.set(mcTmp.getRows(), mcTmp.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
}
}
use of org.apache.sysml.runtime.matrix.operators.QuaternaryOperator in project systemml by apache.
the class QuaternarySPInstruction method processInstruction.
@Override
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext) ec;
QuaternaryOperator qop = (QuaternaryOperator) _optr;
// tracking of rdds and broadcasts (for lineage maintenance)
ArrayList<String> rddVars = new ArrayList<>();
ArrayList<String> bcVars = new ArrayList<>();
JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
MatrixCharacteristics inMc = sec.getMatrixCharacteristics(input1.getName());
long rlen = inMc.getRows();
long clen = inMc.getCols();
int brlen = inMc.getRowsPerBlock();
int bclen = inMc.getColsPerBlock();
// (map/redwsloss, map/redwcemm); safe because theses ops produce a scalar
if (qop.wtype1 != null || qop.wtype4 != null) {
in = in.filter(new FilterNonEmptyBlocksFunction());
}
// map-side only operation (one rdd input, two broadcasts)
if (WeightedSquaredLoss.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedSigmoid.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedDivMM.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedCrossEntropy.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedUnaryMM.OPCODE.equalsIgnoreCase(getOpcode())) {
PartitionedBroadcast<MatrixBlock> bc1 = sec.getBroadcastForVariable(input2.getName());
PartitionedBroadcast<MatrixBlock> bc2 = sec.getBroadcastForVariable(input3.getName());
// partitioning-preserving mappartitions (key access required for broadcast loopkup)
// only wdivmm changes keys
boolean noKeyChange = (qop.wtype3 == null || qop.wtype3.isBasic());
out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), noKeyChange);
rddVars.add(input1.getName());
bcVars.add(input2.getName());
bcVars.add(input3.getName());
} else // reduce-side operation (two/three/four rdd inputs, zero/one/two broadcasts)
{
PartitionedBroadcast<MatrixBlock> bc1 = _cacheU ? sec.getBroadcastForVariable(input2.getName()) : null;
PartitionedBroadcast<MatrixBlock> bc2 = _cacheV ? sec.getBroadcastForVariable(input3.getName()) : null;
JavaPairRDD<MatrixIndexes, MatrixBlock> inU = (!_cacheU) ? sec.getBinaryBlockRDDHandleForVariable(input2.getName()) : null;
JavaPairRDD<MatrixIndexes, MatrixBlock> inV = (!_cacheV) ? sec.getBinaryBlockRDDHandleForVariable(input3.getName()) : null;
JavaPairRDD<MatrixIndexes, MatrixBlock> inW = (qop.hasFourInputs() && !_input4.isLiteral()) ? sec.getBinaryBlockRDDHandleForVariable(_input4.getName()) : null;
// preparation of transposed and replicated U
if (inU != null)
inU = inU.flatMapToPair(new ReplicateBlockFunction(clen, bclen, true));
// preparation of transposed and replicated V
if (inV != null)
inV = inV.mapToPair(new TransposeFactorIndexesFunction()).flatMapToPair(new ReplicateBlockFunction(rlen, brlen, false));
// functions calls w/ two rdd inputs
if (inU != null && inV == null && inW == null)
out = in.join(inU).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
else if (inU == null && inV != null && inW == null)
out = in.join(inV).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
else if (inU == null && inV == null && inW != null)
out = in.join(inW).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
else // function calls w/ three rdd inputs
if (inU != null && inV != null && inW == null)
out = in.join(inU).join(inV).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if (inU != null && inV == null && inW != null)
out = in.join(inU).join(inW).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if (inU == null && inV != null && inW != null)
out = in.join(inV).join(inW).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if (inU == null && inV == null && inW == null) {
out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), false);
} else
// function call w/ four rdd inputs
// need keys in case of wdivmm
out = in.join(inU).join(inV).join(inW).mapToPair(new RDDQuaternaryFunction4(qop));
// keep variable names for lineage maintenance
if (inU == null)
bcVars.add(input2.getName());
else
rddVars.add(input2.getName());
if (inV == null)
bcVars.add(input3.getName());
else
rddVars.add(input3.getName());
if (inW != null)
rddVars.add(_input4.getName());
}
// output handling, incl aggregation
if (// map/redwsloss, map/redwcemm
qop.wtype1 != null || qop.wtype4 != null) {
// full aggregate and cast to scalar
MatrixBlock tmp = RDDAggregateUtils.sumStable(out);
DoubleObject ret = new DoubleObject(tmp.getValue(0, 0));
sec.setVariable(output.getName(), ret);
} else // map/redwsigmoid, map/redwdivmm, map/redwumm
{
// aggregation if required (map/redwdivmm)
if (qop.wtype3 != null && !qop.wtype3.isBasic())
out = RDDAggregateUtils.sumByKeyStable(out, false);
// put output RDD handle into symbol table
sec.setRDDHandleForVariable(output.getName(), out);
// maintain lineage information for output rdd
for (String rddVar : rddVars) sec.addLineageRDD(output.getName(), rddVar);
for (String bcVar : bcVars) sec.addLineageBroadcast(output.getName(), bcVar);
// update matrix characteristics
updateOutputMatrixCharacteristics(sec, qop);
}
}
Aggregations