Search in sources :

Example 21 with DataOp

use of org.apache.sysml.hops.DataOp in project systemml by apache.

the class DMLTranslator method constructHopsForIterablePredicate.

/**
 * Constructs all predicate Hops (for FROM, TO, INCREMENT) of an iterable predicate
 * and assigns these Hops to the passed statement block.
 *
 * Method used for both ForStatementBlock and ParForStatementBlock.
 *
 * @param fsb for statement block
 */
public void constructHopsForIterablePredicate(ForStatementBlock fsb) {
    HashMap<String, Hop> _ids = new HashMap<>();
    // set iterable predicate
    ForStatement fs = (ForStatement) fsb.getStatement(0);
    IterablePredicate ip = fs.getIterablePredicate();
    for (int i = 0; i < 3; i++) {
        Expression expr = (i == 0) ? ip.getFromExpr() : (i == 1) ? ip.getToExpr() : (ip.getIncrementExpr() != null) ? ip.getIncrementExpr() : null;
        VariableSet varsRead = (expr != null) ? expr.variablesRead() : null;
        if (varsRead != null) {
            for (String varName : varsRead.getVariables().keySet()) {
                DataIdentifier var = fsb.liveIn().getVariable(varName);
                DataOp read = null;
                if (var == null) {
                    LOG.error("variable '" + varName + "' is not available for iterable predicate");
                    throw new ParseException("variable '" + varName + "' is not available for iterable predicate");
                } else {
                    long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier) var).getOrigDim1() : var.getDim1();
                    long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier) var).getOrigDim2() : var.getDim2();
                    read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), DataOpTypes.TRANSIENTREAD, null, actualDim1, actualDim2, var.getNnz(), var.getRowsInBlock(), var.getColumnsInBlock());
                    read.setParseInfo(var);
                }
                _ids.put(varName, read);
            }
        }
        // create transient write to internal variable name on top of expression
        // in order to ensure proper instruction generation
        Hop predicateHops = processTempIntExpression(expr, _ids);
        if (predicateHops != null)
            predicateHops = HopRewriteUtils.createDataOp(ProgramBlock.PRED_VAR, predicateHops, DataOpTypes.TRANSIENTWRITE);
        // construct hops for from, to, and increment expressions
        if (i == 0)
            fsb.setFromHops(predicateHops);
        else if (i == 1)
            fsb.setToHops(predicateHops);
        else if (ip.getIncrementExpr() != null)
            fsb.setIncrementHops(predicateHops);
    }
}
Also used : HashMap(java.util.HashMap) Hop(org.apache.sysml.hops.Hop) DataOp(org.apache.sysml.hops.DataOp)

Example 22 with DataOp

use of org.apache.sysml.hops.DataOp in project systemml by apache.

the class DMLTranslator method processMultipleReturnParameterizedBuiltinFunctionExpression.

private Hop processMultipleReturnParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFunctionExpression source, ArrayList<DataIdentifier> targetList, HashMap<String, Hop> hops) {
    FunctionType ftype = FunctionType.MULTIRETURN_BUILTIN;
    String nameSpace = DMLProgram.INTERNAL_NAMESPACE;
    // Create an array list to hold the outputs of this lop.
    // Exact list of outputs are added based on opcode.
    ArrayList<Hop> outputs = new ArrayList<>();
    // Construct Hop for current builtin function expression based on its type
    Hop currBuiltinOp = null;
    switch(source.getOpCode()) {
        case TRANSFORMENCODE:
            ArrayList<Hop> inputs = new ArrayList<>();
            inputs.add(processExpression(source.getVarParam("target"), null, hops));
            inputs.add(processExpression(source.getVarParam("spec"), null, hops));
            String[] outputNames = new String[targetList.size()];
            outputNames[0] = ((DataIdentifier) targetList.get(0)).getName();
            outputNames[1] = ((DataIdentifier) targetList.get(1)).getName();
            outputs.add(new DataOp(outputNames[0], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[0]));
            outputs.add(new DataOp(outputNames[1], DataType.FRAME, ValueType.STRING, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[1]));
            currBuiltinOp = new FunctionOp(ftype, nameSpace, source.getOpCode().toString(), inputs, outputNames, outputs);
            break;
        default:
            throw new ParseException("Invaid Opcode in DMLTranslator:processMultipleReturnParameterizedBuiltinFunctionExpression(): " + source.getOpCode());
    }
    // set properties for created hops based on outputs of source expression
    for (int i = 0; i < source.getOutputs().length; i++) {
        setIdentifierParams(outputs.get(i), source.getOutputs()[i]);
        outputs.get(i).setParseInfo(source);
    }
    currBuiltinOp.setParseInfo(source);
    return currBuiltinOp;
}
Also used : FunctionType(org.apache.sysml.hops.FunctionOp.FunctionType) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) ParameterizedBuiltinFunctionOp(org.apache.sysml.parser.Expression.ParameterizedBuiltinFunctionOp) BuiltinFunctionOp(org.apache.sysml.parser.Expression.BuiltinFunctionOp) FunctionOp(org.apache.sysml.hops.FunctionOp) DataOp(org.apache.sysml.hops.DataOp)

Example 23 with DataOp

use of org.apache.sysml.hops.DataOp in project systemml by apache.

the class DMLTranslator method constructHops.

public void constructHops(StatementBlock sb) {
    if (sb instanceof WhileStatementBlock) {
        constructHopsForWhileControlBlock((WhileStatementBlock) sb);
        return;
    }
    if (sb instanceof IfStatementBlock) {
        constructHopsForIfControlBlock((IfStatementBlock) sb);
        return;
    }
    if (sb instanceof ForStatementBlock) {
        // incl ParForStatementBlock
        constructHopsForForControlBlock((ForStatementBlock) sb);
        return;
    }
    if (sb instanceof FunctionStatementBlock) {
        constructHopsForFunctionControlBlock((FunctionStatementBlock) sb);
        return;
    }
    HashMap<String, Hop> ids = new HashMap<>();
    ArrayList<Hop> output = new ArrayList<>();
    VariableSet liveIn = sb.liveIn();
    VariableSet liveOut = sb.liveOut();
    VariableSet updated = sb._updated;
    VariableSet gen = sb._gen;
    VariableSet updatedLiveOut = new VariableSet();
    // handle liveout variables that are updated --> target identifiers for Assignment
    HashMap<String, Integer> liveOutToTemp = new HashMap<>();
    for (int i = 0; i < sb.getNumStatements(); i++) {
        Statement current = sb.getStatement(i);
        if (current instanceof AssignmentStatement) {
            AssignmentStatement as = (AssignmentStatement) current;
            DataIdentifier target = as.getTarget();
            if (target != null) {
                if (liveOut.containsVariable(target.getName())) {
                    liveOutToTemp.put(target.getName(), Integer.valueOf(i));
                }
            }
        }
        if (current instanceof MultiAssignmentStatement) {
            MultiAssignmentStatement mas = (MultiAssignmentStatement) current;
            for (DataIdentifier target : mas.getTargetList()) {
                if (liveOut.containsVariable(target.getName())) {
                    liveOutToTemp.put(target.getName(), Integer.valueOf(i));
                }
            }
        }
    }
    // (i.e., from LV analysis, updated and gen sets)
    if (!liveIn.getVariables().values().isEmpty()) {
        for (String varName : liveIn.getVariables().keySet()) {
            if (updated.containsVariable(varName) || gen.containsVariable(varName)) {
                DataIdentifier var = liveIn.getVariables().get(varName);
                long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier) var).getOrigDim1() : var.getDim1();
                long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier) var).getOrigDim2() : var.getDim2();
                DataOp read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), DataOpTypes.TRANSIENTREAD, null, actualDim1, actualDim2, var.getNnz(), var.getRowsInBlock(), var.getColumnsInBlock());
                read.setParseInfo(var);
                ids.put(varName, read);
            }
        }
    }
    for (int i = 0; i < sb.getNumStatements(); i++) {
        Statement current = sb.getStatement(i);
        if (current instanceof OutputStatement) {
            OutputStatement os = (OutputStatement) current;
            DataExpression source = os.getSource();
            DataIdentifier target = os.getIdentifier();
            // error handling unsupported indexing expression in write statement
            if (target instanceof IndexedIdentifier) {
                throw new LanguageException(source.printErrorLocation() + ": Unsupported indexing expression in write statement. " + "Please, assign the right indexing result to a variable and write this variable.");
            }
            DataOp ae = (DataOp) processExpression(source, target, ids);
            String formatName = os.getExprParam(DataExpression.FORMAT_TYPE).toString();
            ae.setInputFormatType(Expression.convertFormatType(formatName));
            if (ae.getDataType() == DataType.SCALAR) {
                ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1, -1);
            } else {
                switch(ae.getInputFormatType()) {
                    case TEXT:
                    case MM:
                    case CSV:
                        // write output in textcell format
                        ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1, -1);
                        break;
                    case BINARY:
                        // write output in binary block format
                        ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize());
                        break;
                    default:
                        throw new LanguageException("Unrecognized file format: " + ae.getInputFormatType());
                }
            }
            output.add(ae);
        }
        if (current instanceof PrintStatement) {
            DataIdentifier target = createTarget();
            target.setDataType(DataType.SCALAR);
            target.setValueType(ValueType.STRING);
            target.setParseInfo(current);
            PrintStatement ps = (PrintStatement) current;
            PRINTTYPE ptype = ps.getType();
            try {
                if (ptype == PRINTTYPE.PRINT) {
                    Hop.OpOp1 op = Hop.OpOp1.PRINT;
                    Expression source = ps.getExpressions().get(0);
                    Hop ae = processExpression(source, target, ids);
                    Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
                    printHop.setParseInfo(current);
                    output.add(printHop);
                } else if (ptype == PRINTTYPE.ASSERT) {
                    Hop.OpOp1 op = Hop.OpOp1.ASSERT;
                    Expression source = ps.getExpressions().get(0);
                    Hop ae = processExpression(source, target, ids);
                    Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
                    printHop.setParseInfo(current);
                    output.add(printHop);
                } else if (ptype == PRINTTYPE.STOP) {
                    Hop.OpOp1 op = Hop.OpOp1.STOP;
                    Expression source = ps.getExpressions().get(0);
                    Hop ae = processExpression(source, target, ids);
                    Hop stopHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
                    stopHop.setParseInfo(current);
                    output.add(stopHop);
                    // avoid merge
                    sb.setSplitDag(true);
                } else if (ptype == PRINTTYPE.PRINTF) {
                    List<Expression> expressions = ps.getExpressions();
                    Hop[] inHops = new Hop[expressions.size()];
                    // Hop (ie, MultipleOp) as input Hops
                    for (int j = 0; j < expressions.size(); j++) {
                        Hop inHop = processExpression(expressions.get(j), target, ids);
                        inHops[j] = inHop;
                    }
                    target.setValueType(ValueType.STRING);
                    Hop printfHop = new NaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOpN.PRINTF, inHops);
                    output.add(printfHop);
                }
            } catch (HopsException e) {
                throw new LanguageException(e);
            }
        }
        if (current instanceof AssignmentStatement) {
            AssignmentStatement as = (AssignmentStatement) current;
            DataIdentifier target = as.getTarget();
            Expression source = as.getSource();
            // CASE: regular assignment statement -- source is DML expression that is NOT user-defined or external function
            if (!(source instanceof FunctionCallIdentifier)) {
                // CASE: target is regular data identifier
                if (!(target instanceof IndexedIdentifier)) {
                    // process right hand side and accumulation
                    Hop ae = processExpression(source, target, ids);
                    if (((AssignmentStatement) current).isAccumulator()) {
                        DataIdentifier accum = liveIn.getVariable(target.getName());
                        if (accum == null)
                            throw new LanguageException("Invalid accumulator assignment " + "to non-existing variable " + target.getName() + ".");
                        ae = HopRewriteUtils.createBinary(ids.get(target.getName()), ae, OpOp2.PLUS);
                        target.setProperties(accum.getOutput());
                    } else
                        target.setProperties(source.getOutput());
                    ids.put(target.getName(), ae);
                    // add transient write if needed
                    Integer statementId = liveOutToTemp.get(target.getName());
                    if ((statementId != null) && (statementId.intValue() == i)) {
                        DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null);
                        transientwrite.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ae.getRowsInBlock(), ae.getColsInBlock());
                        transientwrite.setParseInfo(target);
                        updatedLiveOut.addVariable(target.getName(), target);
                        output.add(transientwrite);
                    }
                } else // CASE: target is indexed identifier (left-hand side indexed expression)
                {
                    Hop ae = processLeftIndexedExpression(source, (IndexedIdentifier) target, ids);
                    ids.put(target.getName(), ae);
                    // obtain origDim values BEFORE they are potentially updated during setProperties call
                    // (this is incorrect for LHS Indexing)
                    long origDim1 = ((IndexedIdentifier) target).getOrigDim1();
                    long origDim2 = ((IndexedIdentifier) target).getOrigDim2();
                    target.setProperties(source.getOutput());
                    ((IndexedIdentifier) target).setOriginalDimensions(origDim1, origDim2);
                    // (required for scalar input to left indexing)
                    if (target.getDataType() != DataType.MATRIX) {
                        target.setDataType(DataType.MATRIX);
                        target.setValueType(ValueType.DOUBLE);
                        target.setBlockDimensions(ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize());
                    }
                    Integer statementId = liveOutToTemp.get(target.getName());
                    if ((statementId != null) && (statementId.intValue() == i)) {
                        DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null);
                        transientwrite.setOutputParams(origDim1, origDim2, ae.getNnz(), ae.getUpdateType(), ae.getRowsInBlock(), ae.getColsInBlock());
                        transientwrite.setParseInfo(target);
                        updatedLiveOut.addVariable(target.getName(), target);
                        output.add(transientwrite);
                    }
                }
            } else {
                // assignment, function call
                FunctionCallIdentifier fci = (FunctionCallIdentifier) source;
                FunctionStatementBlock fsb = this._dmlProg.getFunctionStatementBlock(fci.getNamespace(), fci.getName());
                // error handling missing function
                if (fsb == null) {
                    String error = source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace();
                    LOG.error(error);
                    throw new LanguageException(error);
                }
                // error handling unsupported function call in indexing expression
                if (target instanceof IndexedIdentifier) {
                    String fkey = DMLProgram.constructFunctionKey(fci.getNamespace(), fci.getName());
                    throw new LanguageException("Unsupported function call to '" + fkey + "' in left indexing expression. " + "Please, assign the function output to a variable.");
                }
                ArrayList<Hop> finputs = new ArrayList<>();
                for (ParameterExpression paramName : fci.getParamExprs()) {
                    Hop in = processExpression(paramName.getExpr(), null, ids);
                    finputs.add(in);
                }
                // create function op
                FunctionType ftype = fsb.getFunctionOpType();
                FunctionOp fcall = (target == null) ? new FunctionOp(ftype, fci.getNamespace(), fci.getName(), finputs, new String[] {}, false) : new FunctionOp(ftype, fci.getNamespace(), fci.getName(), finputs, new String[] { target.getName() }, false);
                output.add(fcall);
            // TODO function output dataops (phase 3)
            // DataOp trFoutput = new DataOp(target.getName(), target.getDataType(), target.getValueType(), fcall, DataOpTypes.FUNCTIONOUTPUT, null);
            // DataOp twFoutput = new DataOp(target.getName(), target.getDataType(), target.getValueType(), trFoutput, DataOpTypes.TRANSIENTWRITE, null);
            }
        } else if (current instanceof MultiAssignmentStatement) {
            // multi-assignment, by definition a function call
            MultiAssignmentStatement mas = (MultiAssignmentStatement) current;
            Expression source = mas.getSource();
            if (source instanceof FunctionCallIdentifier) {
                FunctionCallIdentifier fci = (FunctionCallIdentifier) source;
                FunctionStatementBlock fsb = this._dmlProg.getFunctionStatementBlock(fci.getNamespace(), fci.getName());
                FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
                if (fstmt == null) {
                    LOG.error(source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace());
                    throw new LanguageException(source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace());
                }
                ArrayList<Hop> finputs = new ArrayList<>();
                for (ParameterExpression paramName : fci.getParamExprs()) {
                    Hop in = processExpression(paramName.getExpr(), null, ids);
                    finputs.add(in);
                }
                // create function op
                String[] foutputs = mas.getTargetList().stream().map(d -> d.getName()).toArray(String[]::new);
                FunctionType ftype = fsb.getFunctionOpType();
                FunctionOp fcall = new FunctionOp(ftype, fci.getNamespace(), fci.getName(), finputs, foutputs, false);
                output.add(fcall);
            // TODO function output dataops (phase 3)
            /*for ( DataIdentifier paramName : mas.getTargetList() ){
						DataOp twFoutput = new DataOp(paramName.getName(), paramName.getDataType(), paramName.getValueType(), fcall, DataOpTypes.TRANSIENTWRITE, null);
						output.add(twFoutput);
					}*/
            } else if (source instanceof BuiltinFunctionExpression && ((BuiltinFunctionExpression) source).multipleReturns()) {
                // construct input hops
                Hop fcall = processMultipleReturnBuiltinFunctionExpression((BuiltinFunctionExpression) source, mas.getTargetList(), ids);
                output.add(fcall);
            } else if (source instanceof ParameterizedBuiltinFunctionExpression && ((ParameterizedBuiltinFunctionExpression) source).multipleReturns()) {
                // construct input hops
                Hop fcall = processMultipleReturnParameterizedBuiltinFunctionExpression((ParameterizedBuiltinFunctionExpression) source, mas.getTargetList(), ids);
                output.add(fcall);
            } else
                throw new LanguageException("Class \"" + source.getClass() + "\" is not supported in Multiple Assignment statements");
        }
    }
    sb.updateLiveVariablesOut(updatedLiveOut);
    sb.setHops(output);
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) PRINTTYPE(org.apache.sysml.parser.PrintStatement.PRINTTYPE) List(java.util.List) ArrayList(java.util.ArrayList) DataOp(org.apache.sysml.hops.DataOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) FunctionType(org.apache.sysml.hops.FunctionOp.FunctionType) Hop(org.apache.sysml.hops.Hop) ParameterizedBuiltinFunctionOp(org.apache.sysml.parser.Expression.ParameterizedBuiltinFunctionOp) BuiltinFunctionOp(org.apache.sysml.parser.Expression.BuiltinFunctionOp) FunctionOp(org.apache.sysml.hops.FunctionOp) HopsException(org.apache.sysml.hops.HopsException) NaryOp(org.apache.sysml.hops.NaryOp)

Example 24 with DataOp

use of org.apache.sysml.hops.DataOp in project systemml by apache.

the class DMLTranslator method processDataExpression.

/**
 * Construct Hops from parse tree : Process ParameterizedExpression in a
 * read/write/rand statement
 *
 * @param source data expression
 * @param target data identifier
 * @param hops map of high-level operators
 * @return high-level operator
 */
private Hop processDataExpression(DataExpression source, DataIdentifier target, HashMap<String, Hop> hops) {
    // this expression has multiple "named" parameters
    HashMap<String, Hop> paramHops = new HashMap<>();
    // -- construct hops for all input parameters
    // -- store them in hashmap so that their "name"s are maintained
    Hop pHop = null;
    for (String paramName : source.getVarParams().keySet()) {
        pHop = processExpression(source.getVarParam(paramName), null, hops);
        paramHops.put(paramName, pHop);
    }
    Hop currBuiltinOp = null;
    if (target == null) {
        target = createTarget(source);
    }
    // construct hop based on opcode
    switch(source.getOpCode()) {
        case READ:
            currBuiltinOp = new DataOp(target.getName(), target.getDataType(), target.getValueType(), DataOpTypes.PERSISTENTREAD, paramHops);
            ((DataOp) currBuiltinOp).setFileName(((StringIdentifier) source.getVarParam(DataExpression.IO_FILENAME)).getValue());
            break;
        case WRITE:
            currBuiltinOp = new DataOp(target.getName(), target.getDataType(), target.getValueType(), DataOpTypes.PERSISTENTWRITE, hops.get(target.getName()), paramHops);
            break;
        case RAND:
            // We limit RAND_MIN, RAND_MAX, RAND_SPARSITY, RAND_SEED, and RAND_PDF to be constants
            DataGenMethod method = (paramHops.get(DataExpression.RAND_MIN).getValueType() == ValueType.STRING) ? DataGenMethod.SINIT : DataGenMethod.RAND;
            currBuiltinOp = new DataGenOp(method, target, paramHops);
            break;
        case MATRIX:
            ArrayList<Hop> tmp = new ArrayList<>();
            tmp.add(0, paramHops.get(DataExpression.RAND_DATA));
            tmp.add(1, paramHops.get(DataExpression.RAND_ROWS));
            tmp.add(2, paramHops.get(DataExpression.RAND_COLS));
            tmp.add(3, paramHops.get(DataExpression.RAND_BY_ROW));
            currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), ReOrgOp.RESHAPE, tmp);
            break;
        default:
            LOG.error(source.printErrorLocation() + "processDataExpression():: Unknown operation:  " + source.getOpCode());
            throw new ParseException(source.printErrorLocation() + "processDataExpression():: Unknown operation:  " + source.getOpCode());
    }
    // set identifier meta data (incl dimensions and blocksizes)
    setIdentifierParams(currBuiltinOp, source.getOutput());
    if (source.getOpCode() == DataExpression.DataOp.READ)
        ((DataOp) currBuiltinOp).setInputBlockSizes(target.getRowsInBlock(), target.getColumnsInBlock());
    currBuiltinOp.setParseInfo(source);
    return currBuiltinOp;
}
Also used : HashMap(java.util.HashMap) DataGenOp(org.apache.sysml.hops.DataGenOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) ReorgOp(org.apache.sysml.hops.ReorgOp) DataGenMethod(org.apache.sysml.hops.Hop.DataGenMethod) DataOp(org.apache.sysml.hops.DataOp)

Example 25 with DataOp

use of org.apache.sysml.hops.DataOp in project systemml by apache.

the class RewriteForLoopVectorization method vectorizeIndexedCopy.

private static StatementBlock vectorizeIndexedCopy(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
    StatementBlock ret = sb;
    // check supported increment values
    if (!(increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
        return ret;
    }
    // check for applicability
    boolean apply = false;
    // row or col
    boolean rowIx = false;
    if (csb.getHops() != null && csb.getHops().size() == 1) {
        Hop root = csb.getHops().get(0);
        if (root.getDataType() == DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp) {
            LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
            Hop lixlhs = lix.getInput().get(0);
            Hop lixrhs = lix.getInput().get(1);
            if (lixlhs instanceof DataOp && lixrhs instanceof IndexingOp && lixrhs.getInput().get(0) instanceof DataOp) {
                boolean[] tmp = checkLeftAndRightIndexing(lix, (IndexingOp) lixrhs, itervar);
                apply = tmp[0];
                rowIx = tmp[1];
            }
        }
    }
    // apply rewrite if possible
    if (apply) {
        Hop root = csb.getHops().get(0);
        LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
        IndexingOp rix = (IndexingOp) lix.getInput().get(1);
        int index1 = rowIx ? 2 : 4;
        int index2 = rowIx ? 3 : 5;
        // modify left indexing bounds
        HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index1), from, index1);
        HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index2), to, index2);
        // modify right indexing
        HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index1 - 1), from, index1 - 1);
        HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index2 - 1), to, index2 - 1);
        updateLeftAndRightIndexingSizes(rowIx, lix, rix);
        ret = csb;
        LOG.debug("Applied vectorizeIndexedCopy.");
    }
    return ret;
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp)

Aggregations

DataOp (org.apache.sysml.hops.DataOp)86 Hop (org.apache.sysml.hops.Hop)75 LiteralOp (org.apache.sysml.hops.LiteralOp)44 ArrayList (java.util.ArrayList)23 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)20 UnaryOp (org.apache.sysml.hops.UnaryOp)18 StatementBlock (org.apache.sysml.parser.StatementBlock)17 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)17 HopsException (org.apache.sysml.hops.HopsException)16 IndexingOp (org.apache.sysml.hops.IndexingOp)16 HashMap (java.util.HashMap)13 FunctionOp (org.apache.sysml.hops.FunctionOp)13 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)13 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)13 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)12 DataIdentifier (org.apache.sysml.parser.DataIdentifier)11 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)11 Data (org.apache.sysml.runtime.instructions.cp.Data)11 BinaryOp (org.apache.sysml.hops.BinaryOp)9 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)9