Search in sources :

Example 16 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class Recompiler method rRecompileProgramBlock2Forced.

private static void rRecompileProgramBlock2Forced(ProgramBlock pb, long tid, HashSet<String> fnStack, ExecType et) {
    if (pb instanceof WhileProgramBlock) {
        WhileProgramBlock pbTmp = (WhileProgramBlock) pb;
        WhileStatementBlock sbTmp = (WhileStatementBlock) pbTmp.getStatementBlock();
        // recompile predicate
        if (sbTmp != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getPredicate(), true, true)))
            pbTmp.setPredicate(Recompiler.recompileHopsDag2Forced(sbTmp.getPredicateHops(), tid, et));
        // recompile body
        for (ProgramBlock pb2 : pbTmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
    } else if (pb instanceof IfProgramBlock) {
        IfProgramBlock pbTmp = (IfProgramBlock) pb;
        IfStatementBlock sbTmp = (IfStatementBlock) pbTmp.getStatementBlock();
        // recompile predicate
        if (sbTmp != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getPredicate(), true, true)))
            pbTmp.setPredicate(Recompiler.recompileHopsDag2Forced(sbTmp.getPredicateHops(), tid, et));
        // recompile body
        for (ProgramBlock pb2 : pbTmp.getChildBlocksIfBody()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
        for (ProgramBlock pb2 : pbTmp.getChildBlocksElseBody()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
    } else if (// includes ParFORProgramBlock
    pb instanceof ForProgramBlock) {
        ForProgramBlock pbTmp = (ForProgramBlock) pb;
        ForStatementBlock sbTmp = (ForStatementBlock) pbTmp.getStatementBlock();
        // recompile predicate
        if (sbTmp != null && sbTmp.getFromHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getFromInstructions(), true, true)))
            pbTmp.setFromInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getFromHops(), tid, et));
        if (sbTmp != null && sbTmp.getToHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getToInstructions(), true, true)))
            pbTmp.setToInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getToHops(), tid, et));
        if (sbTmp != null && sbTmp.getIncrementHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getIncrementInstructions(), true, true)))
            pbTmp.setIncrementInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getIncrementHops(), tid, et));
        // recompile body
        for (ProgramBlock pb2 : pbTmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
    } else if (// includes ExternalFunctionProgramBlock and ExternalFunctionProgramBlockCP
    pb instanceof FunctionProgramBlock) {
        FunctionProgramBlock tmp = (FunctionProgramBlock) pb;
        for (ProgramBlock pb2 : tmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
    } else {
        StatementBlock sb = pb.getStatementBlock();
        // would be invalid with permutation matrix mult across multiple dags)
        if (sb != null) {
            ArrayList<Instruction> tmp = pb.getInstructions();
            tmp = Recompiler.recompileHopsDag2Forced(sb, sb.getHops(), tid, et);
            pb.setInstructions(tmp);
        }
        // recompile functions
        if (OptTreeConverter.containsFunctionCallInstruction(pb)) {
            ArrayList<Instruction> tmp = pb.getInstructions();
            for (Instruction inst : tmp) if (inst instanceof FunctionCallCPInstruction) {
                FunctionCallCPInstruction func = (FunctionCallCPInstruction) inst;
                String fname = func.getFunctionName();
                String fnamespace = func.getNamespace();
                String fKey = DMLProgram.constructFunctionKey(fnamespace, fname);
                if (// memoization for multiple calls, recursion
                !fnStack.contains(fKey)) {
                    fnStack.add(fKey);
                    FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(fnamespace, fname);
                    // recompile chains of functions
                    rRecompileProgramBlock2Forced(fpb, tid, fnStack, et);
                }
            }
        }
    }
}
Also used : IfProgramBlock(org.apache.sysml.runtime.controlprogram.IfProgramBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) ForProgramBlock(org.apache.sysml.runtime.controlprogram.ForProgramBlock) ParForProgramBlock(org.apache.sysml.runtime.controlprogram.ParForProgramBlock) ArrayList(java.util.ArrayList) WhileProgramBlock(org.apache.sysml.runtime.controlprogram.WhileProgramBlock) MRJobInstruction(org.apache.sysml.runtime.instructions.MRJobInstruction) Instruction(org.apache.sysml.runtime.instructions.Instruction) FunctionCallCPInstruction(org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction) SeqInstruction(org.apache.sysml.runtime.instructions.mr.SeqInstruction) RandInstruction(org.apache.sysml.runtime.instructions.mr.RandInstruction) FunctionCallCPInstruction(org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction) FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) WhileProgramBlock(org.apache.sysml.runtime.controlprogram.WhileProgramBlock) ForProgramBlock(org.apache.sysml.runtime.controlprogram.ForProgramBlock) IfProgramBlock(org.apache.sysml.runtime.controlprogram.IfProgramBlock) ProgramBlock(org.apache.sysml.runtime.controlprogram.ProgramBlock) ParForProgramBlock(org.apache.sysml.runtime.controlprogram.ParForProgramBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 17 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class Recompiler method rRecompileProgramBlock.

// ////////////////////////////
// private helper functions //
// ////////////////////////////
private static void rRecompileProgramBlock(ProgramBlock pb, LocalVariableMap vars, RecompileStatus status, long tid, ResetType resetRecompile) {
    if (pb instanceof WhileProgramBlock) {
        WhileProgramBlock wpb = (WhileProgramBlock) pb;
        WhileStatementBlock wsb = (WhileStatementBlock) wpb.getStatementBlock();
        // recompile predicate
        recompileWhilePredicate(wpb, wsb, vars, status, tid, resetRecompile);
        // remove updated scalars because in loop
        removeUpdatedScalars(vars, wsb);
        // copy vars for later compare
        LocalVariableMap oldVars = (LocalVariableMap) vars.clone();
        RecompileStatus oldStatus = (RecompileStatus) status.clone();
        for (ProgramBlock pb2 : wpb.getChildBlocks()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
        if (reconcileUpdatedCallVarsLoops(oldVars, vars, wsb) | reconcileUpdatedCallVarsLoops(oldStatus, status, wsb)) {
            // second pass with unknowns if required
            recompileWhilePredicate(wpb, wsb, vars, status, tid, resetRecompile);
            for (ProgramBlock pb2 : wpb.getChildBlocks()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
        }
        removeUpdatedScalars(vars, wsb);
    } else if (pb instanceof IfProgramBlock) {
        IfProgramBlock ipb = (IfProgramBlock) pb;
        IfStatementBlock isb = (IfStatementBlock) ipb.getStatementBlock();
        // recompile predicate
        recompileIfPredicate(ipb, isb, vars, status, tid, resetRecompile);
        // copy vars for later compare
        LocalVariableMap oldVars = (LocalVariableMap) vars.clone();
        LocalVariableMap varsElse = (LocalVariableMap) vars.clone();
        RecompileStatus oldStatus = (RecompileStatus) status.clone();
        RecompileStatus statusElse = (RecompileStatus) status.clone();
        for (ProgramBlock pb2 : ipb.getChildBlocksIfBody()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
        for (ProgramBlock pb2 : ipb.getChildBlocksElseBody()) rRecompileProgramBlock(pb2, varsElse, statusElse, tid, resetRecompile);
        reconcileUpdatedCallVarsIf(oldVars, vars, varsElse, isb);
        reconcileUpdatedCallVarsIf(oldStatus, status, statusElse, isb);
        removeUpdatedScalars(vars, ipb.getStatementBlock());
    } else if (// includes ParFORProgramBlock
    pb instanceof ForProgramBlock) {
        ForProgramBlock fpb = (ForProgramBlock) pb;
        ForStatementBlock fsb = (ForStatementBlock) fpb.getStatementBlock();
        // recompile predicates
        recompileForPredicates(fpb, fsb, vars, status, tid, resetRecompile);
        // remove updated scalars because in loop
        removeUpdatedScalars(vars, fpb.getStatementBlock());
        // copy vars for later compare
        LocalVariableMap oldVars = (LocalVariableMap) vars.clone();
        RecompileStatus oldStatus = (RecompileStatus) status.clone();
        for (ProgramBlock pb2 : fpb.getChildBlocks()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
        if (reconcileUpdatedCallVarsLoops(oldVars, vars, fsb) | reconcileUpdatedCallVarsLoops(oldStatus, status, fsb)) {
            // second pass with unknowns if required
            recompileForPredicates(fpb, fsb, vars, status, tid, resetRecompile);
            for (ProgramBlock pb2 : fpb.getChildBlocks()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
        }
        removeUpdatedScalars(vars, fpb.getStatementBlock());
    } else if (// includes ExternalFunctionProgramBlock and ExternalFunctionProgramBlockCP
    pb instanceof FunctionProgramBlock) {
    // do nothing
    } else {
        StatementBlock sb = pb.getStatementBlock();
        ArrayList<Instruction> tmp = pb.getInstructions();
        if (sb == null)
            return;
        // recompile all for stats propagation and recompile flags
        tmp = Recompiler.recompileHopsDag(sb, sb.getHops(), vars, status, true, false, tid);
        pb.setInstructions(tmp);
        // propagate stats across hops (should be executed on clone of vars)
        Recompiler.extractDAGOutputStatistics(sb.getHops(), vars);
        // reset recompilation flags (w/ special handling functions)
        if (ParForProgramBlock.RESET_RECOMPILATION_FLAGs && !containsRootFunctionOp(sb.getHops()) && resetRecompile.isReset()) {
            Hop.resetRecompilationFlag(sb.getHops(), ExecType.CP, resetRecompile);
            sb.updateRecompilationFlag();
        }
    }
}
Also used : IfProgramBlock(org.apache.sysml.runtime.controlprogram.IfProgramBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) ForProgramBlock(org.apache.sysml.runtime.controlprogram.ForProgramBlock) ParForProgramBlock(org.apache.sysml.runtime.controlprogram.ParForProgramBlock) ArrayList(java.util.ArrayList) FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) WhileProgramBlock(org.apache.sysml.runtime.controlprogram.WhileProgramBlock) ForProgramBlock(org.apache.sysml.runtime.controlprogram.ForProgramBlock) IfProgramBlock(org.apache.sysml.runtime.controlprogram.IfProgramBlock) ProgramBlock(org.apache.sysml.runtime.controlprogram.ProgramBlock) ParForProgramBlock(org.apache.sysml.runtime.controlprogram.ParForProgramBlock) WhileProgramBlock(org.apache.sysml.runtime.controlprogram.WhileProgramBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 18 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class ProgramRewriter method rRewriteStatementBlocks.

public ArrayList<StatementBlock> rRewriteStatementBlocks(ArrayList<StatementBlock> sbs, ProgramRewriteStatus status, boolean splitDags) {
    // ensure robustness for calls from outside
    if (status == null)
        status = new ProgramRewriteStatus();
    // apply rewrite rules to list of statement blocks
    List<StatementBlock> tmp = sbs;
    for (StatementBlockRewriteRule r : _sbRuleSet) if (splitDags || !r.createsSplitDag())
        tmp = r.rewriteStatementBlocks(tmp, status);
    // recursively rewrite statement blocks (with potential expansion)
    List<StatementBlock> tmp2 = new ArrayList<>();
    for (StatementBlock sb : tmp) tmp2.addAll(rRewriteStatementBlock(sb, status, splitDags));
    // apply rewrite rules to list of statement blocks (with potential contraction)
    for (StatementBlockRewriteRule r : _sbRuleSet) if (splitDags || !r.createsSplitDag())
        tmp2 = r.rewriteStatementBlocks(tmp2, status);
    // prepare output list
    sbs.clear();
    sbs.addAll(tmp2);
    return sbs;
}
Also used : ArrayList(java.util.ArrayList) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) ParForStatementBlock(org.apache.sysml.parser.ParForStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock)

Example 19 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class ProgramRewriter method rRewriteStatementBlockHopDAGs.

public void rRewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state) {
    // ensure robustness for calls from outside
    if (state == null)
        state = new ProgramRewriteStatus();
    if (current instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) current;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock sb : fstmt.getBody()) rRewriteStatementBlockHopDAGs(sb, state);
    } else if (current instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) current;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        wsb.setPredicateHops(rewriteHopDAG(wsb.getPredicateHops(), state));
        for (StatementBlock sb : wstmt.getBody()) rRewriteStatementBlockHopDAGs(sb, state);
    } else if (current instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) current;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        isb.setPredicateHops(rewriteHopDAG(isb.getPredicateHops(), state));
        for (StatementBlock sb : istmt.getIfBody()) rRewriteStatementBlockHopDAGs(sb, state);
        for (StatementBlock sb : istmt.getElseBody()) rRewriteStatementBlockHopDAGs(sb, state);
    } else if (// incl parfor
    current instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) current;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        fsb.setFromHops(rewriteHopDAG(fsb.getFromHops(), state));
        fsb.setToHops(rewriteHopDAG(fsb.getToHops(), state));
        fsb.setIncrementHops(rewriteHopDAG(fsb.getIncrementHops(), state));
        for (StatementBlock sb : fstmt.getBody()) rRewriteStatementBlockHopDAGs(sb, state);
    } else // generic (last-level)
    {
        current.setHops(rewriteHopDAG(current.getHops(), state));
    }
}
Also used : ParForStatementBlock(org.apache.sysml.parser.ParForStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) ParForStatementBlock(org.apache.sysml.parser.ParForStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 20 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class ProgramRewriter method rewriteProgramHopDAGs.

public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, boolean splitDags) {
    ProgramRewriteStatus state = new ProgramRewriteStatus();
    // for each namespace, handle function statement blocks
    for (String namespaceKey : dmlp.getNamespaces().keySet()) for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
        FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname);
        rRewriteStatementBlockHopDAGs(fsblock, state);
        if (!_sbRuleSet.isEmpty())
            rRewriteStatementBlock(fsblock, state, splitDags);
    }
    // handle regular statement blocks in "main" method
    for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
        StatementBlock current = dmlp.getStatementBlock(i);
        rRewriteStatementBlockHopDAGs(current, state);
    }
    if (!_sbRuleSet.isEmpty())
        dmlp.setStatementBlocks(rRewriteStatementBlocks(dmlp.getStatementBlocks(), state, splitDags));
    return state;
}
Also used : FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) ParForStatementBlock(org.apache.sysml.parser.ParForStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock)

Aggregations

StatementBlock (org.apache.sysml.parser.StatementBlock)67 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)57 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)57 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)57 FunctionStatementBlock (org.apache.sysml.parser.FunctionStatementBlock)39 Hop (org.apache.sysml.hops.Hop)28 ArrayList (java.util.ArrayList)24 FunctionStatement (org.apache.sysml.parser.FunctionStatement)22 IfStatement (org.apache.sysml.parser.IfStatement)22 ForStatement (org.apache.sysml.parser.ForStatement)20 WhileStatement (org.apache.sysml.parser.WhileStatement)19 ParForStatementBlock (org.apache.sysml.parser.ParForStatementBlock)18 ForProgramBlock (org.apache.sysml.runtime.controlprogram.ForProgramBlock)18 IfProgramBlock (org.apache.sysml.runtime.controlprogram.IfProgramBlock)16 WhileProgramBlock (org.apache.sysml.runtime.controlprogram.WhileProgramBlock)16 FunctionProgramBlock (org.apache.sysml.runtime.controlprogram.FunctionProgramBlock)13 ProgramBlock (org.apache.sysml.runtime.controlprogram.ProgramBlock)13 HashSet (java.util.HashSet)11 ExternalFunctionStatement (org.apache.sysml.parser.ExternalFunctionStatement)11 LocalVariableMap (org.apache.sysml.runtime.controlprogram.LocalVariableMap)11