use of org.apache.sysml.parser.Expression.DataType in project incubator-systemml by apache.
the class ArithmeticBinarySPInstruction method parseInstruction.
public static ArithmeticBinarySPInstruction parseInstruction(String str) throws DMLRuntimeException {
CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
String opcode = null;
boolean isBroadcast = false;
VectorType vtype = null;
if (str.startsWith("SPARK" + Lop.OPERAND_DELIMITOR + "map")) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 5);
opcode = parts[0];
in1.split(parts[1]);
in2.split(parts[2]);
out.split(parts[3]);
vtype = VectorType.valueOf(parts[5]);
isBroadcast = true;
} else {
opcode = parseBinaryInstruction(str, in1, in2, out);
}
// Arithmetic operations must be performed on DOUBLE or INT
DataType dt1 = in1.getDataType();
DataType dt2 = in2.getDataType();
Operator operator = (dt1 != dt2) ? InstructionUtils.parseScalarBinaryOperator(opcode, (dt1 == DataType.SCALAR)) : InstructionUtils.parseExtendedBinaryOperator(opcode);
if (dt1 == DataType.MATRIX || dt2 == DataType.MATRIX) {
if (dt1 == DataType.MATRIX && dt2 == DataType.MATRIX) {
if (isBroadcast)
return new MatrixBVectorArithmeticSPInstruction(operator, in1, in2, out, vtype, opcode, str);
else
return new MatrixMatrixArithmeticSPInstruction(operator, in1, in2, out, opcode, str);
} else
return new MatrixScalarArithmeticSPInstruction(operator, in1, in2, out, opcode, str);
}
return null;
}
use of org.apache.sysml.parser.Expression.DataType in project incubator-systemml by apache.
the class BinaryOp method inferOutputCharacteristics.
@Override
protected long[] inferOutputCharacteristics(MemoTable memo) {
long[] ret = null;
MatrixCharacteristics[] mc = memo.getAllInputStats(getInput());
Hop input1 = getInput().get(0);
Hop input2 = getInput().get(1);
DataType dt1 = input1.getDataType();
DataType dt2 = input2.getDataType();
if (op == OpOp2.CBIND) {
long ldim1 = -1, ldim2 = -1, lnnz = -1;
if (mc[0].rowsKnown() || mc[1].rowsKnown())
ldim1 = mc[0].rowsKnown() ? mc[0].getRows() : mc[1].getRows();
if (mc[0].colsKnown() && mc[1].colsKnown())
ldim2 = mc[0].getCols() + mc[1].getCols();
if (mc[0].nnzKnown() && mc[1].nnzKnown())
lnnz = mc[0].getNonZeros() + mc[1].getNonZeros();
if (ldim1 >= 0 || ldim2 >= 0 || lnnz >= 0)
return new long[] { ldim1, ldim2, lnnz };
} else if (op == OpOp2.RBIND) {
long ldim1 = -1, ldim2 = -1, lnnz = -1;
if (mc[0].colsKnown() || mc[1].colsKnown())
ldim2 = mc[0].colsKnown() ? mc[0].getCols() : mc[1].getCols();
if (mc[0].rowsKnown() && mc[1].rowsKnown())
ldim1 = mc[0].getRows() + mc[1].getRows();
if (mc[0].nnzKnown() && mc[1].nnzKnown())
lnnz = mc[0].getNonZeros() + mc[1].getNonZeros();
if (ldim1 >= 0 || ldim2 >= 0 || lnnz >= 0)
return new long[] { ldim1, ldim2, lnnz };
} else if (op == OpOp2.SOLVE) {
// Output is a (likely to be dense) vector of size number of columns in the first input
if (mc[0].getCols() >= 0) {
ret = new long[] { mc[0].getCols(), 1, mc[0].getCols() };
}
} else // general case
{
long ldim1, ldim2;
double sp1 = 1.0, sp2 = 1.0;
if (dt1 == DataType.MATRIX && dt2 == DataType.SCALAR && mc[0].dimsKnown()) {
ldim1 = mc[0].getRows();
ldim2 = mc[0].getCols();
sp1 = (mc[0].getNonZeros() > 0) ? OptimizerUtils.getSparsity(ldim1, ldim2, mc[0].getNonZeros()) : 1.0;
} else if (dt1 == DataType.SCALAR && dt2 == DataType.MATRIX) {
ldim1 = mc[1].getRows();
ldim2 = mc[1].getCols();
sp2 = (mc[1].getNonZeros() > 0) ? OptimizerUtils.getSparsity(ldim1, ldim2, mc[1].getNonZeros()) : 1.0;
} else // MATRIX - MATRIX
{
// for cols we need to be careful with regard to matrix-vector operations
if (// OUTER VECTOR OPERATION
outer) {
ldim1 = mc[0].getRows();
ldim2 = mc[1].getCols();
} else // GENERAL CASE
{
ldim1 = (mc[0].rowsKnown()) ? mc[0].getRows() : (mc[1].getRows() > 1) ? mc[1].getRows() : -1;
ldim2 = (mc[0].colsKnown()) ? mc[0].getCols() : (mc[1].getCols() > 1) ? mc[1].getCols() : -1;
}
sp1 = (mc[0].getNonZeros() > 0) ? OptimizerUtils.getSparsity(ldim1, ldim2, mc[0].getNonZeros()) : 1.0;
sp2 = (mc[1].getNonZeros() > 0) ? OptimizerUtils.getSparsity(ldim1, ldim2, mc[1].getNonZeros()) : 1.0;
}
if (ldim1 >= 0 && ldim2 >= 0) {
if (OptimizerUtils.isBinaryOpConditionalSparseSafe(op) && input2 instanceof LiteralOp) {
long lnnz = (long) (ldim1 * ldim2 * OptimizerUtils.getBinaryOpSparsityConditionalSparseSafe(sp1, op, (LiteralOp) input2));
ret = new long[] { ldim1, ldim2, lnnz };
} else {
// sparsity estimates are conservative in terms of the worstcase behavior, however,
// for outer vector operations the average case is equivalent to the worst case.
long lnnz = (long) (ldim1 * ldim2 * OptimizerUtils.getBinaryOpSparsity(sp1, sp2, op, !outer));
ret = new long[] { ldim1, ldim2, lnnz };
}
}
}
return ret;
}
use of org.apache.sysml.parser.Expression.DataType in project incubator-systemml by apache.
the class BinaryOp method constructLopsBinaryDefault.
private void constructLopsBinaryDefault() {
/* Default behavior for BinaryOp */
// it depends on input data types
DataType dt1 = getInput().get(0).getDataType();
DataType dt2 = getInput().get(1).getDataType();
if (dt1 == dt2 && dt1 == DataType.SCALAR) {
// Both operands scalar
BinaryScalar binScalar1 = new BinaryScalar(getInput().get(0).constructLops(), getInput().get(1).constructLops(), HopsOpOp2LopsBS.get(op), getDataType(), getValueType());
binScalar1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
setLineNumbers(binScalar1);
setLops(binScalar1);
} else if ((dt1 == DataType.MATRIX && dt2 == DataType.SCALAR) || (dt1 == DataType.SCALAR && dt2 == DataType.MATRIX)) {
// One operand is Matrix and the other is scalar
ExecType et = optFindExecType();
// select specific operator implementations
Unary.OperationTypes ot = null;
Hop right = getInput().get(1);
if (op == OpOp2.POW && right instanceof LiteralOp && ((LiteralOp) right).getDoubleValue() == 2.0)
ot = Unary.OperationTypes.POW2;
else if (op == OpOp2.MULT && right instanceof LiteralOp && ((LiteralOp) right).getDoubleValue() == 2.0)
ot = Unary.OperationTypes.MULTIPLY2;
else
// general case
ot = HopsOpOp2LopsU.get(op);
Unary unary1 = new Unary(getInput().get(0).constructLops(), getInput().get(1).constructLops(), ot, getDataType(), getValueType(), et);
setOutputDimensions(unary1);
setLineNumbers(unary1);
setLops(unary1);
} else {
// Both operands are Matrixes
ExecType et = optFindExecType();
boolean isGPUSoftmax = et == ExecType.GPU && op == Hop.OpOp2.DIV && getInput().get(0) instanceof UnaryOp && getInput().get(1) instanceof AggUnaryOp && ((UnaryOp) getInput().get(0)).getOp() == OpOp1.EXP && ((AggUnaryOp) getInput().get(1)).getOp() == AggOp.SUM && ((AggUnaryOp) getInput().get(1)).getDirection() == Direction.Row && getInput().get(0) == getInput().get(1).getInput().get(0);
if (isGPUSoftmax) {
UnaryCP softmax = new UnaryCP(getInput().get(0).getInput().get(0).constructLops(), UnaryCP.OperationTypes.SOFTMAX, getDataType(), getValueType(), et);
setOutputDimensions(softmax);
setLineNumbers(softmax);
setLops(softmax);
} else if (et == ExecType.CP || et == ExecType.GPU) {
Lop binary = null;
boolean isLeftXGt = (getInput().get(0) instanceof BinaryOp) && ((BinaryOp) getInput().get(0)).getOp() == OpOp2.GREATER;
Hop potentialZero = isLeftXGt ? ((BinaryOp) getInput().get(0)).getInput().get(1) : null;
boolean isLeftXGt0 = isLeftXGt && potentialZero != null && potentialZero instanceof LiteralOp && ((LiteralOp) potentialZero).getDoubleValue() == 0;
if (op == OpOp2.MULT && isLeftXGt0 && !getInput().get(0).isVector() && !getInput().get(1).isVector() && getInput().get(0).dimsKnown() && getInput().get(1).dimsKnown()) {
binary = new ConvolutionTransform(getInput().get(0).getInput().get(0).constructLops(), getInput().get(1).constructLops(), ConvolutionTransform.OperationTypes.RELU_BACKWARD, getDataType(), getValueType(), et, -1);
} else
binary = new Binary(getInput().get(0).constructLops(), getInput().get(1).constructLops(), HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et);
setOutputDimensions(binary);
setLineNumbers(binary);
setLops(binary);
} else if (et == ExecType.SPARK) {
Hop left = getInput().get(0);
Hop right = getInput().get(1);
MMBinaryMethod mbin = optFindMMBinaryMethodSpark(left, right);
Lop binary = null;
if (mbin == MMBinaryMethod.MR_BINARY_UAGG_CHAIN) {
AggUnaryOp uRight = (AggUnaryOp) right;
binary = new BinaryUAggChain(left.constructLops(), HopsOpOp2LopsB.get(op), HopsAgg2Lops.get(uRight.getOp()), HopsDirection2Lops.get(uRight.getDirection()), getDataType(), getValueType(), et);
} else if (mbin == MMBinaryMethod.MR_BINARY_M) {
boolean partitioned = false;
boolean isColVector = (right.getDim2() == 1 && left.getDim1() == right.getDim1());
binary = new BinaryM(left.constructLops(), right.constructLops(), HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et, partitioned, isColVector);
} else {
binary = new Binary(left.constructLops(), right.constructLops(), HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et);
}
setOutputDimensions(binary);
setLineNumbers(binary);
setLops(binary);
} else // MR
{
Hop left = getInput().get(0);
Hop right = getInput().get(1);
MMBinaryMethod mbin = optFindMMBinaryMethod(left, right);
if (mbin == MMBinaryMethod.MR_BINARY_M) {
boolean needPart = requiresPartitioning(right);
Lop dcInput = right.constructLops();
if (needPart) {
// right side in distributed cache
ExecType etPart = (OptimizerUtils.estimateSizeExactSparsity(right.getDim1(), right.getDim2(), OptimizerUtils.getSparsity(right.getDim1(), right.getDim2(), right.getNnz())) < OptimizerUtils.getLocalMemBudget()) ? ExecType.CP : // operator selection
ExecType.MR;
dcInput = new DataPartition(dcInput, DataType.MATRIX, ValueType.DOUBLE, etPart, (right.getDim2() == 1) ? PDataPartitionFormat.ROW_BLOCK_WISE_N : PDataPartitionFormat.COLUMN_BLOCK_WISE_N);
dcInput.getOutputParameters().setDimensions(right.getDim1(), right.getDim2(), right.getRowsInBlock(), right.getColsInBlock(), right.getNnz());
dcInput.setAllPositions(right.getFilename(), right.getBeginLine(), right.getBeginColumn(), right.getEndLine(), right.getEndColumn());
}
BinaryM binary = new BinaryM(left.constructLops(), dcInput, HopsOpOp2LopsB.get(op), getDataType(), getValueType(), ExecType.MR, needPart, (right.getDim2() == 1 && left.getDim1() == right.getDim1()));
setOutputDimensions(binary);
setLineNumbers(binary);
setLops(binary);
} else if (mbin == MMBinaryMethod.MR_BINARY_UAGG_CHAIN) {
AggUnaryOp uRight = (AggUnaryOp) right;
BinaryUAggChain bin = new BinaryUAggChain(left.constructLops(), HopsOpOp2LopsB.get(op), HopsAgg2Lops.get(uRight.getOp()), HopsDirection2Lops.get(uRight.getDirection()), getDataType(), getValueType(), et);
setOutputDimensions(bin);
setLineNumbers(bin);
setLops(bin);
} else if (mbin == MMBinaryMethod.MR_BINARY_OUTER_R) {
boolean requiresRepLeft = (!right.dimsKnown() || right.getDim2() > right.getColsInBlock());
boolean requiresRepRight = (!left.dimsKnown() || left.getDim1() > right.getRowsInBlock());
Lop leftLop = left.constructLops();
Lop rightLop = right.constructLops();
if (requiresRepLeft) {
// ncol of right determines rep of left
Lop offset = createOffsetLop(right, true);
leftLop = new RepMat(leftLop, offset, true, left.getDataType(), left.getValueType());
setOutputDimensions(leftLop);
setLineNumbers(leftLop);
}
if (requiresRepRight) {
// nrow of right determines rep of right
Lop offset = createOffsetLop(left, false);
rightLop = new RepMat(rightLop, offset, false, right.getDataType(), right.getValueType());
setOutputDimensions(rightLop);
setLineNumbers(rightLop);
}
Group group1 = new Group(leftLop, Group.OperationTypes.Sort, getDataType(), getValueType());
setLineNumbers(group1);
setOutputDimensions(group1);
Group group2 = new Group(rightLop, Group.OperationTypes.Sort, getDataType(), getValueType());
setLineNumbers(group2);
setOutputDimensions(group2);
Binary binary = new Binary(group1, group2, HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et);
setOutputDimensions(binary);
setLineNumbers(binary);
setLops(binary);
} else // MMBinaryMethod.MR_BINARY_R
{
boolean requiresRep = requiresReplication(left, right);
Lop rightLop = right.constructLops();
if (requiresRep) {
// ncol of left input (determines num replicates)
Lop offset = createOffsetLop(left, (right.getDim2() <= 1));
rightLop = new RepMat(rightLop, offset, (right.getDim2() <= 1), right.getDataType(), right.getValueType());
setOutputDimensions(rightLop);
setLineNumbers(rightLop);
}
Group group1 = new Group(getInput().get(0).constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
setLineNumbers(group1);
setOutputDimensions(group1);
Group group2 = new Group(rightLop, Group.OperationTypes.Sort, getDataType(), getValueType());
setLineNumbers(group2);
setOutputDimensions(group2);
Binary binary = new Binary(group1, group2, HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et);
setLineNumbers(binary);
setOutputDimensions(binary);
setLops(binary);
}
}
}
}
use of org.apache.sysml.parser.Expression.DataType in project incubator-systemml by apache.
the class TernaryOp method constructLopsCtable.
/**
* Method to construct LOPs when op = CTABLE.
*/
private void constructLopsCtable() {
if (_op != OpOp3.CTABLE)
throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.CTABLE);
/*
* We must handle three different cases: case1 : all three
* inputs are vectors (e.g., F=ctable(A,B,W)) case2 : two
* vectors and one scalar (e.g., F=ctable(A,B)) case3 : one
* vector and two scalars (e.g., F=ctable(A))
*/
// identify the particular case
// F=ctable(A,B,W)
DataType dt1 = getInput().get(0).getDataType();
DataType dt2 = getInput().get(1).getDataType();
DataType dt3 = getInput().get(2).getDataType();
Ctable.OperationTypes ternaryOpOrig = Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3);
// Compute lops for all inputs
Lop[] inputLops = new Lop[getInput().size()];
for (int i = 0; i < getInput().size(); i++) {
inputLops[i] = getInput().get(i).constructLops();
}
ExecType et = optFindExecType();
// reset reblock requirement (see MR ctable / construct lops)
setRequiresReblock(false);
if (et == ExecType.CP || et == ExecType.SPARK) {
// for CP we support only ctable expand left
Ctable.OperationTypes ternaryOp = isSequenceRewriteApplicable(true) ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig;
boolean ignoreZeros = false;
if (isMatrixIgnoreZeroRewriteApplicable()) {
// table - rmempty - rshape
ignoreZeros = true;
inputLops[0] = ((ParameterizedBuiltinOp) getInput().get(0)).getTargetHop().getInput().get(0).constructLops();
inputLops[1] = ((ParameterizedBuiltinOp) getInput().get(1)).getTargetHop().getInput().get(0).constructLops();
}
Ctable ternary = new Ctable(inputLops, ternaryOp, getDataType(), getValueType(), ignoreZeros, et);
ternary.getOutputParameters().setDimensions(_dim1, _dim2, getRowsInBlock(), getColsInBlock(), -1);
setLineNumbers(ternary);
// force blocked output in CP (see below), otherwise binarycell
if (et == ExecType.SPARK) {
ternary.getOutputParameters().setDimensions(_dim1, _dim2, -1, -1, -1);
setRequiresReblock(true);
} else
ternary.getOutputParameters().setDimensions(_dim1, _dim2, getRowsInBlock(), getColsInBlock(), -1);
// ternary opt, w/o reblock in CP
setLops(ternary);
} else // MR
{
// for MR we support both ctable expand left and right
Ctable.OperationTypes ternaryOp = isSequenceRewriteApplicable() ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig;
Group group1 = null, group2 = null, group3 = null, group4 = null;
group1 = new Group(inputLops[0], Group.OperationTypes.Sort, getDataType(), getValueType());
group1.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group1);
Ctable ternary = null;
// create "group" lops for MATRIX inputs
switch(ternaryOp) {
case CTABLE_TRANSFORM:
// F = ctable(A,B,W)
group2 = new Group(inputLops[1], Group.OperationTypes.Sort, getDataType(), getValueType());
group2.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group2);
group3 = new Group(inputLops[2], Group.OperationTypes.Sort, getDataType(), getValueType());
group3.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group3);
if (inputLops.length == 3)
ternary = new Ctable(new Lop[] { group1, group2, group3 }, ternaryOp, getDataType(), getValueType(), et);
else
// output dimensions are given
ternary = new Ctable(new Lop[] { group1, group2, group3, inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
break;
case CTABLE_TRANSFORM_SCALAR_WEIGHT:
// F = ctable(A,B) or F = ctable(A,B,1)
group2 = new Group(inputLops[1], Group.OperationTypes.Sort, getDataType(), getValueType());
group2.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group2);
if (inputLops.length == 3)
ternary = new Ctable(new Lop[] { group1, group2, inputLops[2] }, ternaryOp, getDataType(), getValueType(), et);
else
ternary = new Ctable(new Lop[] { group1, group2, inputLops[2], inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
break;
case CTABLE_EXPAND_SCALAR_WEIGHT:
// F=ctable(seq(1,N),A) or F = ctable(seq,A,1)
// left 1, right 0 (index of input data)
int left = isSequenceRewriteApplicable(true) ? 1 : 0;
Group group = new Group(getInput().get(left).constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
group.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
if (inputLops.length == 3)
ternary = new Ctable(new Lop[] { // matrix
group, // weight
getInput().get(2).constructLops(), // left
new LiteralOp(left).constructLops() }, ternaryOp, getDataType(), getValueType(), et);
else
ternary = new Ctable(new Lop[] { // matrix
group, // weight
getInput().get(2).constructLops(), // left
new LiteralOp(left).constructLops(), inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
break;
case CTABLE_TRANSFORM_HISTOGRAM:
// F=ctable(A,1) or F = ctable(A,1,1)
if (inputLops.length == 3)
ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), getInput().get(2).constructLops() }, ternaryOp, getDataType(), getValueType(), et);
else
ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), getInput().get(2).constructLops(), inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
break;
case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM:
// F=ctable(A,1,W)
group3 = new Group(getInput().get(2).constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
group3.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group3);
if (inputLops.length == 3)
ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), group3 }, ternaryOp, getDataType(), getValueType(), et);
else
ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), group3, inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
break;
default:
throw new HopsException("Invalid ternary operator type: " + _op);
}
// output dimensions are not known at compilation time
ternary.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
setLineNumbers(ternary);
Lop lctable = ternary;
if (!(_disjointInputs || ternaryOp == Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT)) {
// no need for aggregation if (1) input indexed disjoint or one side is sequence w/ 1 increment
group4 = new Group(ternary, Group.OperationTypes.Sort, getDataType(), getValueType());
group4.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
setLineNumbers(group4);
Aggregate agg1 = new Aggregate(group4, HopsAgg2Lops.get(AggOp.SUM), getDataType(), getValueType(), ExecType.MR);
agg1.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
setLineNumbers(agg1);
// kahamSum is used for aggregation but inputs do not have
// correction values
agg1.setupCorrectionLocation(CorrectionLocationType.NONE);
lctable = agg1;
}
setLops(lctable);
// to introduce reblock lop since table itself outputs in blocked format if dims known.
if (!dimsKnown() && !_dimInputsPresent) {
setRequiresReblock(true);
}
}
}
use of org.apache.sysml.parser.Expression.DataType in project incubator-systemml by apache.
the class InterProceduralAnalysis method extractFunctionCallReturnStatistics.
/**
* Extract return variable statistics from this function into the
* calling program.
*
* @param fstmt The function statement.
* @param fop The function op.
* @param tmpVars Function's map of variables eligible for
* extraction.
* @param callVars Calling program's map of variables.
* @param overwrite Whether or not to overwrite variables in the
* calling program's variable map.
*/
private static void extractFunctionCallReturnStatistics(FunctionStatement fstmt, FunctionOp fop, LocalVariableMap tmpVars, LocalVariableMap callVars, boolean overwrite) {
ArrayList<DataIdentifier> foutputOps = fstmt.getOutputParams();
String[] outputVars = fop.getOutputVariableNames();
String fkey = fop.getFunctionKey();
try {
for (int i = 0; i < foutputOps.size(); i++) {
DataIdentifier di = foutputOps.get(i);
// name in function signature
String fvarname = di.getName();
// name in calling program
String pvarname = outputVars[i];
// output, remove that variable from the calling program's variable map.
if (callVars.keySet().contains(pvarname)) {
DataType fdataType = di.getDataType();
DataType pdataType = callVars.get(pvarname).getDataType();
if (fdataType != pdataType) {
// datatype has changed, and the calling program is reassigning the
// the variable, so remove it from the calling variable map
callVars.remove(pvarname);
}
}
// Update or add to the calling program's variable map.
if (di.getDataType() == DataType.MATRIX && tmpVars.keySet().contains(fvarname)) {
MatrixObject moIn = (MatrixObject) tmpVars.get(fvarname);
if (// not existing so far
!callVars.keySet().contains(pvarname) || overwrite) {
MatrixObject moOut = createOutputMatrix(moIn.getNumRows(), moIn.getNumColumns(), moIn.getNnz());
callVars.put(pvarname, moOut);
} else // already existing: take largest
{
Data dat = callVars.get(pvarname);
if (dat instanceof MatrixObject) {
MatrixObject moOut = (MatrixObject) dat;
MatrixCharacteristics mc = moOut.getMatrixCharacteristics();
if (OptimizerUtils.estimateSizeExactSparsity(mc.getRows(), mc.getCols(), (mc.getNonZeros() > 0) ? OptimizerUtils.getSparsity(mc) : 1.0) < OptimizerUtils.estimateSize(moIn.getNumRows(), moIn.getNumColumns())) {
// update statistics if necessary
mc.setDimension(moIn.getNumRows(), moIn.getNumColumns());
mc.setNonZeros(moIn.getNnz());
}
}
}
}
}
} catch (Exception ex) {
throw new HopsException("Failed to extract output statistics of function " + fkey + ".", ex);
}
}
Aggregations