Search in sources :

Example 41 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class RewriteRemoveUnnecessaryBranches method rewriteStatementBlock.

@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
    ArrayList<StatementBlock> ret = new ArrayList<>();
    if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        Hop pred = isb.getPredicateHops().getInput().get(0);
        // apply rewrite if literal op (constant value)
        if (pred instanceof LiteralOp) {
            IfStatement istmt = (IfStatement) isb.getStatement(0);
            LiteralOp litpred = (LiteralOp) pred;
            boolean condition = HopRewriteUtils.getBooleanValue(litpred);
            if (condition) {
                // pull-out simple if body
                if (!istmt.getIfBody().isEmpty())
                    // pull if-branch
                    ret.addAll(istmt.getIfBody());
            // otherwise: add nothing (remove if-else)
            } else {
                // pull-out simple else body
                if (!istmt.getElseBody().isEmpty())
                    // pull else-branch
                    ret.addAll(istmt.getElseBody());
            // otherwise: add nothing (remove if-else)
            }
            state.setRemovedBranches();
            LOG.debug("Applied removeUnnecessaryBranches (lines " + sb.getBeginLine() + "-" + sb.getEndLine() + ").");
        } else
            // keep original sb (non-constant condition)
            ret.add(sb);
    } else
        // keep original sb (no if)
        ret.add(sb);
    return ret;
}
Also used : IfStatement(org.apache.sysml.parser.IfStatement) ArrayList(java.util.ArrayList) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) StatementBlock(org.apache.sysml.parser.StatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 42 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class RewriteSplitDagDataDependentOperators method rewriteStatementBlock.

@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
    // DAG splits not required for forced single node
    if (DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE || !HopRewriteUtils.isLastLevelStatementBlock(sb))
        return Arrays.asList(sb);
    ArrayList<StatementBlock> ret = new ArrayList<>();
    // collect all unknown csv reads hops
    ArrayList<Hop> cand = new ArrayList<>();
    collectDataDependentOperators(sb.getHops(), cand);
    Hop.resetVisitStatus(sb.getHops());
    // split hop dag on demand
    if (!cand.isEmpty()) {
        // collect child operators of candidates (to prevent rewrite anomalies)
        HashSet<Hop> candChilds = new HashSet<>();
        collectCandidateChildOperators(cand, candChilds);
        try {
            // duplicate sb incl live variable sets
            StatementBlock sb1 = new StatementBlock();
            sb1.setDMLProg(sb.getDMLProg());
            sb1.setParseInfo(sb);
            sb1.setLiveIn(new VariableSet());
            sb1.setLiveOut(new VariableSet());
            // move data-dependent ops incl transient writes to new statement block
            // (and replace original persistent read with transient read)
            ArrayList<Hop> sb1hops = new ArrayList<>();
            for (Hop c : cand) {
                // if there are already transient writes use them and don't introduce artificial variables;
                // unless there are transient reads w/ the same variable name in the current dag which can
                // lead to invalid reordering if variable consumers are not feeding into the candidate op.
                boolean hasTWrites = hasTransientWriteParents(c);
                boolean moveTWrite = hasTWrites ? HopRewriteUtils.rHasSimpleReadChain(c, getFirstTransientWriteParent(c).getName()) : false;
                String varname = null;
                long rlen = c.getDim1();
                long clen = c.getDim2();
                long nnz = c.getNnz();
                UpdateType update = c.getUpdateType();
                int brlen = c.getRowsInBlock();
                int bclen = c.getColsInBlock();
                if (// reuse existing transient_write
                hasTWrites && moveTWrite) {
                    Hop twrite = getFirstTransientWriteParent(c);
                    varname = twrite.getName();
                    // create new transient read
                    DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen);
                    tread.setVisited();
                    HopRewriteUtils.copyLineNumbers(c, tread);
                    // replace data-dependent operator with transient read
                    ArrayList<Hop> parents = new ArrayList<>(c.getParent());
                    for (int i = 0; i < parents.size(); i++) {
                        // prevent concurrent modification by index access
                        Hop parent = parents.get(i);
                        if (!candChilds.contains(parent)) {
                            // anomaly filter
                            if (parent != twrite)
                                HopRewriteUtils.replaceChildReference(parent, c, tread);
                            else
                                sb.getHops().remove(parent);
                        }
                    }
                    // add data-dependent operator sub dag to first statement block
                    sb1hops.add(twrite);
                } else // create transient write to artificial variables
                {
                    varname = createCutVarName(false);
                    // create new transient read
                    DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen);
                    tread.setVisited();
                    HopRewriteUtils.copyLineNumbers(c, tread);
                    // replace data-dependent operator with transient read
                    ArrayList<Hop> parents = new ArrayList<>(c.getParent());
                    for (int i = 0; i < parents.size(); i++) {
                        // prevent concurrent modification by index access
                        Hop parent = parents.get(i);
                        if (// anomaly filter
                        !candChilds.contains(parent))
                            HopRewriteUtils.replaceChildReference(parent, c, tread);
                    }
                    // add data-dependent operator sub dag to first statement block
                    DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null);
                    twrite.setVisited();
                    twrite.setOutputParams(rlen, clen, nnz, update, brlen, bclen);
                    HopRewriteUtils.copyLineNumbers(c, twrite);
                    sb1hops.add(twrite);
                }
                // update live in and out of new statement block (for piggybacking)
                DataIdentifier diVar = new DataIdentifier(varname);
                diVar.setDimensions(rlen, clen);
                diVar.setBlockDimensions(brlen, bclen);
                diVar.setDataType(c.getDataType());
                diVar.setValueType(c.getValueType());
                sb1.liveOut().addVariable(varname, new DataIdentifier(diVar));
                sb.liveIn().addVariable(varname, new DataIdentifier(diVar));
            }
            // ensure disjoint operators across DAGs (prevent replicated operations)
            handleReplicatedOperators(sb1hops, sb.getHops(), sb1.liveOut(), sb.liveIn());
            // deep copy new dag (in order to prevent any dangling references)
            sb1.setHops(Recompiler.deepCopyHopsDag(sb1hops));
            sb1.updateRecompilationFlag();
            // avoid later merge by other rewrites
            sb1.setSplitDag(true);
            // recursive application of rewrite rule (in case of multiple data dependent operators
            // with data dependencies in between each other)
            List<StatementBlock> tmp = rewriteStatementBlock(sb1, state);
            // add new statement blocks to output
            // statement block with data dependent hops
            ret.addAll(tmp);
            // statement block with remaining hops
            ret.add(sb);
            // avoid later merge by other rewrites
            sb.setSplitDag(true);
        } catch (Exception ex) {
            throw new HopsException("Failed to split hops dag for data dependent operators with unknown size.", ex);
        }
        LOG.debug("Applied splitDagDataDependentOperators (lines " + sb.getBeginLine() + "-" + sb.getEndLine() + ").");
    } else // keep original hop dag
    {
        ret.add(sb);
    }
    return ret;
}
Also used : DataIdentifier(org.apache.sysml.parser.DataIdentifier) ArrayList(java.util.ArrayList) Hop(org.apache.sysml.hops.Hop) HopsException(org.apache.sysml.hops.HopsException) UpdateType(org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType) HopsException(org.apache.sysml.hops.HopsException) VariableSet(org.apache.sysml.parser.VariableSet) DataOp(org.apache.sysml.hops.DataOp) StatementBlock(org.apache.sysml.parser.StatementBlock) HashSet(java.util.HashSet)

Example 43 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class InterProceduralAnalysis method propagateStatisticsAcrossBlock.

// ///////////////////////////
// INTRA-PROCEDURE ANALYSIS
// ////
private void propagateStatisticsAcrossBlock(StatementBlock sb, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack) {
    if (sb instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
    } else if (sb instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        // old stats into predicate
        propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, wsb);
        // check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb)) {
            // second pass if required
            propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
            for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        }
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        // old stats into predicate
        propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);
        // check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : istmt.getIfBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        for (StatementBlock sbi : istmt.getElseBody()) propagateStatisticsAcrossBlock(sbi, callVarsElse, fcallSizes, fnStack);
        callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb);
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else if (// incl parfor
    sb instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        // old stats into predicate
        propagateStatisticsAcrossPredicateDAG(fsb.getFromHops(), callVars);
        propagateStatisticsAcrossPredicateDAG(fsb.getToHops(), callVars);
        propagateStatisticsAcrossPredicateDAG(fsb.getIncrementHops(), callVars);
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, fsb);
        // check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb))
            for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else // generic (last-level)
    {
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
        // old stats in, new stats out if updated
        ArrayList<Hop> roots = sb.getHops();
        DMLProgram prog = sb.getDMLProg();
        // replace scalar reads with literals
        Hop.resetVisitStatus(roots);
        propagateScalarsAcrossDAG(roots, callVars);
        // refresh stats across dag
        Hop.resetVisitStatus(roots);
        propagateStatisticsAcrossDAG(roots, callVars);
        // propagate stats into function calls
        Hop.resetVisitStatus(roots);
        propagateStatisticsIntoFunctions(prog, roots, callVars, fcallSizes, fnStack);
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) ExternalFunctionStatement(org.apache.sysml.parser.ExternalFunctionStatement) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) ArrayList(java.util.ArrayList) DMLProgram(org.apache.sysml.parser.DMLProgram) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 44 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class InterProceduralAnalysis method isUnarySizePreservingFunction.

private boolean isUnarySizePreservingFunction(FunctionStatementBlock fsb) {
    FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
    // check unary functions over matrices
    boolean ret = (fstmt.getInputParams().size() == 1 && fstmt.getInputParams().get(0).getDataType() == DataType.MATRIX && fstmt.getOutputParams().size() == 1 && fstmt.getOutputParams().get(0).getDataType() == DataType.MATRIX);
    // check size-preserving characteristic
    if (ret) {
        FunctionCallSizeInfo fcallSizes = new FunctionCallSizeInfo(_fgraph, false);
        HashSet<String> fnStack = new HashSet<>();
        LocalVariableMap callVars = new LocalVariableMap();
        // populate input
        MatrixObject mo = createOutputMatrix(7777, 3333, -1);
        callVars.put(fstmt.getInputParams().get(0).getName(), mo);
        // propagate statistics
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        // compare output
        MatrixObject mo2 = (MatrixObject) callVars.get(fstmt.getOutputParams().get(0).getName());
        ret &= mo.getNumRows() == mo2.getNumRows() && mo.getNumColumns() == mo2.getNumColumns();
        // reset function
        mo.getMatrixCharacteristics().setDimension(-1, -1);
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
    }
    return ret;
}
Also used : ExternalFunctionStatement(org.apache.sysml.parser.ExternalFunctionStatement) FunctionStatement(org.apache.sysml.parser.FunctionStatement) MatrixObject(org.apache.sysml.runtime.controlprogram.caching.MatrixObject) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) HashSet(java.util.HashSet)

Example 45 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class Recompiler method recompileProgramBlockInstructions.

/**
 * This method does NO full program block recompile (no stats update, no rewrites, no recursion) but
 * only regenerates lops and instructions. The primary use case is recompilation after are hop configuration
 * changes which allows to preserve statistics (e.g., propagated worst case stats from other program blocks)
 * and better performance for recompiling individual program blocks.
 *
 * @param pb program block
 * @throws IOException if IOException occurs
 */
public static void recompileProgramBlockInstructions(ProgramBlock pb) throws IOException {
    if (pb instanceof WhileProgramBlock) {
        // recompile while predicate instructions
        WhileProgramBlock wpb = (WhileProgramBlock) pb;
        WhileStatementBlock wsb = (WhileStatementBlock) pb.getStatementBlock();
        if (wsb != null && wsb.getPredicateHops() != null)
            wpb.setPredicate(recompileHopsDagInstructions(wsb.getPredicateHops()));
    } else if (pb instanceof IfProgramBlock) {
        // recompile if predicate instructions
        IfProgramBlock ipb = (IfProgramBlock) pb;
        IfStatementBlock isb = (IfStatementBlock) pb.getStatementBlock();
        if (isb != null && isb.getPredicateHops() != null)
            ipb.setPredicate(recompileHopsDagInstructions(isb.getPredicateHops()));
    } else if (pb instanceof ForProgramBlock) {
        // recompile for/parfor predicate instructions
        ForProgramBlock fpb = (ForProgramBlock) pb;
        ForStatementBlock fsb = (ForStatementBlock) pb.getStatementBlock();
        if (fsb != null && fsb.getFromHops() != null)
            fpb.setFromInstructions(recompileHopsDagInstructions(fsb.getFromHops()));
        if (fsb != null && fsb.getToHops() != null)
            fpb.setToInstructions(recompileHopsDagInstructions(fsb.getToHops()));
        if (fsb != null && fsb.getIncrementHops() != null)
            fpb.setIncrementInstructions(recompileHopsDagInstructions(fsb.getIncrementHops()));
    } else {
        // recompile last-level program block instructions
        StatementBlock sb = pb.getStatementBlock();
        if (sb != null && sb.getHops() != null) {
            pb.setInstructions(recompileHopsDagInstructions(sb, sb.getHops()));
        }
    }
}
Also used : IfProgramBlock(org.apache.sysml.runtime.controlprogram.IfProgramBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) ForProgramBlock(org.apache.sysml.runtime.controlprogram.ForProgramBlock) ParForProgramBlock(org.apache.sysml.runtime.controlprogram.ParForProgramBlock) WhileProgramBlock(org.apache.sysml.runtime.controlprogram.WhileProgramBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Aggregations

StatementBlock (org.apache.sysml.parser.StatementBlock)67 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)57 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)57 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)57 FunctionStatementBlock (org.apache.sysml.parser.FunctionStatementBlock)39 Hop (org.apache.sysml.hops.Hop)28 ArrayList (java.util.ArrayList)24 FunctionStatement (org.apache.sysml.parser.FunctionStatement)22 IfStatement (org.apache.sysml.parser.IfStatement)22 ForStatement (org.apache.sysml.parser.ForStatement)20 WhileStatement (org.apache.sysml.parser.WhileStatement)19 ParForStatementBlock (org.apache.sysml.parser.ParForStatementBlock)18 ForProgramBlock (org.apache.sysml.runtime.controlprogram.ForProgramBlock)18 IfProgramBlock (org.apache.sysml.runtime.controlprogram.IfProgramBlock)16 WhileProgramBlock (org.apache.sysml.runtime.controlprogram.WhileProgramBlock)16 FunctionProgramBlock (org.apache.sysml.runtime.controlprogram.FunctionProgramBlock)13 ProgramBlock (org.apache.sysml.runtime.controlprogram.ProgramBlock)13 HashSet (java.util.HashSet)11 ExternalFunctionStatement (org.apache.sysml.parser.ExternalFunctionStatement)11 LocalVariableMap (org.apache.sysml.runtime.controlprogram.LocalVariableMap)11