use of org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction in project incubator-systemml by apache.
the class QuaternarySPInstruction method processInstruction.
@Override
public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
SparkExecutionContext sec = (SparkExecutionContext) ec;
QuaternaryOperator qop = (QuaternaryOperator) _optr;
//tracking of rdds and broadcasts (for lineage maintenance)
ArrayList<String> rddVars = new ArrayList<String>();
ArrayList<String> bcVars = new ArrayList<String>();
JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
MatrixCharacteristics inMc = sec.getMatrixCharacteristics(input1.getName());
long rlen = inMc.getRows();
long clen = inMc.getCols();
int brlen = inMc.getRowsPerBlock();
int bclen = inMc.getColsPerBlock();
//(map/redwsloss, map/redwcemm); safe because theses ops produce a scalar
if (qop.wtype1 != null || qop.wtype4 != null) {
in = in.filter(new FilterNonEmptyBlocksFunction());
}
//map-side only operation (one rdd input, two broadcasts)
if (WeightedSquaredLoss.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedSigmoid.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedDivMM.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedCrossEntropy.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedUnaryMM.OPCODE.equalsIgnoreCase(getOpcode())) {
PartitionedBroadcast<MatrixBlock> bc1 = sec.getBroadcastForVariable(input2.getName());
PartitionedBroadcast<MatrixBlock> bc2 = sec.getBroadcastForVariable(input3.getName());
//partitioning-preserving mappartitions (key access required for broadcast loopkup)
//only wdivmm changes keys
boolean noKeyChange = (qop.wtype3 == null || qop.wtype3.isBasic());
out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), noKeyChange);
rddVars.add(input1.getName());
bcVars.add(input2.getName());
bcVars.add(input3.getName());
} else //reduce-side operation (two/three/four rdd inputs, zero/one/two broadcasts)
{
PartitionedBroadcast<MatrixBlock> bc1 = _cacheU ? sec.getBroadcastForVariable(input2.getName()) : null;
PartitionedBroadcast<MatrixBlock> bc2 = _cacheV ? sec.getBroadcastForVariable(input3.getName()) : null;
JavaPairRDD<MatrixIndexes, MatrixBlock> inU = (!_cacheU) ? sec.getBinaryBlockRDDHandleForVariable(input2.getName()) : null;
JavaPairRDD<MatrixIndexes, MatrixBlock> inV = (!_cacheV) ? sec.getBinaryBlockRDDHandleForVariable(input3.getName()) : null;
JavaPairRDD<MatrixIndexes, MatrixBlock> inW = (qop.hasFourInputs() && !_input4.isLiteral()) ? sec.getBinaryBlockRDDHandleForVariable(_input4.getName()) : null;
//preparation of transposed and replicated U
if (inU != null)
inU = inU.flatMapToPair(new ReplicateBlocksFunction(clen, bclen, true));
//preparation of transposed and replicated V
if (inV != null)
inV = inV.mapToPair(new TransposeFactorIndexesFunction()).flatMapToPair(new ReplicateBlocksFunction(rlen, brlen, false));
//functions calls w/ two rdd inputs
if (inU != null && inV == null && inW == null)
out = in.join(inU).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
else if (inU == null && inV != null && inW == null)
out = in.join(inV).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
else if (inU == null && inV == null && inW != null)
out = in.join(inW).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
else //function calls w/ three rdd inputs
if (inU != null && inV != null && inW == null)
out = in.join(inU).join(inV).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if (inU != null && inV == null && inW != null)
out = in.join(inU).join(inW).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if (inU == null && inV != null && inW != null)
out = in.join(inV).join(inW).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if (inU == null && inV == null && inW == null) {
out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), false);
} else
//function call w/ four rdd inputs
//need keys in case of wdivmm
out = in.join(inU).join(inV).join(inW).mapToPair(new RDDQuaternaryFunction4(qop));
//keep variable names for lineage maintenance
if (inU == null)
bcVars.add(input2.getName());
else
rddVars.add(input2.getName());
if (inV == null)
bcVars.add(input3.getName());
else
rddVars.add(input3.getName());
if (inW != null)
rddVars.add(_input4.getName());
}
//output handling, incl aggregation
if (//map/redwsloss, map/redwcemm
qop.wtype1 != null || qop.wtype4 != null) {
//full aggregate and cast to scalar
MatrixBlock tmp = RDDAggregateUtils.sumStable(out);
DoubleObject ret = new DoubleObject(tmp.getValue(0, 0));
sec.setVariable(output.getName(), ret);
} else //map/redwsigmoid, map/redwdivmm, map/redwumm
{
//aggregation if required (map/redwdivmm)
if (qop.wtype3 != null && !qop.wtype3.isBasic())
out = RDDAggregateUtils.sumByKeyStable(out, false);
//put output RDD handle into symbol table
sec.setRDDHandleForVariable(output.getName(), out);
//maintain lineage information for output rdd
for (String rddVar : rddVars) sec.addLineageRDD(output.getName(), rddVar);
for (String bcVar : bcVars) sec.addLineageBroadcast(output.getName(), bcVar);
//update matrix characteristics
updateOutputMatrixCharacteristics(sec, qop);
}
}
use of org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction in project incubator-systemml by apache.
the class MapmmSPInstruction method processInstruction.
@Override
public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
SparkExecutionContext sec = (SparkExecutionContext) ec;
CacheType type = _type;
String rddVar = type.isRight() ? input1.getName() : input2.getName();
String bcastVar = type.isRight() ? input2.getName() : input1.getName();
MatrixCharacteristics mcRdd = sec.getMatrixCharacteristics(rddVar);
MatrixCharacteristics mcBc = sec.getMatrixCharacteristics(bcastVar);
//get input rdd
JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
//inputs - is required to ensure moderately sized output partitions (2GB limitation)
if (requiresFlatMapFunction(type, mcBc) && requiresRepartitioning(type, mcRdd, mcBc, in1.getNumPartitions())) {
int numParts = getNumRepartitioning(type, mcRdd, mcBc);
int numParts2 = getNumRepartitioning(type.getFlipped(), mcBc, mcRdd);
if (numParts2 > numParts) {
//flip required
type = type.getFlipped();
rddVar = type.isRight() ? input1.getName() : input2.getName();
bcastVar = type.isRight() ? input2.getName() : input1.getName();
mcRdd = sec.getMatrixCharacteristics(rddVar);
mcBc = sec.getMatrixCharacteristics(bcastVar);
in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
LOG.warn("Mapmm: Switching rdd ('" + bcastVar + "') and broadcast ('" + rddVar + "') inputs " + "for repartitioning because this allows better control of output partition " + "sizes (" + numParts + " < " + numParts2 + ").");
}
}
//get inputs
PartitionedBroadcast<MatrixBlock> in2 = sec.getBroadcastForVariable(bcastVar);
//empty input block filter
if (!_outputEmpty)
in1 = in1.filter(new FilterNonEmptyBlocksFunction());
//execute mapmm and aggregation if necessary and put output into symbol table
if (_aggtype == SparkAggType.SINGLE_BLOCK) {
JavaRDD<MatrixBlock> out = in1.map(new RDDMapMMFunction2(type, in2));
MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
//put output block into symbol table (no lineage because single block)
//this also includes implicit maintenance of matrix characteristics
sec.setMatrixOutput(output.getName(), out2);
} else //MULTI_BLOCK or NONE
{
JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
if (requiresFlatMapFunction(type, mcBc)) {
if (requiresRepartitioning(type, mcRdd, mcBc, in1.getNumPartitions())) {
int numParts = getNumRepartitioning(type, mcRdd, mcBc);
LOG.warn("Mapmm: Repartition input rdd '" + rddVar + "' from " + in1.getNumPartitions() + " to " + numParts + " partitions to satisfy size restrictions of output partitions.");
in1 = in1.repartition(numParts);
}
out = in1.flatMapToPair(new RDDFlatMapMMFunction(type, in2));
} else if (preservesPartitioning(mcRdd, type))
out = in1.mapPartitionsToPair(new RDDMapMMPartitionFunction(type, in2), true);
else
out = in1.mapToPair(new RDDMapMMFunction(type, in2));
//empty output block filter
if (!_outputEmpty)
out = out.filter(new FilterNonEmptyBlocksFunction());
if (_aggtype == SparkAggType.MULTI_BLOCK)
out = RDDAggregateUtils.sumByKeyStable(out, false);
//put output RDD handle into symbol table
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), rddVar);
sec.addLineageBroadcast(output.getName(), bcastVar);
//update output statistics if not inferred
updateBinaryMMOutputMatrixCharacteristics(sec, true);
}
}
Aggregations