use of org.apache.sysml.lops.PickByCount in project incubator-systemml by apache.
the class BinaryOp method constructLopsIQM.
private void constructLopsIQM(ExecType et) throws HopsException, LopsException {
if (et == ExecType.MR) {
CombineBinary combine = CombineBinary.constructCombineLop(OperationTypes.PreSort, (Lop) getInput().get(0).constructLops(), (Lop) getInput().get(1).constructLops(), DataType.MATRIX, getValueType());
combine.getOutputParameters().setDimensions(getInput().get(0).getDim1(), getInput().get(0).getDim2(), getInput().get(0).getRowsInBlock(), getInput().get(0).getColsInBlock(), getInput().get(0).getNnz());
SortKeys sort = SortKeys.constructSortByValueLop(combine, SortKeys.OperationTypes.WithWeights, DataType.MATRIX, ValueType.DOUBLE, ExecType.MR);
// Sort dimensions are same as the first input
sort.getOutputParameters().setDimensions(getInput().get(0).getDim1(), getInput().get(0).getDim2(), getInput().get(0).getRowsInBlock(), getInput().get(0).getColsInBlock(), getInput().get(0).getNnz());
Data lit = Data.createLiteralLop(ValueType.DOUBLE, Double.toString(0.25));
setLineNumbers(lit);
PickByCount pick = new PickByCount(sort, lit, DataType.MATRIX, getValueType(), PickByCount.OperationTypes.RANGEPICK);
pick.getOutputParameters().setDimensions(-1, -1, getRowsInBlock(), getColsInBlock(), -1);
setLineNumbers(pick);
PartialAggregate pagg = new PartialAggregate(pick, HopsAgg2Lops.get(Hop.AggOp.SUM), HopsDirection2Lops.get(Hop.Direction.RowCol), DataType.MATRIX, getValueType());
setLineNumbers(pagg);
// Set the dimensions of PartialAggregate LOP based on the
// direction in which aggregation is performed
pagg.setDimensionsBasedOnDirection(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock());
Group group1 = new Group(pagg, Group.OperationTypes.Sort, DataType.MATRIX, getValueType());
setOutputDimensions(group1);
setLineNumbers(group1);
Aggregate agg1 = new Aggregate(group1, HopsAgg2Lops.get(Hop.AggOp.SUM), DataType.MATRIX, getValueType(), ExecType.MR);
setOutputDimensions(agg1);
agg1.setupCorrectionLocation(pagg.getCorrectionLocation());
setLineNumbers(agg1);
UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), DataType.SCALAR, getValueType());
unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
setLineNumbers(unary1);
Unary iqm = new Unary(sort, unary1, Unary.OperationTypes.MR_IQM, DataType.SCALAR, ValueType.DOUBLE, ExecType.CP);
iqm.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
setLineNumbers(iqm);
setLops(iqm);
} else {
SortKeys sort = SortKeys.constructSortByValueLop(getInput().get(0).constructLops(), getInput().get(1).constructLops(), SortKeys.OperationTypes.WithWeights, getInput().get(0).getDataType(), getInput().get(0).getValueType(), et);
sort.getOutputParameters().setDimensions(getInput().get(0).getDim1(), getInput().get(0).getDim2(), getInput().get(0).getRowsInBlock(), getInput().get(0).getColsInBlock(), getInput().get(0).getNnz());
PickByCount pick = new PickByCount(sort, null, getDataType(), getValueType(), PickByCount.OperationTypes.IQM, et, true);
setOutputDimensions(pick);
setLineNumbers(pick);
setLops(pick);
}
}
use of org.apache.sysml.lops.PickByCount in project incubator-systemml by apache.
the class Dag method getRecordReaderInstructions.
/**
* Method to get record reader instructions for a MR job.
*
* @param node low-level operator
* @param execNodes list of exec nodes
* @param inputStrings list of input strings
* @param recordReaderInstructions list of record reader instructions
* @param nodeIndexMapping node index mapping
* @param start_index start index
* @param inputLabels list of input labels
* @param inputLops list of input lops
* @param MRJobLineNumbers MR job line numbers
* @return -1 if problem
* @throws LopsException if LopsException occurs
*/
private static int getRecordReaderInstructions(Lop node, ArrayList<Lop> execNodes, ArrayList<String> inputStrings, ArrayList<String> recordReaderInstructions, HashMap<Lop, Integer> nodeIndexMapping, int[] start_index, ArrayList<String> inputLabels, ArrayList<Lop> inputLops, ArrayList<Integer> MRJobLineNumbers) throws LopsException {
// if input source, return index
if (nodeIndexMapping.containsKey(node))
return nodeIndexMapping.get(node);
// not input source and not in exec nodes, then return.
if (!execNodes.contains(node))
return -1;
ArrayList<Integer> inputIndices = new ArrayList<Integer>();
int max_input_index = -1;
// get mapper instructions
for (int i = 0; i < node.getInputs().size(); i++) {
// recurse
Lop childNode = node.getInputs().get(i);
int ret_val = getRecordReaderInstructions(childNode, execNodes, inputStrings, recordReaderInstructions, nodeIndexMapping, start_index, inputLabels, inputLops, MRJobLineNumbers);
inputIndices.add(ret_val);
if (ret_val > max_input_index) {
max_input_index = ret_val;
//child_for_max_input_index = childNode;
}
}
// instructions
if ((node.getExecLocation() == ExecLocation.RecordReader)) {
int output_index = max_input_index;
// cannot reuse index if this is true
// need to add better indexing schemes
output_index = start_index[0];
start_index[0]++;
nodeIndexMapping.put(node, output_index);
// only Ranagepick lop can contribute to labels
if (node.getType() == Type.PickValues) {
PickByCount pbc = (PickByCount) node;
if (pbc.getOperationType() == PickByCount.OperationTypes.RANGEPICK) {
// always the second input is a scalar
int scalarIndex = 1;
// if data lop not a literal -- add label
if (node.getInputs().get(scalarIndex).getExecLocation() == ExecLocation.Data && !((Data) (node.getInputs().get(scalarIndex))).isLiteral()) {
inputLabels.add(node.getInputs().get(scalarIndex).getOutputParameters().getLabel());
inputLops.add(node.getInputs().get(scalarIndex));
}
// if not data lop, then this is an intermediate variable.
if (node.getInputs().get(scalarIndex).getExecLocation() != ExecLocation.Data) {
inputLabels.add(node.getInputs().get(scalarIndex).getOutputParameters().getLabel());
inputLops.add(node.getInputs().get(scalarIndex));
}
}
}
// get recordreader instruction.
if (node.getInputs().size() == 2) {
recordReaderInstructions.add(node.getInstructions(inputIndices.get(0), inputIndices.get(1), output_index));
if (DMLScript.ENABLE_DEBUG_MODE) {
MRJobLineNumbers.add(node._beginLine);
}
} else
throw new LopsException("Unexpected number of inputs while generating a RecordReader Instruction");
return output_index;
}
return -1;
}
Aggregations