use of org.apache.sysml.lops.Aggregate in project incubator-systemml by apache.
the class TernaryOp method constructLopsCtable.
/**
* Method to construct LOPs when op = CTABLE.
*/
private void constructLopsCtable() {
if (_op != OpOp3.CTABLE)
throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.CTABLE);
/*
* We must handle three different cases: case1 : all three
* inputs are vectors (e.g., F=ctable(A,B,W)) case2 : two
* vectors and one scalar (e.g., F=ctable(A,B)) case3 : one
* vector and two scalars (e.g., F=ctable(A))
*/
// identify the particular case
// F=ctable(A,B,W)
DataType dt1 = getInput().get(0).getDataType();
DataType dt2 = getInput().get(1).getDataType();
DataType dt3 = getInput().get(2).getDataType();
Ctable.OperationTypes ternaryOpOrig = Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3);
// Compute lops for all inputs
Lop[] inputLops = new Lop[getInput().size()];
for (int i = 0; i < getInput().size(); i++) {
inputLops[i] = getInput().get(i).constructLops();
}
ExecType et = optFindExecType();
// reset reblock requirement (see MR ctable / construct lops)
setRequiresReblock(false);
if (et == ExecType.CP || et == ExecType.SPARK) {
// for CP we support only ctable expand left
Ctable.OperationTypes ternaryOp = isSequenceRewriteApplicable(true) ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig;
boolean ignoreZeros = false;
if (isMatrixIgnoreZeroRewriteApplicable()) {
// table - rmempty - rshape
ignoreZeros = true;
inputLops[0] = ((ParameterizedBuiltinOp) getInput().get(0)).getTargetHop().getInput().get(0).constructLops();
inputLops[1] = ((ParameterizedBuiltinOp) getInput().get(1)).getTargetHop().getInput().get(0).constructLops();
}
Ctable ternary = new Ctable(inputLops, ternaryOp, getDataType(), getValueType(), ignoreZeros, et);
ternary.getOutputParameters().setDimensions(_dim1, _dim2, getRowsInBlock(), getColsInBlock(), -1);
setLineNumbers(ternary);
// force blocked output in CP (see below), otherwise binarycell
if (et == ExecType.SPARK) {
ternary.getOutputParameters().setDimensions(_dim1, _dim2, -1, -1, -1);
setRequiresReblock(true);
} else
ternary.getOutputParameters().setDimensions(_dim1, _dim2, getRowsInBlock(), getColsInBlock(), -1);
// ternary opt, w/o reblock in CP
setLops(ternary);
} else // MR
{
// for MR we support both ctable expand left and right
Ctable.OperationTypes ternaryOp = isSequenceRewriteApplicable() ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig;
Group group1 = null, group2 = null, group3 = null, group4 = null;
group1 = new Group(inputLops[0], Group.OperationTypes.Sort, getDataType(), getValueType());
group1.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group1);
Ctable ternary = null;
// create "group" lops for MATRIX inputs
switch(ternaryOp) {
case CTABLE_TRANSFORM:
// F = ctable(A,B,W)
group2 = new Group(inputLops[1], Group.OperationTypes.Sort, getDataType(), getValueType());
group2.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group2);
group3 = new Group(inputLops[2], Group.OperationTypes.Sort, getDataType(), getValueType());
group3.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group3);
if (inputLops.length == 3)
ternary = new Ctable(new Lop[] { group1, group2, group3 }, ternaryOp, getDataType(), getValueType(), et);
else
// output dimensions are given
ternary = new Ctable(new Lop[] { group1, group2, group3, inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
break;
case CTABLE_TRANSFORM_SCALAR_WEIGHT:
// F = ctable(A,B) or F = ctable(A,B,1)
group2 = new Group(inputLops[1], Group.OperationTypes.Sort, getDataType(), getValueType());
group2.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group2);
if (inputLops.length == 3)
ternary = new Ctable(new Lop[] { group1, group2, inputLops[2] }, ternaryOp, getDataType(), getValueType(), et);
else
ternary = new Ctable(new Lop[] { group1, group2, inputLops[2], inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
break;
case CTABLE_EXPAND_SCALAR_WEIGHT:
// F=ctable(seq(1,N),A) or F = ctable(seq,A,1)
// left 1, right 0 (index of input data)
int left = isSequenceRewriteApplicable(true) ? 1 : 0;
Group group = new Group(getInput().get(left).constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
group.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
if (inputLops.length == 3)
ternary = new Ctable(new Lop[] { // matrix
group, // weight
getInput().get(2).constructLops(), // left
new LiteralOp(left).constructLops() }, ternaryOp, getDataType(), getValueType(), et);
else
ternary = new Ctable(new Lop[] { // matrix
group, // weight
getInput().get(2).constructLops(), // left
new LiteralOp(left).constructLops(), inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
break;
case CTABLE_TRANSFORM_HISTOGRAM:
// F=ctable(A,1) or F = ctable(A,1,1)
if (inputLops.length == 3)
ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), getInput().get(2).constructLops() }, ternaryOp, getDataType(), getValueType(), et);
else
ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), getInput().get(2).constructLops(), inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
break;
case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM:
// F=ctable(A,1,W)
group3 = new Group(getInput().get(2).constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
group3.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group3);
if (inputLops.length == 3)
ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), group3 }, ternaryOp, getDataType(), getValueType(), et);
else
ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), group3, inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
break;
default:
throw new HopsException("Invalid ternary operator type: " + _op);
}
// output dimensions are not known at compilation time
ternary.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
setLineNumbers(ternary);
Lop lctable = ternary;
if (!(_disjointInputs || ternaryOp == Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT)) {
// no need for aggregation if (1) input indexed disjoint or one side is sequence w/ 1 increment
group4 = new Group(ternary, Group.OperationTypes.Sort, getDataType(), getValueType());
group4.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
setLineNumbers(group4);
Aggregate agg1 = new Aggregate(group4, HopsAgg2Lops.get(AggOp.SUM), getDataType(), getValueType(), ExecType.MR);
agg1.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
setLineNumbers(agg1);
// kahamSum is used for aggregation but inputs do not have
// correction values
agg1.setupCorrectionLocation(CorrectionLocationType.NONE);
lctable = agg1;
}
setLops(lctable);
// to introduce reblock lop since table itself outputs in blocked format if dims known.
if (!dimsKnown() && !_dimInputsPresent) {
setRequiresReblock(true);
}
}
}
use of org.apache.sysml.lops.Aggregate in project incubator-systemml by apache.
the class UnaryOp method constructLopsIQM.
private Lop constructLopsIQM() {
ExecType et = optFindExecType();
Hop input = getInput().get(0);
if (et == ExecType.MR) {
CombineUnary combine = CombineUnary.constructCombineLop(input.constructLops(), DataType.MATRIX, getValueType());
combine.getOutputParameters().setDimensions(input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), input.getNnz());
SortKeys sort = SortKeys.constructSortByValueLop(combine, SortKeys.OperationTypes.WithoutWeights, DataType.MATRIX, ValueType.DOUBLE, ExecType.MR);
// Sort dimensions are same as the first input
sort.getOutputParameters().setDimensions(input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), input.getNnz());
Data lit = Data.createLiteralLop(ValueType.DOUBLE, Double.toString(0.25));
lit.setAllPositions(this.getFilename(), this.getBeginLine(), this.getBeginColumn(), this.getEndLine(), this.getEndColumn());
PickByCount pick = new PickByCount(sort, lit, DataType.MATRIX, getValueType(), PickByCount.OperationTypes.RANGEPICK);
pick.getOutputParameters().setDimensions(-1, -1, getRowsInBlock(), getColsInBlock(), -1);
setLineNumbers(pick);
PartialAggregate pagg = new PartialAggregate(pick, HopsAgg2Lops.get(Hop.AggOp.SUM), HopsDirection2Lops.get(Hop.Direction.RowCol), DataType.MATRIX, getValueType());
setLineNumbers(pagg);
// Set the dimensions of PartialAggregate LOP based on the
// direction in which aggregation is performed
pagg.setDimensionsBasedOnDirection(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock());
Group group1 = new Group(pagg, Group.OperationTypes.Sort, DataType.MATRIX, getValueType());
group1.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group1);
Aggregate agg1 = new Aggregate(group1, HopsAgg2Lops.get(Hop.AggOp.SUM), DataType.MATRIX, getValueType(), ExecType.MR);
agg1.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
agg1.setupCorrectionLocation(pagg.getCorrectionLocation());
setLineNumbers(agg1);
UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType());
unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
setLineNumbers(unary1);
Unary iqm = new Unary(sort, unary1, Unary.OperationTypes.MR_IQM, DataType.SCALAR, ValueType.DOUBLE, ExecType.CP);
iqm.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
setLineNumbers(iqm);
return iqm;
} else {
SortKeys sort = SortKeys.constructSortByValueLop(input.constructLops(), SortKeys.OperationTypes.WithoutWeights, DataType.MATRIX, ValueType.DOUBLE, et);
sort.getOutputParameters().setDimensions(input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), input.getNnz());
PickByCount pick = new PickByCount(sort, null, getDataType(), getValueType(), PickByCount.OperationTypes.IQM, et, true);
pick.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(pick);
return pick;
}
}
use of org.apache.sysml.lops.Aggregate in project systemml by apache.
the class ParameterizedBuiltinOp method constructLopsGroupedAggregate.
private void constructLopsGroupedAggregate(HashMap<String, Lop> inputlops, ExecType et) {
// reset reblock requirement (see MR aggregate / construct lops)
setRequiresReblock(false);
// determine output dimensions
long outputDim1 = -1, outputDim2 = -1;
Lop numGroups = inputlops.get(Statement.GAGG_NUM_GROUPS);
if (!dimsKnown() && numGroups != null && numGroups instanceof Data && ((Data) numGroups).isLiteral()) {
long ngroups = ((Data) numGroups).getLongValue();
Lop input = inputlops.get(GroupedAggregate.COMBINEDINPUT);
long inDim1 = input.getOutputParameters().getNumRows();
long inDim2 = input.getOutputParameters().getNumCols();
boolean rowwise = (inDim1 == 1 && inDim2 > 1);
if (rowwise) {
// vector
outputDim1 = ngroups;
outputDim2 = 1;
} else {
// vector or matrix
outputDim1 = inDim2;
outputDim2 = ngroups;
}
}
// construct lops
if (et == ExecType.MR) {
Lop grp_agg = null;
// construct necessary lops: combineBinary/combineTertiary and groupedAgg
boolean isWeighted = (_paramIndexMap.get(Statement.GAGG_WEIGHTS) != null);
if (isWeighted) {
Lop append = BinaryOp.constructAppendLopChain(getInput().get(_paramIndexMap.get(Statement.GAGG_TARGET)), getInput().get(_paramIndexMap.get(Statement.GAGG_GROUPS)), getInput().get(_paramIndexMap.get(Statement.GAGG_WEIGHTS)), DataType.MATRIX, getValueType(), true, getInput().get(_paramIndexMap.get(Statement.GAGG_TARGET)));
// add the combine lop to parameter list, with a new name "combinedinput"
inputlops.put(GroupedAggregate.COMBINEDINPUT, append);
inputlops.remove(Statement.GAGG_TARGET);
inputlops.remove(Statement.GAGG_GROUPS);
inputlops.remove(Statement.GAGG_WEIGHTS);
grp_agg = new GroupedAggregate(inputlops, isWeighted, getDataType(), getValueType());
grp_agg.getOutputParameters().setDimensions(outputDim1, outputDim2, getRowsInBlock(), getColsInBlock(), -1);
setRequiresReblock(true);
} else {
Hop target = getInput().get(_paramIndexMap.get(Statement.GAGG_TARGET));
Hop groups = getInput().get(_paramIndexMap.get(Statement.GAGG_GROUPS));
Lop append = null;
// physical operator selection
double groupsSizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(groups.getDim1(), groups.getDim2(), groups.getRowsInBlock(), groups.getColsInBlock(), groups.getNnz());
if (// mapgroupedagg
groupsSizeP < OptimizerUtils.getRemoteMemBudgetMap(true) && getParameterHop(Statement.GAGG_FN) instanceof LiteralOp && ((LiteralOp) getParameterHop(Statement.GAGG_FN)).getStringValue().equals("sum") && inputlops.get(Statement.GAGG_NUM_GROUPS) != null) {
// pre partitioning
boolean needPart = (groups.dimsKnown() && groups.getDim1() * groups.getDim2() > DistributedCacheInput.PARTITION_SIZE);
if (needPart) {
ExecType etPart = (OptimizerUtils.estimateSizeExactSparsity(groups.getDim1(), groups.getDim2(), 1.0) < OptimizerUtils.getLocalMemBudget()) ? ExecType.CP : // operator selection
ExecType.MR;
Lop dcinput = new DataPartition(groups.constructLops(), DataType.MATRIX, ValueType.DOUBLE, etPart, PDataPartitionFormat.ROW_BLOCK_WISE_N);
dcinput.getOutputParameters().setDimensions(groups.getDim1(), groups.getDim2(), target.getRowsInBlock(), target.getColsInBlock(), groups.getNnz());
setLineNumbers(dcinput);
inputlops.put(Statement.GAGG_GROUPS, dcinput);
}
Lop grp_agg_m = new GroupedAggregateM(inputlops, getDataType(), getValueType(), needPart, ExecType.MR);
grp_agg_m.getOutputParameters().setDimensions(outputDim1, outputDim2, target.getRowsInBlock(), target.getColsInBlock(), -1);
setLineNumbers(grp_agg_m);
// post aggregation
Group grp = new Group(grp_agg_m, Group.OperationTypes.Sort, getDataType(), getValueType());
grp.getOutputParameters().setDimensions(outputDim1, outputDim2, target.getRowsInBlock(), target.getColsInBlock(), -1);
setLineNumbers(grp);
Aggregate agg1 = new Aggregate(grp, HopsAgg2Lops.get(AggOp.SUM), getDataType(), getValueType(), ExecType.MR);
agg1.setupCorrectionLocation(CorrectionLocationType.NONE);
agg1.getOutputParameters().setDimensions(outputDim1, outputDim2, target.getRowsInBlock(), target.getColsInBlock(), -1);
grp_agg = agg1;
// note: no reblock required
} else // general case: groupedagg
{
if (// multi-column-block result matrix
target.getDim2() >= target.getColsInBlock() || // unkown
target.getDim2() <= 0) {
long m1_dim1 = target.getDim1();
long m1_dim2 = target.getDim2();
long m2_dim1 = groups.getDim1();
long m2_dim2 = groups.getDim2();
long m3_dim1 = m1_dim1;
long m3_dim2 = ((m1_dim2 >= 0 && m2_dim2 >= 0) ? (m1_dim2 + m2_dim2) : -1);
long m3_nnz = (target.getNnz() > 0 && groups.getNnz() > 0) ? (target.getNnz() + groups.getNnz()) : -1;
long brlen = target.getRowsInBlock();
long bclen = target.getColsInBlock();
Lop offset = createOffsetLop(target, true);
Lop rep = new RepMat(groups.constructLops(), offset, true, groups.getDataType(), groups.getValueType());
setOutputDimensions(rep);
setLineNumbers(rep);
Group group1 = new Group(target.constructLops(), Group.OperationTypes.Sort, DataType.MATRIX, target.getValueType());
group1.getOutputParameters().setDimensions(m1_dim1, m1_dim2, brlen, bclen, target.getNnz());
setLineNumbers(group1);
Group group2 = new Group(rep, Group.OperationTypes.Sort, DataType.MATRIX, groups.getValueType());
group1.getOutputParameters().setDimensions(m2_dim1, m2_dim2, brlen, bclen, groups.getNnz());
setLineNumbers(group2);
append = new AppendR(group1, group2, DataType.MATRIX, ValueType.DOUBLE, true, ExecType.MR);
append.getOutputParameters().setDimensions(m3_dim1, m3_dim2, brlen, bclen, m3_nnz);
setLineNumbers(append);
} else // single-column-block vector or matrix
{
append = BinaryOp.constructMRAppendLop(target, groups, DataType.MATRIX, getValueType(), true, target);
}
// add the combine lop to parameter list, with a new name "combinedinput"
inputlops.put(GroupedAggregate.COMBINEDINPUT, append);
inputlops.remove(Statement.GAGG_TARGET);
inputlops.remove(Statement.GAGG_GROUPS);
grp_agg = new GroupedAggregate(inputlops, isWeighted, getDataType(), getValueType());
grp_agg.getOutputParameters().setDimensions(outputDim1, outputDim2, getRowsInBlock(), getColsInBlock(), -1);
setRequiresReblock(true);
}
}
setLineNumbers(grp_agg);
setLops(grp_agg);
} else // CP/Spark
{
Lop grp_agg = null;
if (et == ExecType.CP) {
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
grp_agg = new GroupedAggregate(inputlops, getDataType(), getValueType(), et, k);
grp_agg.getOutputParameters().setDimensions(outputDim1, outputDim2, getRowsInBlock(), getColsInBlock(), -1);
} else if (et == ExecType.SPARK) {
// physical operator selection
Hop groups = getParameterHop(Statement.GAGG_GROUPS);
boolean broadcastGroups = (_paramIndexMap.get(Statement.GAGG_WEIGHTS) == null && OptimizerUtils.checkSparkBroadcastMemoryBudget(groups.getDim1(), groups.getDim2(), groups.getRowsInBlock(), groups.getColsInBlock(), groups.getNnz()));
if (// mapgroupedagg
broadcastGroups && getParameterHop(Statement.GAGG_FN) instanceof LiteralOp && ((LiteralOp) getParameterHop(Statement.GAGG_FN)).getStringValue().equals("sum") && inputlops.get(Statement.GAGG_NUM_GROUPS) != null) {
Hop target = getTargetHop();
grp_agg = new GroupedAggregateM(inputlops, getDataType(), getValueType(), true, ExecType.SPARK);
grp_agg.getOutputParameters().setDimensions(outputDim1, outputDim2, target.getRowsInBlock(), target.getColsInBlock(), -1);
// no reblock required (directly output binary block)
} else // groupedagg (w/ or w/o broadcast)
{
grp_agg = new GroupedAggregate(inputlops, getDataType(), getValueType(), et, broadcastGroups);
grp_agg.getOutputParameters().setDimensions(outputDim1, outputDim2, -1, -1, -1);
setRequiresReblock(true);
}
}
setLineNumbers(grp_agg);
setLops(grp_agg);
}
}
use of org.apache.sysml.lops.Aggregate in project systemml by apache.
the class ParameterizedBuiltinOp method constructLopsRExpand.
private void constructLopsRExpand(HashMap<String, Lop> inputlops, ExecType et) {
if (et == ExecType.CP || et == ExecType.SPARK) {
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
ParameterizedBuiltin pbilop = new ParameterizedBuiltin(inputlops, HopsParameterizedBuiltinLops.get(_op), getDataType(), getValueType(), et, k);
setOutputDimensions(pbilop);
setLineNumbers(pbilop);
setLops(pbilop);
} else if (et == ExecType.MR) {
ParameterizedBuiltin pbilop = new ParameterizedBuiltin(inputlops, HopsParameterizedBuiltinLops.get(_op), getDataType(), getValueType(), et);
setOutputDimensions(pbilop);
setLineNumbers(pbilop);
Group group1 = new Group(pbilop, Group.OperationTypes.Sort, getDataType(), getValueType());
setOutputDimensions(group1);
setLineNumbers(group1);
Aggregate finalagg = new Aggregate(group1, Aggregate.OperationTypes.Sum, DataType.MATRIX, getValueType(), ExecType.MR);
setOutputDimensions(finalagg);
setLineNumbers(finalagg);
setLops(finalagg);
}
}
use of org.apache.sysml.lops.Aggregate in project systemml by apache.
the class QuaternaryOp method constructMRLopsWeightedDivMM.
private void constructMRLopsWeightedDivMM(WDivMMType wtype) {
// NOTE: the common case for wdivmm are factors U/V with a rank of 10s to 100s; the current runtime only
// supports single block outer products (U/V rank <= blocksize, i.e., 1000 by default); we enforce this
// by applying the hop rewrite for Weighted DivMM only if this constraint holds.
Hop W = getInput().get(0);
Hop U = getInput().get(1);
Hop V = getInput().get(2);
Hop X = getInput().get(3);
// MR operator selection, part1
// size U
double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2());
// size V
double m2Size = OptimizerUtils.estimateSize(V.getDim1(), V.getDim2());
boolean isMapWdivmm = ((!wtype.hasFourInputs() || wtype.hasScalar()) && m1Size + m2Size < OptimizerUtils.getRemoteMemBudgetMap(true));
if (// broadcast
!FORCE_REPLICATION && isMapWdivmm) {
// partitioning of U
boolean needPartU = !U.dimsKnown() || U.getDim1() * U.getDim2() > DistributedCacheInput.PARTITION_SIZE;
Lop lU = U.constructLops();
if (needPartU) {
// requires partitioning
lU = new DataPartition(lU, DataType.MATRIX, ValueType.DOUBLE, (m1Size > OptimizerUtils.getLocalMemBudget()) ? ExecType.MR : ExecType.CP, PDataPartitionFormat.ROW_BLOCK_WISE_N);
lU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), getRowsInBlock(), getColsInBlock(), U.getNnz());
setLineNumbers(lU);
}
// partitioning of V
boolean needPartV = !V.dimsKnown() || V.getDim1() * V.getDim2() > DistributedCacheInput.PARTITION_SIZE;
Lop lV = V.constructLops();
if (needPartV) {
// requires partitioning
lV = new DataPartition(lV, DataType.MATRIX, ValueType.DOUBLE, (m2Size > OptimizerUtils.getLocalMemBudget()) ? ExecType.MR : ExecType.CP, PDataPartitionFormat.ROW_BLOCK_WISE_N);
lV.getOutputParameters().setDimensions(V.getDim1(), V.getDim2(), getRowsInBlock(), getColsInBlock(), V.getNnz());
setLineNumbers(lV);
}
// map-side wdivmm always with broadcast
Lop wdivmm = new WeightedDivMM(W.constructLops(), lU, lV, X.constructLops(), DataType.MATRIX, ValueType.DOUBLE, wtype, ExecType.MR);
setOutputDimensions(wdivmm);
setLineNumbers(wdivmm);
setLops(wdivmm);
} else // general case
{
// MR operator selection part 2 (both cannot happen for wdivmm, otherwise mapwdivmm)
boolean cacheU = !FORCE_REPLICATION && (m1Size < OptimizerUtils.getRemoteMemBudgetReduce());
boolean cacheV = !FORCE_REPLICATION && ((!cacheU && m2Size < OptimizerUtils.getRemoteMemBudgetReduce()) || (cacheU && m1Size + m2Size < OptimizerUtils.getRemoteMemBudgetReduce()));
Group grpW = new Group(W.constructLops(), Group.OperationTypes.Sort, DataType.MATRIX, ValueType.DOUBLE);
grpW.getOutputParameters().setDimensions(W.getDim1(), W.getDim2(), W.getRowsInBlock(), W.getColsInBlock(), W.getNnz());
setLineNumbers(grpW);
Lop grpX = X.constructLops();
if (wtype.hasFourInputs() && (X.getDataType() != DataType.SCALAR))
grpX = new Group(grpX, Group.OperationTypes.Sort, DataType.MATRIX, ValueType.DOUBLE);
grpX.getOutputParameters().setDimensions(X.getDim1(), X.getDim2(), X.getRowsInBlock(), X.getColsInBlock(), X.getNnz());
setLineNumbers(grpX);
Lop lU = constructLeftFactorMRLop(U, V, cacheU, m1Size);
Lop lV = constructRightFactorMRLop(U, V, cacheV, m2Size);
// reduce-side wdivmm w/ or without broadcast
Lop wdivmm = new WeightedDivMMR(grpW, lU, lV, grpX, DataType.MATRIX, ValueType.DOUBLE, wtype, cacheU, cacheV, ExecType.MR);
setOutputDimensions(wdivmm);
setLineNumbers(wdivmm);
setLops(wdivmm);
}
// in contrast to to wsloss/wsigmoid, wdivmm requires partial aggregation (for the final mm)
Group grp = new Group(getLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
setOutputDimensions(grp);
setLineNumbers(grp);
Aggregate agg1 = new Aggregate(grp, HopsAgg2Lops.get(AggOp.SUM), getDataType(), getValueType(), ExecType.MR);
// aggregation uses kahanSum but the inputs do not have correction values
agg1.setupCorrectionLocation(CorrectionLocationType.NONE);
setOutputDimensions(agg1);
setLineNumbers(agg1);
setLops(agg1);
}
Aggregations